1/*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "mdns_protocol_impl.h"
17
18#include <arpa/inet.h>
19#include <cstddef>
20#include <iostream>
21#include <random>
22#include <sys/types.h>
23#include <unistd.h>
24#include <fcntl.h>
25
26#include "mdns_manager.h"
27#include "mdns_packet_parser.h"
28#include "net_conn_client.h"
29#include "netmgr_ext_log_wrapper.h"
30
31#include "securec.h"
32
33namespace OHOS {
34namespace NetManagerStandard {
35
36constexpr uint32_t DEFAULT_INTEVAL_MS = 2000;
37constexpr uint32_t DEFAULT_LOST_MS = 10000;
38constexpr uint32_t DEFAULT_TTL = 120;
39constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
40
41constexpr int PHASE_PTR = 1;
42constexpr int PHASE_SRV = 2;
43constexpr int PHASE_DOMAIN = 3;
44
45std::string AddrToString(const std::any &addr)
46{
47    char buf[INET6_ADDRSTRLEN] = {0};
48    if (std::any_cast<in_addr>(&addr)) {
49        if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
50            return std::string{};
51        }
52    } else if (std::any_cast<in6_addr>(&addr)) {
53        if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
54            return std::string{};
55        }
56    }
57    return std::string(buf);
58}
59
60int64_t MilliSecondsSinceEpoch()
61{
62    return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
63        .count();
64}
65
66MDnsProtocolImpl::MDnsProtocolImpl()
67{
68    Init();
69}
70
71void MDnsProtocolImpl::Init()
72{
73    NETMGR_EXT_LOG_D("mdns_log MDnsProtocolImpl init");
74    listener_.Stop();
75    listener_.CloseAllSocket();
76
77    if (config_.configAllIface) {
78        listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
79    } else {
80        listener_.OpenSocketForDefault(config_.ipv6Support);
81    }
82    listener_.SetReceiveHandler(
83        [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
84    listener_.SetFinishedHandler([this](int sock) {
85        std::lock_guard<std::recursive_mutex> guard(mutex_);
86        RunTaskQueue(taskQueue_);
87    });
88    listener_.Start();
89
90    taskQueue_.clear();
91    taskOnChange_.clear();
92    AddTask([this]() { return Browse(); }, false);
93}
94
95bool MDnsProtocolImpl::Browse()
96{
97    if (lastRunTime != -1 && MilliSecondsSinceEpoch() - lastRunTime < DEFAULT_INTEVAL_MS) {
98        return false;
99    }
100    lastRunTime = MilliSecondsSinceEpoch();
101    std::lock_guard<std::recursive_mutex> guard(mutex_);
102    for (auto &&[key, res] : browserMap_) {
103        NETMGR_EXT_LOG_D("mdns_log Browse browserMap_ key[%{public}s] res.size[%{public}zu]", key.c_str(), res.size());
104        if (nameCbMap_.find(key) != nameCbMap_.end() &&
105            !MDnsManager::GetInstance().IsAvailableCallback(nameCbMap_[key])) {
106            continue;
107        }
108        handleOfflineService(key, res);
109        MDnsPayloadParser parser;
110        MDnsMessage msg{};
111        msg.questions.emplace_back(DNSProto::Question{
112            .name = key,
113            .qtype = DNSProto::RRTYPE_PTR,
114            .qclass = DNSProto::RRCLASS_IN,
115        });
116        listener_.MulticastAll(parser.ToBytes(msg));
117    }
118    return false;
119}
120
121int32_t MDnsProtocolImpl::ConnectControl(int32_t sockfd, sockaddr* serverAddr)
122{
123    uint32_t flags = static_cast<uint32_t>(fcntl(sockfd, F_GETFL, 0));
124    fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
125    int32_t ret = connect(sockfd, serverAddr, sizeof(sockaddr));
126    if ((ret < 0) && (errno != EINPROGRESS)) {
127        NETMGR_EXT_LOG_E("connect error: %{public}d", errno);
128        return NETMANAGER_EXT_ERR_INTERNAL;
129    }
130    if (ret == 0) {
131        fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
132        NETMGR_EXT_LOG_I("connect success.");
133        return NETMANAGER_EXT_SUCCESS;
134    }
135
136    fd_set rset;
137    FD_ZERO(&rset);
138    FD_SET(sockfd, &rset);
139    fd_set wset = rset;
140    timeval tval {1, 0};
141    ret = select(sockfd + 1, &rset, &wset, NULL, &tval);
142    if (ret < 0) { // select error.
143        NETMGR_EXT_LOG_E("select error: %{public}d", errno);
144        return NETMANAGER_EXT_ERR_INTERNAL;
145    }
146    if (ret == 0) { // timeout
147        NETMGR_EXT_LOG_E("connect timeout...");
148        return NETMANAGER_EXT_ERR_INTERNAL;
149    }
150    if (!FD_ISSET(sockfd, &rset) && !FD_ISSET(sockfd, &wset)) {
151        NETMGR_EXT_LOG_E("select error: sockfd not set");
152        return NETMANAGER_EXT_ERR_INTERNAL;
153    }
154
155    int32_t result = NETMANAGER_EXT_ERR_INTERNAL;
156    socklen_t len = sizeof(result);
157    if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &result, &len) < 0) {
158        NETMGR_EXT_LOG_E("getsockopt error: %{public}d", errno);
159        return NETMANAGER_EXT_ERR_INTERNAL;
160    }
161    if (result != 0) { // connect failed.
162        NETMGR_EXT_LOG_E("connect failed. error: %{public}d", result);
163        return NETMANAGER_EXT_ERR_INTERNAL;
164    }
165    fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
166    NETMGR_EXT_LOG_I("lost but connect success.");
167    return NETMANAGER_EXT_SUCCESS;
168}
169
170bool MDnsProtocolImpl::IsConnectivity(const std::string &ip, int32_t port)
171{
172    if (ip.empty()) {
173        NETMGR_EXT_LOG_E("ip is empty");
174        return false;
175    }
176
177    int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0);
178    if (sockfd < 0) {
179        NETMGR_EXT_LOG_E("create socket error: %{public}d", errno);
180        return false;
181    }
182
183    struct sockaddr_in serverAddr;
184    if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
185        NETMGR_EXT_LOG_E("memset_s serverAddr failed!");
186        close(sockfd);
187        return false;
188    }
189
190    serverAddr.sin_family = AF_INET;
191    serverAddr.sin_addr.s_addr = inet_addr(ip.c_str());
192    serverAddr.sin_port = htons(port);
193    if (ConnectControl(sockfd, (struct sockaddr*)&serverAddr) != NETMANAGER_EXT_SUCCESS) {
194        NETMGR_EXT_LOG_I("connect error: %{public}d", errno);
195        close(sockfd);
196        return false;
197    }
198
199    close(sockfd);
200    return true;
201}
202
203void MDnsProtocolImpl::handleOfflineService(const std::string &key, std::vector<Result> &res)
204{
205    NETMGR_EXT_LOG_D("mdns_log handleOfflineService key:[%{public}s]", key.c_str());
206    for (auto it = res.begin(); it != res.end();) {
207        if (lastRunTime - it->refrehTime > DEFAULT_LOST_MS && it->state == State::LIVE) {
208            std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
209            if ((cacheMap_.find(fullName) != cacheMap_.end()) &&
210                IsConnectivity(cacheMap_[fullName].addr, cacheMap_[fullName].port)) {
211                it++;
212                continue;
213            }
214
215            it->state = State::DEAD;
216            if (nameCbMap_.find(key) != nameCbMap_.end() && nameCbMap_[key] != nullptr) {
217                NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
218                nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
219            }
220            it = res.erase(it);
221            cacheMap_.erase(fullName);
222        } else {
223            it++;
224        }
225    }
226}
227
228void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
229{
230    config_ = config;
231}
232
233const MDnsConfig &MDnsProtocolImpl::GetConfig() const
234{
235    return config_;
236}
237
238std::string MDnsProtocolImpl::Decorated(const std::string &name) const
239{
240    return name + config_.topDomain;
241}
242
243int32_t MDnsProtocolImpl::Register(const Result &info)
244{
245    NETMGR_EXT_LOG_D("mdns_log Register");
246    if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
247        return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
248    }
249    std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
250    if (!IsDomainValid(name)) {
251        return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
252    }
253    {
254        std::lock_guard<std::recursive_mutex> guard(mutex_);
255        if (srvMap_.find(name) != srvMap_.end()) {
256            return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
257        }
258        srvMap_.emplace(name, info);
259    }
260    return Announce(info, false);
261}
262
263int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
264{
265    NETMGR_EXT_LOG_D("mdns_log UnRegister");
266    std::string name = Decorated(key);
267    std::lock_guard<std::recursive_mutex> guard(mutex_);
268    if (srvMap_.find(name) != srvMap_.end()) {
269        Announce(srvMap_[name], true);
270        srvMap_.erase(name);
271        return NETMANAGER_EXT_SUCCESS;
272    }
273    return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
274}
275
276bool MDnsProtocolImpl::DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
277{
278    NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache");
279    std::string name = Decorated(serviceType);
280    std::lock_guard<std::recursive_mutex> guard(mutex_);
281    if (!IsBrowserAvailable(name)) {
282        return false;
283    }
284
285    if (browserMap_.find(name) == browserMap_.end()) {
286        NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache browserMap_ not find name");
287        return false;
288    }
289
290    for (auto &res : browserMap_[name]) {
291        if (res.state == State::REMOVE || res.state == State::DEAD) {
292            continue;
293        }
294        AddTask([cb, info = ConvertResultToInfo(res)]() {
295            NETMGR_EXT_LOG_W("mdns_log DiscoveryFromCache ConvertResultToInfo HandleServiceFound");
296            if (MDnsManager::GetInstance().IsAvailableCallback(cb)) {
297                cb->HandleServiceFound(info, NETMANAGER_EXT_SUCCESS);
298            }
299            return true;
300        });
301    }
302    return true;
303}
304
305bool MDnsProtocolImpl::DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
306{
307    NETMGR_EXT_LOG_D("mdns_log DiscoveryFromNet");
308    std::string name = Decorated(serviceType);
309    std::lock_guard<std::recursive_mutex> guard(mutex_);
310    browserMap_.insert({name, std::vector<Result>{}});
311    nameCbMap_[name] = cb;
312    // key is serviceTYpe
313    AddEvent(name, [this, name, cb]() {
314        std::lock_guard<std::recursive_mutex> guard(mutex_);
315        if (!IsBrowserAvailable(name)) {
316            return false;
317        }
318        if (!MDnsManager::GetInstance().IsAvailableCallback(cb)) {
319            return true;
320        }
321        for (auto &res : browserMap_[name]) {
322            std::string fullName = Decorated(res.serviceName + MDNS_DOMAIN_SPLITER_STR + res.serviceType);
323            NETMGR_EXT_LOG_W("mdns_log DiscoveryFromNet name:[%{public}s] fullName:[%{public}s]", name.c_str(),
324                             fullName.c_str());
325            if (cacheMap_.find(fullName) == cacheMap_.end() ||
326                (res.state == State::ADD || res.state == State::REFRESH)) {
327                NETMGR_EXT_LOG_W("mdns_log HandleServiceFound");
328                cb->HandleServiceFound(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
329                res.state = State::LIVE;
330            }
331            if (res.state == State::REMOVE) {
332                res.state = State::DEAD;
333                NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
334                cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
335                if (cacheMap_.find(fullName) != cacheMap_.end()) {
336                    cacheMap_.erase(fullName);
337                }
338            }
339        }
340        return false;
341    });
342
343    AddTask([=]() {
344            MDnsPayloadParser parser;
345            MDnsMessage msg{};
346            msg.questions.emplace_back(DNSProto::Question{
347                .name = name,
348                .qtype = DNSProto::RRTYPE_PTR,
349                .qclass = DNSProto::RRCLASS_IN,
350            });
351            listener_.MulticastAll(parser.ToBytes(msg));
352            return true;
353        }, false);
354    return true;
355}
356
357int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
358{
359    NETMGR_EXT_LOG_D("mdns_log Discovery");
360    DiscoveryFromCache(serviceType, cb);
361    DiscoveryFromNet(serviceType, cb);
362    return NETMANAGER_EXT_SUCCESS;
363}
364
365bool MDnsProtocolImpl::ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)
366{
367    NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromCache");
368    std::lock_guard<std::recursive_mutex> guard(mutex_);
369    if (!IsInstanceCacheAvailable(name)) {
370        NETMGR_EXT_LOG_W("mdns_log ResolveInstanceFromCache cacheMap_ has no element [%{public}s]", name.c_str());
371        return false;
372    }
373
374    NETMGR_EXT_LOG_I("mdns_log rr.name : [%{public}s]", name.c_str());
375    Result r = cacheMap_[name];
376    if (IsDomainCacheAvailable(r.domain)) {
377        r.ipv6 = cacheMap_[r.domain].ipv6;
378        r.addr = cacheMap_[r.domain].addr;
379
380        NETMGR_EXT_LOG_D("mdns_log Add Task DomainCache Available, [%{public}s]", r.domain.c_str());
381        AddTask([cb, info = ConvertResultToInfo(r)]() {
382            if (nullptr != cb) {
383                cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
384            }
385            return true;
386        });
387    } else {
388        ResolveFromNet(r.domain, nullptr);
389        NETMGR_EXT_LOG_D("mdns_log Add Event DomainCache UnAvailable, [%{public}s]", r.domain.c_str());
390        AddEvent(r.domain, [this, cb, r]() mutable {
391            if (!IsDomainCacheAvailable(r.domain)) {
392                return false;
393            }
394            r.ipv6 = cacheMap_[r.domain].ipv6;
395            r.addr = cacheMap_[r.domain].addr;
396            if (nullptr != cb) {
397                cb->HandleResolveResult(ConvertResultToInfo(r), NETMANAGER_EXT_SUCCESS);
398            }
399            return true;
400        });
401    }
402    return true;
403}
404
405bool MDnsProtocolImpl::ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)
406{
407    NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromNet");
408    {
409        std::lock_guard<std::recursive_mutex> guard(mutex_);
410        cacheMap_[name].state = State::ADD;
411        ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
412    }
413    MDnsPayloadParser parser;
414    MDnsMessage msg{};
415    msg.questions.emplace_back(DNSProto::Question{
416        .name = name,
417        .qtype = DNSProto::RRTYPE_SRV,
418        .qclass = DNSProto::RRCLASS_IN,
419    });
420    msg.questions.emplace_back(DNSProto::Question{
421        .name = name,
422        .qtype = DNSProto::RRTYPE_TXT,
423        .qclass = DNSProto::RRCLASS_IN,
424    });
425    msg.header.qdcount = msg.questions.size();
426    AddEvent(name, [this, name, cb]() { return ResolveInstanceFromCache(name, cb); });
427    ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
428    return size > 0;
429}
430
431bool MDnsProtocolImpl::ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)
432{
433    NETMGR_EXT_LOG_D("mdns_log ResolveFromCache");
434    std::lock_guard<std::recursive_mutex> guard(mutex_);
435    if (!IsDomainCacheAvailable(domain)) {
436        return false;
437    }
438    AddTask([this, cb, info = ConvertResultToInfo(cacheMap_[domain])]() {
439        if (nullptr != cb) {
440            cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
441        }
442        return true;
443    });
444    return true;
445}
446
447bool MDnsProtocolImpl::ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)
448{
449    NETMGR_EXT_LOG_D("mdns_log ResolveFromNet");
450    {
451        std::lock_guard<std::recursive_mutex> guard(mutex_);
452        cacheMap_[domain];
453        cacheMap_[domain].domain = domain;
454    }
455    MDnsPayloadParser parser;
456    MDnsMessage msg{};
457    msg.questions.emplace_back(DNSProto::Question{
458        .name = domain,
459        .qtype = DNSProto::RRTYPE_A,
460        .qclass = DNSProto::RRCLASS_IN,
461    });
462    msg.questions.emplace_back(DNSProto::Question{
463        .name = domain,
464        .qtype = DNSProto::RRTYPE_AAAA,
465        .qclass = DNSProto::RRCLASS_IN,
466    });
467    // key is serviceName
468    AddEvent(domain, [this, cb, domain]() { return ResolveFromCache(domain, cb); });
469    ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
470    return size > 0;
471}
472
473int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)
474{
475    NETMGR_EXT_LOG_D("mdns_log execute ResolveInstance");
476    if (!IsInstanceValid(instance)) {
477        return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
478    }
479    std::string name = Decorated(instance);
480    if (!IsDomainValid(name)) {
481        return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
482    }
483    if (ResolveInstanceFromCache(name, cb)) {
484        return NETMANAGER_EXT_SUCCESS;
485    }
486    return ResolveInstanceFromNet(name, cb) ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
487}
488
489int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
490{
491    NETMGR_EXT_LOG_I("mdns_log Announce message");
492    MDnsMessage response{};
493    response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
494    std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
495    response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
496                                                           .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
497                                                           .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
498                                                           .ttl = off ? 0U : DEFAULT_TTL,
499                                                           .rdata = name});
500    response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
501                                                           .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
502                                                           .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
503                                                           .ttl = off ? 0U : DEFAULT_TTL,
504                                                           .rdata = DNSProto::RDataSrv{
505                                                               .priority = 0,
506                                                               .weight = 0,
507                                                               .port = static_cast<uint16_t>(info.port),
508                                                               .name = GetHostDomain(),
509                                                           }});
510    response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
511                                                           .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
512                                                           .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
513                                                           .ttl = off ? 0U : DEFAULT_TTL,
514                                                           .rdata = info.txt});
515    MDnsPayloadParser parser;
516    ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
517    return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
518}
519
520void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
521{
522    if (payload.size() == 0) {
523        return;
524    }
525    MDnsPayloadParser parser;
526    MDnsMessage msg = parser.FromBytes(payload);
527    if (parser.GetError() != 0) {
528        NETMGR_EXT_LOG_E("parser payload failed");
529        return;
530    }
531    if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
532        ProcessQuestion(sock, msg);
533    } else {
534        ProcessAnswer(sock, msg);
535    }
536}
537
538void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
539                                    const std::string &name, const std::any &rdata)
540{
541    rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
542                                                 .rtype = static_cast<uint16_t>(type),
543                                                 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
544                                                 .ttl = DEFAULT_TTL,
545                                                 .rdata = rdata});
546}
547
548void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
549{
550    const sockaddr *saddrIf = listener_.GetSockAddr(sock);
551    if (saddrIf == nullptr) {
552        NETMGR_EXT_LOG_W("mdns_log ProcessQuestion saddrIf is null");
553        return;
554    }
555    std::any anyAddr;
556    DNSProto::RRType anyAddrType;
557    if (saddrIf->sa_family == AF_INET6) {
558        anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
559        anyAddrType = DNSProto::RRTYPE_AAAA;
560    } else {
561        anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
562        anyAddrType = DNSProto::RRTYPE_A;
563    }
564    int phase = 0;
565    MDnsMessage response{};
566    response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
567    for (size_t i = 0; i < msg.header.qdcount; ++i) {
568        ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
569    }
570    if (phase < PHASE_DOMAIN) {
571        AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
572    }
573
574    if (phase != 0 && response.answers.size() > 0) {
575        listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
576    }
577}
578
579void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
580                                             const DNSProto::Question &qu, int &phase, MDnsMessage &response)
581{
582    NETMGR_EXT_LOG_D("mdns_log ProcessQuestionRecord");
583    std::lock_guard<std::recursive_mutex> guard(mutex_);
584    std::string name = qu.name;
585    if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
586        std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
587            if (EndsWith(elem.first, name)) {
588                AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
589                AppendRecord(response.additional, DNSProto::RRTYPE_SRV, elem.first,
590                             DNSProto::RDataSrv{
591                                 .priority = 0,
592                                 .weight = 0,
593                                 .port = static_cast<uint16_t>(elem.second.port),
594                                 .name = GetHostDomain(),
595                             });
596                AppendRecord(response.additional, DNSProto::RRTYPE_TXT, elem.first, elem.second.txt);
597            }
598        });
599        phase = std::max(phase, PHASE_PTR);
600    }
601    if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
602        auto iter = srvMap_.find(name);
603        if (iter == srvMap_.end()) {
604            return;
605        }
606        AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
607                     DNSProto::RDataSrv{
608                         .priority = 0,
609                         .weight = 0,
610                         .port = static_cast<uint16_t>(iter->second.port),
611                         .name = GetHostDomain(),
612                     });
613        phase = std::max(phase, PHASE_SRV);
614    }
615    if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
616        auto iter = srvMap_.find(name);
617        if (iter == srvMap_.end()) {
618            return;
619        }
620        AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
621        phase = std::max(phase, PHASE_SRV);
622    }
623    if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
624        if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
625            return;
626        }
627        AppendRecord(response.answers, anyAddrType, name, anyAddr);
628        phase = std::max(phase, PHASE_DOMAIN);
629    }
630}
631
632void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
633{
634    const sockaddr *saddrIf = listener_.GetSockAddr(sock);
635    if (saddrIf == nullptr) {
636        return;
637    }
638    bool v6 = (saddrIf->sa_family == AF_INET6);
639    std::set<std::string> changed;
640    for (const auto &answer : msg.answers) {
641        ProcessAnswerRecord(v6, answer, changed);
642    }
643    for (const auto &i : msg.additional) {
644        ProcessAnswerRecord(v6, i, changed);
645    }
646    for (const auto &i : changed) {
647        std::lock_guard<std::recursive_mutex> guard(mutex_);
648        RunTaskQueue(taskOnChange_[i]);
649        KillCache(i);
650    }
651}
652
653void MDnsProtocolImpl::UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
654{
655    const std::string *data = std::any_cast<std::string>(&rr.rdata);
656    if (data == nullptr) {
657        return;
658    }
659
660    std::string name = rr.name;
661    if (browserMap_.find(name) == browserMap_.end()) {
662        return;
663    }
664    auto &results = browserMap_[name];
665    std::string srvName;
666    std::string srvType;
667    ExtractNameAndType(*data, srvName, srvType);
668    if (srvName.empty() || srvType.empty()) {
669        return;
670    }
671    auto res =
672        std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
673    if (res == results.end()) {
674        results.emplace_back(Result{
675            .serviceName = srvName,
676            .serviceType = srvType,
677            .state = State::ADD,
678        });
679    }
680    res = std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
681    if (res->serviceName != srvName || res->state == State::DEAD) {
682        res->state = State::REFRESH;
683        res->serviceName = srvName;
684    }
685    if (rr.ttl == 0) {
686        res->state = State::REMOVE;
687    }
688    if (res->state != State::LIVE && res->state != State::DEAD) {
689        changed.emplace(name);
690    }
691    res->ttl = rr.ttl;
692    res->refrehTime = MilliSecondsSinceEpoch();
693}
694
695void MDnsProtocolImpl::UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
696{
697    const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
698    if (srv == nullptr) {
699        return;
700    }
701    std::string name = rr.name;
702    if (cacheMap_.find(name) == cacheMap_.end()) {
703        ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
704        cacheMap_[name].state = State::ADD;
705        cacheMap_[name].domain = srv->name;
706        cacheMap_[name].port = srv->port;
707    }
708    Result &result = cacheMap_[name];
709    if (result.domain != srv->name || result.port != srv->port || result.state == State::DEAD) {
710        if (result.state != State::ADD) {
711            result.state = State::REFRESH;
712        }
713        result.domain = srv->name;
714        result.port = srv->port;
715    }
716    if (rr.ttl == 0) {
717        result.state = State::REMOVE;
718    }
719    if (result.state != State::LIVE && result.state != State::DEAD) {
720        changed.emplace(name);
721    }
722    result.ttl = rr.ttl;
723    result.refrehTime = MilliSecondsSinceEpoch();
724}
725
726void MDnsProtocolImpl::UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
727{
728    const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
729    if (txt == nullptr) {
730        return;
731    }
732    std::string name = rr.name;
733    if (cacheMap_.find(name) == cacheMap_.end()) {
734        ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
735        cacheMap_[name].state = State::ADD;
736        cacheMap_[name].txt = *txt;
737    }
738    Result &result = cacheMap_[name];
739    if (result.txt != *txt || result.state == State::DEAD) {
740        if (result.state != State::ADD) {
741            result.state = State::REFRESH;
742        }
743        result.txt = *txt;
744    }
745    if (rr.ttl == 0) {
746        result.state = State::REMOVE;
747    }
748    if (result.state != State::LIVE && result.state != State::DEAD) {
749        changed.emplace(name);
750    }
751    result.ttl = rr.ttl;
752    result.refrehTime = MilliSecondsSinceEpoch();
753}
754
755void MDnsProtocolImpl::UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
756{
757    if (v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
758        return;
759    }
760    const std::string addr = AddrToString(rr.rdata);
761    bool v6rr = (rr.rtype == DNSProto::RRTYPE_AAAA);
762    if (addr.empty()) {
763        return;
764    }
765    std::string name = rr.name;
766    if (cacheMap_.find(name) == cacheMap_.end()) {
767        ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
768        cacheMap_[name].state = State::ADD;
769        cacheMap_[name].ipv6 = v6rr;
770        cacheMap_[name].addr = addr;
771    }
772    Result &result = cacheMap_[name];
773    if (result.addr != addr || result.ipv6 != v6rr || result.state == State::DEAD) {
774        result.state = State::REFRESH;
775        result.addr = addr;
776        result.ipv6 = v6rr;
777    }
778    if (rr.ttl == 0) {
779        result.state = State::REMOVE;
780    }
781    if (result.state != State::LIVE && result.state != State::DEAD) {
782        changed.emplace(name);
783    }
784    result.ttl = rr.ttl;
785    result.refrehTime = MilliSecondsSinceEpoch();
786}
787
788void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
789{
790    NETMGR_EXT_LOG_D("mdns_log ProcessAnswerRecord, type=[%{public}d]", rr.rtype);
791    std::lock_guard<std::recursive_mutex> guard(mutex_);
792    std::string name = rr.name;
793    if (cacheMap_.find(name) == cacheMap_.end() && browserMap_.find(name) == browserMap_.end() &&
794        srvMap_.find(name) != srvMap_.end()) {
795        return;
796    }
797    if (rr.rtype == DNSProto::RRTYPE_PTR) {
798        UpdatePtr(v6, rr, changed);
799    } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
800        UpdateSrv(v6, rr, changed);
801    } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
802        UpdateTxt(v6, rr, changed);
803    } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
804        UpdateAddr(v6, rr, changed);
805    } else {
806        NETMGR_EXT_LOG_D("mdns_log Unknown packet received, type=[%{public}d]", rr.rtype);
807    }
808}
809
810std::string MDnsProtocolImpl::GetHostDomain()
811{
812    if (config_.hostname.empty()) {
813        char buffer[MDNS_MAX_DOMAIN_LABEL];
814        if (gethostname(buffer, sizeof(buffer)) == 0) {
815            config_.hostname = buffer;
816            static auto uid = []() {
817                std::random_device rd;
818                return rd();
819            }();
820            config_.hostname += std::to_string(uid);
821        }
822    }
823    return Decorated(config_.hostname);
824}
825
826void MDnsProtocolImpl::AddTask(const Task &task, bool atonce)
827{
828    {
829        std::lock_guard<std::recursive_mutex> guard(mutex_);
830        taskQueue_.emplace_back(task);
831    }
832    if (atonce) {
833        listener_.TriggerRefresh();
834    }
835}
836
837MDnsServiceInfo MDnsProtocolImpl::ConvertResultToInfo(const MDnsProtocolImpl::Result &result)
838{
839    MDnsServiceInfo info;
840    info.name = result.serviceName;
841    info.type = result.serviceType;
842    if (!result.addr.empty()) {
843        info.family = result.ipv6 ? MDnsServiceInfo::IPV6 : MDnsServiceInfo::IPV4;
844    }
845    info.addr = result.addr;
846    info.port = result.port;
847    info.txtRecord = result.txt;
848    return info;
849}
850
851bool MDnsProtocolImpl::IsCacheAvailable(const std::string &key)
852{
853    constexpr int64_t ms2S = 1000LL;
854    NETMGR_EXT_LOG_D("mdns_log IsCacheAvailable, ttl=[%{public}u]", cacheMap_[key].ttl);
855    return cacheMap_.find(key) != cacheMap_.end() &&
856           (ms2S * cacheMap_[key].ttl) > static_cast<uint32_t>(MilliSecondsSinceEpoch() - cacheMap_[key].refrehTime);
857}
858
859bool MDnsProtocolImpl::IsDomainCacheAvailable(const std::string &key)
860{
861    return IsCacheAvailable(key) && !cacheMap_[key].addr.empty();
862}
863
864bool MDnsProtocolImpl::IsInstanceCacheAvailable(const std::string &key)
865{
866    return IsCacheAvailable(key) && !cacheMap_[key].domain.empty();
867}
868
869bool MDnsProtocolImpl::IsBrowserAvailable(const std::string &key)
870{
871    return browserMap_.find(key) != browserMap_.end() && !browserMap_[key].empty();
872}
873
874void MDnsProtocolImpl::AddEvent(const std::string &key, const Task &task)
875{
876    std::lock_guard<std::recursive_mutex> guard(mutex_);
877    taskOnChange_[key].emplace_back(task);
878}
879
880void MDnsProtocolImpl::RunTaskQueue(std::list<Task> &queue)
881{
882    std::list<Task> tmp;
883    for (auto &&func : queue) {
884        if (!func()) {
885            tmp.emplace_back(func);
886        }
887    }
888    tmp.swap(queue);
889}
890
891void MDnsProtocolImpl::KillCache(const std::string &key)
892{
893    NETMGR_EXT_LOG_D("mdns_log KillCache");
894    if (IsBrowserAvailable(key) && browserMap_.find(key) != browserMap_.end()) {
895        for (auto it = browserMap_[key].begin(); it != browserMap_[key].end();) {
896            KillBrowseCache(key, it);
897        }
898    }
899    if (IsCacheAvailable(key)) {
900        std::lock_guard<std::recursive_mutex> guard(mutex_);
901        auto &elem = cacheMap_[key];
902        if (elem.state == State::REMOVE) {
903            elem.state = State::DEAD;
904            cacheMap_.erase(key);
905        } else if (elem.state == State::ADD || elem.state == State::REFRESH) {
906            elem.state = State::LIVE;
907        }
908    }
909}
910
911void MDnsProtocolImpl::KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)
912{
913    NETMGR_EXT_LOG_D("mdns_log KillBrowseCache");
914    if (it->state == State::REMOVE) {
915        it->state = State::DEAD;
916        if (nameCbMap_.find(key) != nameCbMap_.end()) {
917            NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
918            nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
919        }
920        std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
921        cacheMap_.erase(fullName);
922        it = browserMap_[key].erase(it);
923    } else if (it->state == State::ADD || it->state == State::REFRESH) {
924        it->state = State::LIVE;
925        it++;
926    } else {
927        it++;
928    }
929}
930
931int32_t MDnsProtocolImpl::StopCbMap(const std::string &serviceType)
932{
933    NETMGR_EXT_LOG_D("mdns_log StopCbMap");
934    std::lock_guard<std::recursive_mutex> guard(mutex_);
935    std::string name = Decorated(serviceType);
936    sptr<IDiscoveryCallback> cb = nullptr;
937    if (nameCbMap_.find(name) != nameCbMap_.end()) {
938        cb = nameCbMap_[name];
939        nameCbMap_.erase(name);
940    }
941    taskOnChange_.erase(name);
942    auto it = browserMap_.find(name);
943    if (it != browserMap_.end()) {
944        if (cb != nullptr) {
945            NETMGR_EXT_LOG_I("mdns_log StopCbMap res size:[%{public}zu]", it->second.size());
946            for (auto &&res : it->second) {
947                NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
948                cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
949            }
950        }
951        browserMap_.erase(name);
952    }
953    return NETMANAGER_SUCCESS;
954}
955} // namespace NetManagerStandard
956} // namespace OHOS
957