1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "mdns_protocol_impl.h"
17 
18 #include <arpa/inet.h>
19 #include <cstddef>
20 #include <iostream>
21 #include <random>
22 #include <sys/types.h>
23 #include <unistd.h>
24 #include <fcntl.h>
25 
26 #include "mdns_manager.h"
27 #include "mdns_packet_parser.h"
28 #include "net_conn_client.h"
29 #include "netmgr_ext_log_wrapper.h"
30 
31 #include "securec.h"
32 
33 namespace OHOS {
34 namespace NetManagerStandard {
35 
36 constexpr uint32_t DEFAULT_INTEVAL_MS = 2000;
37 constexpr uint32_t DEFAULT_LOST_MS = 10000;
38 constexpr uint32_t DEFAULT_TTL = 120;
39 constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
40 
41 constexpr int PHASE_PTR = 1;
42 constexpr int PHASE_SRV = 2;
43 constexpr int PHASE_DOMAIN = 3;
44 
AddrToString(const std::any &addr)45 std::string AddrToString(const std::any &addr)
46 {
47     char buf[INET6_ADDRSTRLEN] = {0};
48     if (std::any_cast<in_addr>(&addr)) {
49         if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
50             return std::string{};
51         }
52     } else if (std::any_cast<in6_addr>(&addr)) {
53         if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
54             return std::string{};
55         }
56     }
57     return std::string(buf);
58 }
59 
MilliSecondsSinceEpoch()60 int64_t MilliSecondsSinceEpoch()
61 {
62     return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
63         .count();
64 }
65 
MDnsProtocolImpl()66 MDnsProtocolImpl::MDnsProtocolImpl()
67 {
68     Init();
69 }
70 
Init()71 void MDnsProtocolImpl::Init()
72 {
73     NETMGR_EXT_LOG_D("mdns_log MDnsProtocolImpl init");
74     listener_.Stop();
75     listener_.CloseAllSocket();
76 
77     if (config_.configAllIface) {
78         listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
79     } else {
80         listener_.OpenSocketForDefault(config_.ipv6Support);
81     }
82     listener_.SetReceiveHandler(
83         [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
84     listener_.SetFinishedHandler([this](int sock) {
85         std::lock_guard<std::recursive_mutex> guard(mutex_);
86         RunTaskQueue(taskQueue_);
87     });
88     listener_.Start();
89 
90     taskQueue_.clear();
91     taskOnChange_.clear();
92     AddTask([this]() { return Browse(); }, false);
93 }
94 
Browse()95 bool MDnsProtocolImpl::Browse()
96 {
97     if (lastRunTime != -1 && MilliSecondsSinceEpoch() - lastRunTime < DEFAULT_INTEVAL_MS) {
98         return false;
99     }
100     lastRunTime = MilliSecondsSinceEpoch();
101     std::lock_guard<std::recursive_mutex> guard(mutex_);
102     for (auto &&[key, res] : browserMap_) {
103         NETMGR_EXT_LOG_D("mdns_log Browse browserMap_ key[%{public}s] res.size[%{public}zu]", key.c_str(), res.size());
104         if (nameCbMap_.find(key) != nameCbMap_.end() &&
105             !MDnsManager::GetInstance().IsAvailableCallback(nameCbMap_[key])) {
106             continue;
107         }
108         handleOfflineService(key, res);
109         MDnsPayloadParser parser;
110         MDnsMessage msg{};
111         msg.questions.emplace_back(DNSProto::Question{
112             .name = key,
113             .qtype = DNSProto::RRTYPE_PTR,
114             .qclass = DNSProto::RRCLASS_IN,
115         });
116         listener_.MulticastAll(parser.ToBytes(msg));
117     }
118     return false;
119 }
120 
ConnectControl(int32_t sockfd, sockaddr* serverAddr)121 int32_t MDnsProtocolImpl::ConnectControl(int32_t sockfd, sockaddr* serverAddr)
122 {
123     uint32_t flags = static_cast<uint32_t>(fcntl(sockfd, F_GETFL, 0));
124     fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
125     int32_t ret = connect(sockfd, serverAddr, sizeof(sockaddr));
126     if ((ret < 0) && (errno != EINPROGRESS)) {
127         NETMGR_EXT_LOG_E("connect error: %{public}d", errno);
128         return NETMANAGER_EXT_ERR_INTERNAL;
129     }
130     if (ret == 0) {
131         fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
132         NETMGR_EXT_LOG_I("connect success.");
133         return NETMANAGER_EXT_SUCCESS;
134     }
135 
136     fd_set rset;
137     FD_ZERO(&rset);
138     FD_SET(sockfd, &rset);
139     fd_set wset = rset;
140     timeval tval {1, 0};
141     ret = select(sockfd + 1, &rset, &wset, NULL, &tval);
142     if (ret < 0) { // select error.
143         NETMGR_EXT_LOG_E("select error: %{public}d", errno);
144         return NETMANAGER_EXT_ERR_INTERNAL;
145     }
146     if (ret == 0) { // timeout
147         NETMGR_EXT_LOG_E("connect timeout...");
148         return NETMANAGER_EXT_ERR_INTERNAL;
149     }
150     if (!FD_ISSET(sockfd, &rset) && !FD_ISSET(sockfd, &wset)) {
151         NETMGR_EXT_LOG_E("select error: sockfd not set");
152         return NETMANAGER_EXT_ERR_INTERNAL;
153     }
154 
155     int32_t result = NETMANAGER_EXT_ERR_INTERNAL;
156     socklen_t len = sizeof(result);
157     if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &result, &len) < 0) {
158         NETMGR_EXT_LOG_E("getsockopt error: %{public}d", errno);
159         return NETMANAGER_EXT_ERR_INTERNAL;
160     }
161     if (result != 0) { // connect failed.
162         NETMGR_EXT_LOG_E("connect failed. error: %{public}d", result);
163         return NETMANAGER_EXT_ERR_INTERNAL;
164     }
165     fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
166     NETMGR_EXT_LOG_I("lost but connect success.");
167     return NETMANAGER_EXT_SUCCESS;
168 }
169 
IsConnectivity(const std::string &ip, int32_t port)170 bool MDnsProtocolImpl::IsConnectivity(const std::string &ip, int32_t port)
171 {
172     if (ip.empty()) {
173         NETMGR_EXT_LOG_E("ip is empty");
174         return false;
175     }
176 
177     int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0);
178     if (sockfd < 0) {
179         NETMGR_EXT_LOG_E("create socket error: %{public}d", errno);
180         return false;
181     }
182 
183     struct sockaddr_in serverAddr;
184     if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
185         NETMGR_EXT_LOG_E("memset_s serverAddr failed!");
186         close(sockfd);
187         return false;
188     }
189 
190     serverAddr.sin_family = AF_INET;
191     serverAddr.sin_addr.s_addr = inet_addr(ip.c_str());
192     serverAddr.sin_port = htons(port);
193     if (ConnectControl(sockfd, (struct sockaddr*)&serverAddr) != NETMANAGER_EXT_SUCCESS) {
194         NETMGR_EXT_LOG_I("connect error: %{public}d", errno);
195         close(sockfd);
196         return false;
197     }
198 
199     close(sockfd);
200     return true;
201 }
202 
handleOfflineService(const std::string &key, std::vector<Result> &res)203 void MDnsProtocolImpl::handleOfflineService(const std::string &key, std::vector<Result> &res)
204 {
205     NETMGR_EXT_LOG_D("mdns_log handleOfflineService key:[%{public}s]", key.c_str());
206     for (auto it = res.begin(); it != res.end();) {
207         if (lastRunTime - it->refrehTime > DEFAULT_LOST_MS && it->state == State::LIVE) {
208             std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
209             if ((cacheMap_.find(fullName) != cacheMap_.end()) &&
210                 IsConnectivity(cacheMap_[fullName].addr, cacheMap_[fullName].port)) {
211                 it++;
212                 continue;
213             }
214 
215             it->state = State::DEAD;
216             if (nameCbMap_.find(key) != nameCbMap_.end() && nameCbMap_[key] != nullptr) {
217                 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
218                 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
219             }
220             it = res.erase(it);
221             cacheMap_.erase(fullName);
222         } else {
223             it++;
224         }
225     }
226 }
227 
SetConfig(const MDnsConfig &config)228 void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
229 {
230     config_ = config;
231 }
232 
GetConfig() const233 const MDnsConfig &MDnsProtocolImpl::GetConfig() const
234 {
235     return config_;
236 }
237 
Decorated(const std::string &name) const238 std::string MDnsProtocolImpl::Decorated(const std::string &name) const
239 {
240     return name + config_.topDomain;
241 }
242 
Register(const Result &info)243 int32_t MDnsProtocolImpl::Register(const Result &info)
244 {
245     NETMGR_EXT_LOG_D("mdns_log Register");
246     if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
247         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
248     }
249     std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
250     if (!IsDomainValid(name)) {
251         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
252     }
253     {
254         std::lock_guard<std::recursive_mutex> guard(mutex_);
255         if (srvMap_.find(name) != srvMap_.end()) {
256             return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
257         }
258         srvMap_.emplace(name, info);
259     }
260     return Announce(info, false);
261 }
262 
UnRegister(const std::string &key)263 int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
264 {
265     NETMGR_EXT_LOG_D("mdns_log UnRegister");
266     std::string name = Decorated(key);
267     std::lock_guard<std::recursive_mutex> guard(mutex_);
268     if (srvMap_.find(name) != srvMap_.end()) {
269         Announce(srvMap_[name], true);
270         srvMap_.erase(name);
271         return NETMANAGER_EXT_SUCCESS;
272     }
273     return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
274 }
275 
DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)276 bool MDnsProtocolImpl::DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
277 {
278     NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache");
279     std::string name = Decorated(serviceType);
280     std::lock_guard<std::recursive_mutex> guard(mutex_);
281     if (!IsBrowserAvailable(name)) {
282         return false;
283     }
284 
285     if (browserMap_.find(name) == browserMap_.end()) {
286         NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache browserMap_ not find name");
287         return false;
288     }
289 
290     for (auto &res : browserMap_[name]) {
291         if (res.state == State::REMOVE || res.state == State::DEAD) {
292             continue;
293         }
294         AddTask([cb, info = ConvertResultToInfo(res)]() {
295             NETMGR_EXT_LOG_W("mdns_log DiscoveryFromCache ConvertResultToInfo HandleServiceFound");
296             if (MDnsManager::GetInstance().IsAvailableCallback(cb)) {
297                 cb->HandleServiceFound(info, NETMANAGER_EXT_SUCCESS);
298             }
299             return true;
300         });
301     }
302     return true;
303 }
304 
DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)305 bool MDnsProtocolImpl::DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
306 {
307     NETMGR_EXT_LOG_D("mdns_log DiscoveryFromNet");
308     std::string name = Decorated(serviceType);
309     std::lock_guard<std::recursive_mutex> guard(mutex_);
310     browserMap_.insert({name, std::vector<Result>{}});
311     nameCbMap_[name] = cb;
312     // key is serviceTYpe
313     AddEvent(name, [this, name, cb]() {
314         std::lock_guard<std::recursive_mutex> guard(mutex_);
315         if (!IsBrowserAvailable(name)) {
316             return false;
317         }
318         if (!MDnsManager::GetInstance().IsAvailableCallback(cb)) {
319             return true;
320         }
321         for (auto &res : browserMap_[name]) {
322             std::string fullName = Decorated(res.serviceName + MDNS_DOMAIN_SPLITER_STR + res.serviceType);
323             NETMGR_EXT_LOG_W("mdns_log DiscoveryFromNet name:[%{public}s] fullName:[%{public}s]", name.c_str(),
324                              fullName.c_str());
325             if (cacheMap_.find(fullName) == cacheMap_.end() ||
326                 (res.state == State::ADD || res.state == State::REFRESH)) {
327                 NETMGR_EXT_LOG_W("mdns_log HandleServiceFound");
328                 cb->HandleServiceFound(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
329                 res.state = State::LIVE;
330             }
331             if (res.state == State::REMOVE) {
332                 res.state = State::DEAD;
333                 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
334                 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
335                 if (cacheMap_.find(fullName) != cacheMap_.end()) {
336                     cacheMap_.erase(fullName);
337                 }
338             }
339         }
340         return false;
341     });
342 
343     AddTask([=]() {
344             MDnsPayloadParser parser;
345             MDnsMessage msg{};
346             msg.questions.emplace_back(DNSProto::Question{
347                 .name = name,
348                 .qtype = DNSProto::RRTYPE_PTR,
349                 .qclass = DNSProto::RRCLASS_IN,
350             });
351             listener_.MulticastAll(parser.ToBytes(msg));
352             return true;
353         }, false);
354     return true;
355 }
356 
Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)357 int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
358 {
359     NETMGR_EXT_LOG_D("mdns_log Discovery");
360     DiscoveryFromCache(serviceType, cb);
361     DiscoveryFromNet(serviceType, cb);
362     return NETMANAGER_EXT_SUCCESS;
363 }
364 
ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)365 bool MDnsProtocolImpl::ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)
366 {
367     NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromCache");
368     std::lock_guard<std::recursive_mutex> guard(mutex_);
369     if (!IsInstanceCacheAvailable(name)) {
370         NETMGR_EXT_LOG_W("mdns_log ResolveInstanceFromCache cacheMap_ has no element [%{public}s]", name.c_str());
371         return false;
372     }
373 
374     NETMGR_EXT_LOG_I("mdns_log rr.name : [%{public}s]", name.c_str());
375     Result r = cacheMap_[name];
376     if (IsDomainCacheAvailable(r.domain)) {
377         r.ipv6 = cacheMap_[r.domain].ipv6;
378         r.addr = cacheMap_[r.domain].addr;
379 
380         NETMGR_EXT_LOG_D("mdns_log Add Task DomainCache Available, [%{public}s]", r.domain.c_str());
381         AddTask([cb, info = ConvertResultToInfo(r)]() {
382             if (nullptr != cb) {
383                 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
384             }
385             return true;
386         });
387     } else {
388         ResolveFromNet(r.domain, nullptr);
389         NETMGR_EXT_LOG_D("mdns_log Add Event DomainCache UnAvailable, [%{public}s]", r.domain.c_str());
390         AddEvent(r.domain, [this, cb, r]() mutable {
391             if (!IsDomainCacheAvailable(r.domain)) {
392                 return false;
393             }
394             r.ipv6 = cacheMap_[r.domain].ipv6;
395             r.addr = cacheMap_[r.domain].addr;
396             if (nullptr != cb) {
397                 cb->HandleResolveResult(ConvertResultToInfo(r), NETMANAGER_EXT_SUCCESS);
398             }
399             return true;
400         });
401     }
402     return true;
403 }
404 
ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)405 bool MDnsProtocolImpl::ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)
406 {
407     NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromNet");
408     {
409         std::lock_guard<std::recursive_mutex> guard(mutex_);
410         cacheMap_[name].state = State::ADD;
411         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
412     }
413     MDnsPayloadParser parser;
414     MDnsMessage msg{};
415     msg.questions.emplace_back(DNSProto::Question{
416         .name = name,
417         .qtype = DNSProto::RRTYPE_SRV,
418         .qclass = DNSProto::RRCLASS_IN,
419     });
420     msg.questions.emplace_back(DNSProto::Question{
421         .name = name,
422         .qtype = DNSProto::RRTYPE_TXT,
423         .qclass = DNSProto::RRCLASS_IN,
424     });
425     msg.header.qdcount = msg.questions.size();
426     AddEvent(name, [this, name, cb]() { return ResolveInstanceFromCache(name, cb); });
427     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
428     return size > 0;
429 }
430 
ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)431 bool MDnsProtocolImpl::ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)
432 {
433     NETMGR_EXT_LOG_D("mdns_log ResolveFromCache");
434     std::lock_guard<std::recursive_mutex> guard(mutex_);
435     if (!IsDomainCacheAvailable(domain)) {
436         return false;
437     }
438     AddTask([this, cb, info = ConvertResultToInfo(cacheMap_[domain])]() {
439         if (nullptr != cb) {
440             cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
441         }
442         return true;
443     });
444     return true;
445 }
446 
ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)447 bool MDnsProtocolImpl::ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)
448 {
449     NETMGR_EXT_LOG_D("mdns_log ResolveFromNet");
450     {
451         std::lock_guard<std::recursive_mutex> guard(mutex_);
452         cacheMap_[domain];
453         cacheMap_[domain].domain = domain;
454     }
455     MDnsPayloadParser parser;
456     MDnsMessage msg{};
457     msg.questions.emplace_back(DNSProto::Question{
458         .name = domain,
459         .qtype = DNSProto::RRTYPE_A,
460         .qclass = DNSProto::RRCLASS_IN,
461     });
462     msg.questions.emplace_back(DNSProto::Question{
463         .name = domain,
464         .qtype = DNSProto::RRTYPE_AAAA,
465         .qclass = DNSProto::RRCLASS_IN,
466     });
467     // key is serviceName
468     AddEvent(domain, [this, cb, domain]() { return ResolveFromCache(domain, cb); });
469     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
470     return size > 0;
471 }
472 
ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)473 int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)
474 {
475     NETMGR_EXT_LOG_D("mdns_log execute ResolveInstance");
476     if (!IsInstanceValid(instance)) {
477         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
478     }
479     std::string name = Decorated(instance);
480     if (!IsDomainValid(name)) {
481         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
482     }
483     if (ResolveInstanceFromCache(name, cb)) {
484         return NETMANAGER_EXT_SUCCESS;
485     }
486     return ResolveInstanceFromNet(name, cb) ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
487 }
488 
Announce(const Result &info, bool off)489 int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
490 {
491     NETMGR_EXT_LOG_I("mdns_log Announce message");
492     MDnsMessage response{};
493     response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
494     std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
495     response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
496                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
497                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
498                                                            .ttl = off ? 0U : DEFAULT_TTL,
499                                                            .rdata = name});
500     response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
501                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
502                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
503                                                            .ttl = off ? 0U : DEFAULT_TTL,
504                                                            .rdata = DNSProto::RDataSrv{
505                                                                .priority = 0,
506                                                                .weight = 0,
507                                                                .port = static_cast<uint16_t>(info.port),
508                                                                .name = GetHostDomain(),
509                                                            }});
510     response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
511                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
512                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
513                                                            .ttl = off ? 0U : DEFAULT_TTL,
514                                                            .rdata = info.txt});
515     MDnsPayloadParser parser;
516     ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
517     return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
518 }
519 
ReceivePacket(int sock, const MDnsPayload &payload)520 void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
521 {
522     if (payload.size() == 0) {
523         return;
524     }
525     MDnsPayloadParser parser;
526     MDnsMessage msg = parser.FromBytes(payload);
527     if (parser.GetError() != 0) {
528         NETMGR_EXT_LOG_E("parser payload failed");
529         return;
530     }
531     if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
532         ProcessQuestion(sock, msg);
533     } else {
534         ProcessAnswer(sock, msg);
535     }
536 }
537 
AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type, const std::string &name, const std::any &rdata)538 void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
539                                     const std::string &name, const std::any &rdata)
540 {
541     rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
542                                                  .rtype = static_cast<uint16_t>(type),
543                                                  .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
544                                                  .ttl = DEFAULT_TTL,
545                                                  .rdata = rdata});
546 }
547 
ProcessQuestion(int sock, const MDnsMessage &msg)548 void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
549 {
550     const sockaddr *saddrIf = listener_.GetSockAddr(sock);
551     if (saddrIf == nullptr) {
552         NETMGR_EXT_LOG_W("mdns_log ProcessQuestion saddrIf is null");
553         return;
554     }
555     std::any anyAddr;
556     DNSProto::RRType anyAddrType;
557     if (saddrIf->sa_family == AF_INET6) {
558         anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
559         anyAddrType = DNSProto::RRTYPE_AAAA;
560     } else {
561         anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
562         anyAddrType = DNSProto::RRTYPE_A;
563     }
564     int phase = 0;
565     MDnsMessage response{};
566     response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
567     for (size_t i = 0; i < msg.header.qdcount; ++i) {
568         ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
569     }
570     if (phase < PHASE_DOMAIN) {
571         AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
572     }
573 
574     if (phase != 0 && response.answers.size() > 0) {
575         listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
576     }
577 }
578 
ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType, const DNSProto::Question &qu, int &phase, MDnsMessage &response)579 void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
580                                              const DNSProto::Question &qu, int &phase, MDnsMessage &response)
581 {
582     NETMGR_EXT_LOG_D("mdns_log ProcessQuestionRecord");
583     std::lock_guard<std::recursive_mutex> guard(mutex_);
584     std::string name = qu.name;
585     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
586         std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
587             if (EndsWith(elem.first, name)) {
588                 AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
589                 AppendRecord(response.additional, DNSProto::RRTYPE_SRV, elem.first,
590                              DNSProto::RDataSrv{
591                                  .priority = 0,
592                                  .weight = 0,
593                                  .port = static_cast<uint16_t>(elem.second.port),
594                                  .name = GetHostDomain(),
595                              });
596                 AppendRecord(response.additional, DNSProto::RRTYPE_TXT, elem.first, elem.second.txt);
597             }
598         });
599         phase = std::max(phase, PHASE_PTR);
600     }
601     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
602         auto iter = srvMap_.find(name);
603         if (iter == srvMap_.end()) {
604             return;
605         }
606         AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
607                      DNSProto::RDataSrv{
608                          .priority = 0,
609                          .weight = 0,
610                          .port = static_cast<uint16_t>(iter->second.port),
611                          .name = GetHostDomain(),
612                      });
613         phase = std::max(phase, PHASE_SRV);
614     }
615     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
616         auto iter = srvMap_.find(name);
617         if (iter == srvMap_.end()) {
618             return;
619         }
620         AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
621         phase = std::max(phase, PHASE_SRV);
622     }
623     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
624         if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
625             return;
626         }
627         AppendRecord(response.answers, anyAddrType, name, anyAddr);
628         phase = std::max(phase, PHASE_DOMAIN);
629     }
630 }
631 
ProcessAnswer(int sock, const MDnsMessage &msg)632 void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
633 {
634     const sockaddr *saddrIf = listener_.GetSockAddr(sock);
635     if (saddrIf == nullptr) {
636         return;
637     }
638     bool v6 = (saddrIf->sa_family == AF_INET6);
639     std::set<std::string> changed;
640     for (const auto &answer : msg.answers) {
641         ProcessAnswerRecord(v6, answer, changed);
642     }
643     for (const auto &i : msg.additional) {
644         ProcessAnswerRecord(v6, i, changed);
645     }
646     for (const auto &i : changed) {
647         std::lock_guard<std::recursive_mutex> guard(mutex_);
648         RunTaskQueue(taskOnChange_[i]);
649         KillCache(i);
650     }
651 }
652 
UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)653 void MDnsProtocolImpl::UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
654 {
655     const std::string *data = std::any_cast<std::string>(&rr.rdata);
656     if (data == nullptr) {
657         return;
658     }
659 
660     std::string name = rr.name;
661     if (browserMap_.find(name) == browserMap_.end()) {
662         return;
663     }
664     auto &results = browserMap_[name];
665     std::string srvName;
666     std::string srvType;
667     ExtractNameAndType(*data, srvName, srvType);
668     if (srvName.empty() || srvType.empty()) {
669         return;
670     }
671     auto res =
672         std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
673     if (res == results.end()) {
674         results.emplace_back(Result{
675             .serviceName = srvName,
676             .serviceType = srvType,
677             .state = State::ADD,
678         });
679     }
680     res = std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
681     if (res->serviceName != srvName || res->state == State::DEAD) {
682         res->state = State::REFRESH;
683         res->serviceName = srvName;
684     }
685     if (rr.ttl == 0) {
686         res->state = State::REMOVE;
687     }
688     if (res->state != State::LIVE && res->state != State::DEAD) {
689         changed.emplace(name);
690     }
691     res->ttl = rr.ttl;
692     res->refrehTime = MilliSecondsSinceEpoch();
693 }
694 
UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)695 void MDnsProtocolImpl::UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
696 {
697     const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
698     if (srv == nullptr) {
699         return;
700     }
701     std::string name = rr.name;
702     if (cacheMap_.find(name) == cacheMap_.end()) {
703         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
704         cacheMap_[name].state = State::ADD;
705         cacheMap_[name].domain = srv->name;
706         cacheMap_[name].port = srv->port;
707     }
708     Result &result = cacheMap_[name];
709     if (result.domain != srv->name || result.port != srv->port || result.state == State::DEAD) {
710         if (result.state != State::ADD) {
711             result.state = State::REFRESH;
712         }
713         result.domain = srv->name;
714         result.port = srv->port;
715     }
716     if (rr.ttl == 0) {
717         result.state = State::REMOVE;
718     }
719     if (result.state != State::LIVE && result.state != State::DEAD) {
720         changed.emplace(name);
721     }
722     result.ttl = rr.ttl;
723     result.refrehTime = MilliSecondsSinceEpoch();
724 }
725 
UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)726 void MDnsProtocolImpl::UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
727 {
728     const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
729     if (txt == nullptr) {
730         return;
731     }
732     std::string name = rr.name;
733     if (cacheMap_.find(name) == cacheMap_.end()) {
734         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
735         cacheMap_[name].state = State::ADD;
736         cacheMap_[name].txt = *txt;
737     }
738     Result &result = cacheMap_[name];
739     if (result.txt != *txt || result.state == State::DEAD) {
740         if (result.state != State::ADD) {
741             result.state = State::REFRESH;
742         }
743         result.txt = *txt;
744     }
745     if (rr.ttl == 0) {
746         result.state = State::REMOVE;
747     }
748     if (result.state != State::LIVE && result.state != State::DEAD) {
749         changed.emplace(name);
750     }
751     result.ttl = rr.ttl;
752     result.refrehTime = MilliSecondsSinceEpoch();
753 }
754 
UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)755 void MDnsProtocolImpl::UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
756 {
757     if (v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
758         return;
759     }
760     const std::string addr = AddrToString(rr.rdata);
761     bool v6rr = (rr.rtype == DNSProto::RRTYPE_AAAA);
762     if (addr.empty()) {
763         return;
764     }
765     std::string name = rr.name;
766     if (cacheMap_.find(name) == cacheMap_.end()) {
767         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
768         cacheMap_[name].state = State::ADD;
769         cacheMap_[name].ipv6 = v6rr;
770         cacheMap_[name].addr = addr;
771     }
772     Result &result = cacheMap_[name];
773     if (result.addr != addr || result.ipv6 != v6rr || result.state == State::DEAD) {
774         result.state = State::REFRESH;
775         result.addr = addr;
776         result.ipv6 = v6rr;
777     }
778     if (rr.ttl == 0) {
779         result.state = State::REMOVE;
780     }
781     if (result.state != State::LIVE && result.state != State::DEAD) {
782         changed.emplace(name);
783     }
784     result.ttl = rr.ttl;
785     result.refrehTime = MilliSecondsSinceEpoch();
786 }
787 
ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)788 void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
789 {
790     NETMGR_EXT_LOG_D("mdns_log ProcessAnswerRecord, type=[%{public}d]", rr.rtype);
791     std::lock_guard<std::recursive_mutex> guard(mutex_);
792     std::string name = rr.name;
793     if (cacheMap_.find(name) == cacheMap_.end() && browserMap_.find(name) == browserMap_.end() &&
794         srvMap_.find(name) != srvMap_.end()) {
795         return;
796     }
797     if (rr.rtype == DNSProto::RRTYPE_PTR) {
798         UpdatePtr(v6, rr, changed);
799     } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
800         UpdateSrv(v6, rr, changed);
801     } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
802         UpdateTxt(v6, rr, changed);
803     } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
804         UpdateAddr(v6, rr, changed);
805     } else {
806         NETMGR_EXT_LOG_D("mdns_log Unknown packet received, type=[%{public}d]", rr.rtype);
807     }
808 }
809 
GetHostDomain()810 std::string MDnsProtocolImpl::GetHostDomain()
811 {
812     if (config_.hostname.empty()) {
813         char buffer[MDNS_MAX_DOMAIN_LABEL];
814         if (gethostname(buffer, sizeof(buffer)) == 0) {
815             config_.hostname = buffer;
816             static auto uid = []() {
817                 std::random_device rd;
818                 return rd();
819             }();
820             config_.hostname += std::to_string(uid);
821         }
822     }
823     return Decorated(config_.hostname);
824 }
825 
AddTask(const Task &task, bool atonce)826 void MDnsProtocolImpl::AddTask(const Task &task, bool atonce)
827 {
828     {
829         std::lock_guard<std::recursive_mutex> guard(mutex_);
830         taskQueue_.emplace_back(task);
831     }
832     if (atonce) {
833         listener_.TriggerRefresh();
834     }
835 }
836 
ConvertResultToInfo(const MDnsProtocolImpl::Result &result)837 MDnsServiceInfo MDnsProtocolImpl::ConvertResultToInfo(const MDnsProtocolImpl::Result &result)
838 {
839     MDnsServiceInfo info;
840     info.name = result.serviceName;
841     info.type = result.serviceType;
842     if (!result.addr.empty()) {
843         info.family = result.ipv6 ? MDnsServiceInfo::IPV6 : MDnsServiceInfo::IPV4;
844     }
845     info.addr = result.addr;
846     info.port = result.port;
847     info.txtRecord = result.txt;
848     return info;
849 }
850 
IsCacheAvailable(const std::string &key)851 bool MDnsProtocolImpl::IsCacheAvailable(const std::string &key)
852 {
853     constexpr int64_t ms2S = 1000LL;
854     NETMGR_EXT_LOG_D("mdns_log IsCacheAvailable, ttl=[%{public}u]", cacheMap_[key].ttl);
855     return cacheMap_.find(key) != cacheMap_.end() &&
856            (ms2S * cacheMap_[key].ttl) > static_cast<uint32_t>(MilliSecondsSinceEpoch() - cacheMap_[key].refrehTime);
857 }
858 
IsDomainCacheAvailable(const std::string &key)859 bool MDnsProtocolImpl::IsDomainCacheAvailable(const std::string &key)
860 {
861     return IsCacheAvailable(key) && !cacheMap_[key].addr.empty();
862 }
863 
IsInstanceCacheAvailable(const std::string &key)864 bool MDnsProtocolImpl::IsInstanceCacheAvailable(const std::string &key)
865 {
866     return IsCacheAvailable(key) && !cacheMap_[key].domain.empty();
867 }
868 
IsBrowserAvailable(const std::string &key)869 bool MDnsProtocolImpl::IsBrowserAvailable(const std::string &key)
870 {
871     return browserMap_.find(key) != browserMap_.end() && !browserMap_[key].empty();
872 }
873 
AddEvent(const std::string &key, const Task &task)874 void MDnsProtocolImpl::AddEvent(const std::string &key, const Task &task)
875 {
876     std::lock_guard<std::recursive_mutex> guard(mutex_);
877     taskOnChange_[key].emplace_back(task);
878 }
879 
RunTaskQueue(std::list<Task> &queue)880 void MDnsProtocolImpl::RunTaskQueue(std::list<Task> &queue)
881 {
882     std::list<Task> tmp;
883     for (auto &&func : queue) {
884         if (!func()) {
885             tmp.emplace_back(func);
886         }
887     }
888     tmp.swap(queue);
889 }
890 
KillCache(const std::string &key)891 void MDnsProtocolImpl::KillCache(const std::string &key)
892 {
893     NETMGR_EXT_LOG_D("mdns_log KillCache");
894     if (IsBrowserAvailable(key) && browserMap_.find(key) != browserMap_.end()) {
895         for (auto it = browserMap_[key].begin(); it != browserMap_[key].end();) {
896             KillBrowseCache(key, it);
897         }
898     }
899     if (IsCacheAvailable(key)) {
900         std::lock_guard<std::recursive_mutex> guard(mutex_);
901         auto &elem = cacheMap_[key];
902         if (elem.state == State::REMOVE) {
903             elem.state = State::DEAD;
904             cacheMap_.erase(key);
905         } else if (elem.state == State::ADD || elem.state == State::REFRESH) {
906             elem.state = State::LIVE;
907         }
908     }
909 }
910 
KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)911 void MDnsProtocolImpl::KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)
912 {
913     NETMGR_EXT_LOG_D("mdns_log KillBrowseCache");
914     if (it->state == State::REMOVE) {
915         it->state = State::DEAD;
916         if (nameCbMap_.find(key) != nameCbMap_.end()) {
917             NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
918             nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
919         }
920         std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
921         cacheMap_.erase(fullName);
922         it = browserMap_[key].erase(it);
923     } else if (it->state == State::ADD || it->state == State::REFRESH) {
924         it->state = State::LIVE;
925         it++;
926     } else {
927         it++;
928     }
929 }
930 
StopCbMap(const std::string &serviceType)931 int32_t MDnsProtocolImpl::StopCbMap(const std::string &serviceType)
932 {
933     NETMGR_EXT_LOG_D("mdns_log StopCbMap");
934     std::lock_guard<std::recursive_mutex> guard(mutex_);
935     std::string name = Decorated(serviceType);
936     sptr<IDiscoveryCallback> cb = nullptr;
937     if (nameCbMap_.find(name) != nameCbMap_.end()) {
938         cb = nameCbMap_[name];
939         nameCbMap_.erase(name);
940     }
941     taskOnChange_.erase(name);
942     auto it = browserMap_.find(name);
943     if (it != browserMap_.end()) {
944         if (cb != nullptr) {
945             NETMGR_EXT_LOG_I("mdns_log StopCbMap res size:[%{public}zu]", it->second.size());
946             for (auto &&res : it->second) {
947                 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
948                 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
949             }
950         }
951         browserMap_.erase(name);
952     }
953     return NETMANAGER_SUCCESS;
954 }
955 } // namespace NetManagerStandard
956 } // namespace OHOS
957