1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mdns_protocol_impl.h"
17
18 #include <arpa/inet.h>
19 #include <cstddef>
20 #include <iostream>
21 #include <random>
22 #include <sys/types.h>
23 #include <unistd.h>
24 #include <fcntl.h>
25
26 #include "mdns_manager.h"
27 #include "mdns_packet_parser.h"
28 #include "net_conn_client.h"
29 #include "netmgr_ext_log_wrapper.h"
30
31 #include "securec.h"
32
33 namespace OHOS {
34 namespace NetManagerStandard {
35
36 constexpr uint32_t DEFAULT_INTEVAL_MS = 2000;
37 constexpr uint32_t DEFAULT_LOST_MS = 10000;
38 constexpr uint32_t DEFAULT_TTL = 120;
39 constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
40
41 constexpr int PHASE_PTR = 1;
42 constexpr int PHASE_SRV = 2;
43 constexpr int PHASE_DOMAIN = 3;
44
AddrToString(const std::any &addr)45 std::string AddrToString(const std::any &addr)
46 {
47 char buf[INET6_ADDRSTRLEN] = {0};
48 if (std::any_cast<in_addr>(&addr)) {
49 if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
50 return std::string{};
51 }
52 } else if (std::any_cast<in6_addr>(&addr)) {
53 if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
54 return std::string{};
55 }
56 }
57 return std::string(buf);
58 }
59
MilliSecondsSinceEpoch()60 int64_t MilliSecondsSinceEpoch()
61 {
62 return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
63 .count();
64 }
65
MDnsProtocolImpl()66 MDnsProtocolImpl::MDnsProtocolImpl()
67 {
68 Init();
69 }
70
Init()71 void MDnsProtocolImpl::Init()
72 {
73 NETMGR_EXT_LOG_D("mdns_log MDnsProtocolImpl init");
74 listener_.Stop();
75 listener_.CloseAllSocket();
76
77 if (config_.configAllIface) {
78 listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
79 } else {
80 listener_.OpenSocketForDefault(config_.ipv6Support);
81 }
82 listener_.SetReceiveHandler(
83 [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
84 listener_.SetFinishedHandler([this](int sock) {
85 std::lock_guard<std::recursive_mutex> guard(mutex_);
86 RunTaskQueue(taskQueue_);
87 });
88 listener_.Start();
89
90 taskQueue_.clear();
91 taskOnChange_.clear();
92 AddTask([this]() { return Browse(); }, false);
93 }
94
Browse()95 bool MDnsProtocolImpl::Browse()
96 {
97 if (lastRunTime != -1 && MilliSecondsSinceEpoch() - lastRunTime < DEFAULT_INTEVAL_MS) {
98 return false;
99 }
100 lastRunTime = MilliSecondsSinceEpoch();
101 std::lock_guard<std::recursive_mutex> guard(mutex_);
102 for (auto &&[key, res] : browserMap_) {
103 NETMGR_EXT_LOG_D("mdns_log Browse browserMap_ key[%{public}s] res.size[%{public}zu]", key.c_str(), res.size());
104 if (nameCbMap_.find(key) != nameCbMap_.end() &&
105 !MDnsManager::GetInstance().IsAvailableCallback(nameCbMap_[key])) {
106 continue;
107 }
108 handleOfflineService(key, res);
109 MDnsPayloadParser parser;
110 MDnsMessage msg{};
111 msg.questions.emplace_back(DNSProto::Question{
112 .name = key,
113 .qtype = DNSProto::RRTYPE_PTR,
114 .qclass = DNSProto::RRCLASS_IN,
115 });
116 listener_.MulticastAll(parser.ToBytes(msg));
117 }
118 return false;
119 }
120
ConnectControl(int32_t sockfd, sockaddr* serverAddr)121 int32_t MDnsProtocolImpl::ConnectControl(int32_t sockfd, sockaddr* serverAddr)
122 {
123 uint32_t flags = static_cast<uint32_t>(fcntl(sockfd, F_GETFL, 0));
124 fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
125 int32_t ret = connect(sockfd, serverAddr, sizeof(sockaddr));
126 if ((ret < 0) && (errno != EINPROGRESS)) {
127 NETMGR_EXT_LOG_E("connect error: %{public}d", errno);
128 return NETMANAGER_EXT_ERR_INTERNAL;
129 }
130 if (ret == 0) {
131 fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
132 NETMGR_EXT_LOG_I("connect success.");
133 return NETMANAGER_EXT_SUCCESS;
134 }
135
136 fd_set rset;
137 FD_ZERO(&rset);
138 FD_SET(sockfd, &rset);
139 fd_set wset = rset;
140 timeval tval {1, 0};
141 ret = select(sockfd + 1, &rset, &wset, NULL, &tval);
142 if (ret < 0) { // select error.
143 NETMGR_EXT_LOG_E("select error: %{public}d", errno);
144 return NETMANAGER_EXT_ERR_INTERNAL;
145 }
146 if (ret == 0) { // timeout
147 NETMGR_EXT_LOG_E("connect timeout...");
148 return NETMANAGER_EXT_ERR_INTERNAL;
149 }
150 if (!FD_ISSET(sockfd, &rset) && !FD_ISSET(sockfd, &wset)) {
151 NETMGR_EXT_LOG_E("select error: sockfd not set");
152 return NETMANAGER_EXT_ERR_INTERNAL;
153 }
154
155 int32_t result = NETMANAGER_EXT_ERR_INTERNAL;
156 socklen_t len = sizeof(result);
157 if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &result, &len) < 0) {
158 NETMGR_EXT_LOG_E("getsockopt error: %{public}d", errno);
159 return NETMANAGER_EXT_ERR_INTERNAL;
160 }
161 if (result != 0) { // connect failed.
162 NETMGR_EXT_LOG_E("connect failed. error: %{public}d", result);
163 return NETMANAGER_EXT_ERR_INTERNAL;
164 }
165 fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
166 NETMGR_EXT_LOG_I("lost but connect success.");
167 return NETMANAGER_EXT_SUCCESS;
168 }
169
IsConnectivity(const std::string &ip, int32_t port)170 bool MDnsProtocolImpl::IsConnectivity(const std::string &ip, int32_t port)
171 {
172 if (ip.empty()) {
173 NETMGR_EXT_LOG_E("ip is empty");
174 return false;
175 }
176
177 int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0);
178 if (sockfd < 0) {
179 NETMGR_EXT_LOG_E("create socket error: %{public}d", errno);
180 return false;
181 }
182
183 struct sockaddr_in serverAddr;
184 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
185 NETMGR_EXT_LOG_E("memset_s serverAddr failed!");
186 close(sockfd);
187 return false;
188 }
189
190 serverAddr.sin_family = AF_INET;
191 serverAddr.sin_addr.s_addr = inet_addr(ip.c_str());
192 serverAddr.sin_port = htons(port);
193 if (ConnectControl(sockfd, (struct sockaddr*)&serverAddr) != NETMANAGER_EXT_SUCCESS) {
194 NETMGR_EXT_LOG_I("connect error: %{public}d", errno);
195 close(sockfd);
196 return false;
197 }
198
199 close(sockfd);
200 return true;
201 }
202
handleOfflineService(const std::string &key, std::vector<Result> &res)203 void MDnsProtocolImpl::handleOfflineService(const std::string &key, std::vector<Result> &res)
204 {
205 NETMGR_EXT_LOG_D("mdns_log handleOfflineService key:[%{public}s]", key.c_str());
206 for (auto it = res.begin(); it != res.end();) {
207 if (lastRunTime - it->refrehTime > DEFAULT_LOST_MS && it->state == State::LIVE) {
208 std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
209 if ((cacheMap_.find(fullName) != cacheMap_.end()) &&
210 IsConnectivity(cacheMap_[fullName].addr, cacheMap_[fullName].port)) {
211 it++;
212 continue;
213 }
214
215 it->state = State::DEAD;
216 if (nameCbMap_.find(key) != nameCbMap_.end() && nameCbMap_[key] != nullptr) {
217 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
218 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
219 }
220 it = res.erase(it);
221 cacheMap_.erase(fullName);
222 } else {
223 it++;
224 }
225 }
226 }
227
SetConfig(const MDnsConfig &config)228 void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
229 {
230 config_ = config;
231 }
232
GetConfig() const233 const MDnsConfig &MDnsProtocolImpl::GetConfig() const
234 {
235 return config_;
236 }
237
Decorated(const std::string &name) const238 std::string MDnsProtocolImpl::Decorated(const std::string &name) const
239 {
240 return name + config_.topDomain;
241 }
242
Register(const Result &info)243 int32_t MDnsProtocolImpl::Register(const Result &info)
244 {
245 NETMGR_EXT_LOG_D("mdns_log Register");
246 if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
247 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
248 }
249 std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
250 if (!IsDomainValid(name)) {
251 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
252 }
253 {
254 std::lock_guard<std::recursive_mutex> guard(mutex_);
255 if (srvMap_.find(name) != srvMap_.end()) {
256 return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
257 }
258 srvMap_.emplace(name, info);
259 }
260 return Announce(info, false);
261 }
262
UnRegister(const std::string &key)263 int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
264 {
265 NETMGR_EXT_LOG_D("mdns_log UnRegister");
266 std::string name = Decorated(key);
267 std::lock_guard<std::recursive_mutex> guard(mutex_);
268 if (srvMap_.find(name) != srvMap_.end()) {
269 Announce(srvMap_[name], true);
270 srvMap_.erase(name);
271 return NETMANAGER_EXT_SUCCESS;
272 }
273 return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
274 }
275
DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)276 bool MDnsProtocolImpl::DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
277 {
278 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache");
279 std::string name = Decorated(serviceType);
280 std::lock_guard<std::recursive_mutex> guard(mutex_);
281 if (!IsBrowserAvailable(name)) {
282 return false;
283 }
284
285 if (browserMap_.find(name) == browserMap_.end()) {
286 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache browserMap_ not find name");
287 return false;
288 }
289
290 for (auto &res : browserMap_[name]) {
291 if (res.state == State::REMOVE || res.state == State::DEAD) {
292 continue;
293 }
294 AddTask([cb, info = ConvertResultToInfo(res)]() {
295 NETMGR_EXT_LOG_W("mdns_log DiscoveryFromCache ConvertResultToInfo HandleServiceFound");
296 if (MDnsManager::GetInstance().IsAvailableCallback(cb)) {
297 cb->HandleServiceFound(info, NETMANAGER_EXT_SUCCESS);
298 }
299 return true;
300 });
301 }
302 return true;
303 }
304
DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)305 bool MDnsProtocolImpl::DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
306 {
307 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromNet");
308 std::string name = Decorated(serviceType);
309 std::lock_guard<std::recursive_mutex> guard(mutex_);
310 browserMap_.insert({name, std::vector<Result>{}});
311 nameCbMap_[name] = cb;
312 // key is serviceTYpe
313 AddEvent(name, [this, name, cb]() {
314 std::lock_guard<std::recursive_mutex> guard(mutex_);
315 if (!IsBrowserAvailable(name)) {
316 return false;
317 }
318 if (!MDnsManager::GetInstance().IsAvailableCallback(cb)) {
319 return true;
320 }
321 for (auto &res : browserMap_[name]) {
322 std::string fullName = Decorated(res.serviceName + MDNS_DOMAIN_SPLITER_STR + res.serviceType);
323 NETMGR_EXT_LOG_W("mdns_log DiscoveryFromNet name:[%{public}s] fullName:[%{public}s]", name.c_str(),
324 fullName.c_str());
325 if (cacheMap_.find(fullName) == cacheMap_.end() ||
326 (res.state == State::ADD || res.state == State::REFRESH)) {
327 NETMGR_EXT_LOG_W("mdns_log HandleServiceFound");
328 cb->HandleServiceFound(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
329 res.state = State::LIVE;
330 }
331 if (res.state == State::REMOVE) {
332 res.state = State::DEAD;
333 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
334 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
335 if (cacheMap_.find(fullName) != cacheMap_.end()) {
336 cacheMap_.erase(fullName);
337 }
338 }
339 }
340 return false;
341 });
342
343 AddTask([=]() {
344 MDnsPayloadParser parser;
345 MDnsMessage msg{};
346 msg.questions.emplace_back(DNSProto::Question{
347 .name = name,
348 .qtype = DNSProto::RRTYPE_PTR,
349 .qclass = DNSProto::RRCLASS_IN,
350 });
351 listener_.MulticastAll(parser.ToBytes(msg));
352 return true;
353 }, false);
354 return true;
355 }
356
Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)357 int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
358 {
359 NETMGR_EXT_LOG_D("mdns_log Discovery");
360 DiscoveryFromCache(serviceType, cb);
361 DiscoveryFromNet(serviceType, cb);
362 return NETMANAGER_EXT_SUCCESS;
363 }
364
ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)365 bool MDnsProtocolImpl::ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)
366 {
367 NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromCache");
368 std::lock_guard<std::recursive_mutex> guard(mutex_);
369 if (!IsInstanceCacheAvailable(name)) {
370 NETMGR_EXT_LOG_W("mdns_log ResolveInstanceFromCache cacheMap_ has no element [%{public}s]", name.c_str());
371 return false;
372 }
373
374 NETMGR_EXT_LOG_I("mdns_log rr.name : [%{public}s]", name.c_str());
375 Result r = cacheMap_[name];
376 if (IsDomainCacheAvailable(r.domain)) {
377 r.ipv6 = cacheMap_[r.domain].ipv6;
378 r.addr = cacheMap_[r.domain].addr;
379
380 NETMGR_EXT_LOG_D("mdns_log Add Task DomainCache Available, [%{public}s]", r.domain.c_str());
381 AddTask([cb, info = ConvertResultToInfo(r)]() {
382 if (nullptr != cb) {
383 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
384 }
385 return true;
386 });
387 } else {
388 ResolveFromNet(r.domain, nullptr);
389 NETMGR_EXT_LOG_D("mdns_log Add Event DomainCache UnAvailable, [%{public}s]", r.domain.c_str());
390 AddEvent(r.domain, [this, cb, r]() mutable {
391 if (!IsDomainCacheAvailable(r.domain)) {
392 return false;
393 }
394 r.ipv6 = cacheMap_[r.domain].ipv6;
395 r.addr = cacheMap_[r.domain].addr;
396 if (nullptr != cb) {
397 cb->HandleResolveResult(ConvertResultToInfo(r), NETMANAGER_EXT_SUCCESS);
398 }
399 return true;
400 });
401 }
402 return true;
403 }
404
ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)405 bool MDnsProtocolImpl::ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)
406 {
407 NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromNet");
408 {
409 std::lock_guard<std::recursive_mutex> guard(mutex_);
410 cacheMap_[name].state = State::ADD;
411 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
412 }
413 MDnsPayloadParser parser;
414 MDnsMessage msg{};
415 msg.questions.emplace_back(DNSProto::Question{
416 .name = name,
417 .qtype = DNSProto::RRTYPE_SRV,
418 .qclass = DNSProto::RRCLASS_IN,
419 });
420 msg.questions.emplace_back(DNSProto::Question{
421 .name = name,
422 .qtype = DNSProto::RRTYPE_TXT,
423 .qclass = DNSProto::RRCLASS_IN,
424 });
425 msg.header.qdcount = msg.questions.size();
426 AddEvent(name, [this, name, cb]() { return ResolveInstanceFromCache(name, cb); });
427 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
428 return size > 0;
429 }
430
ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)431 bool MDnsProtocolImpl::ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)
432 {
433 NETMGR_EXT_LOG_D("mdns_log ResolveFromCache");
434 std::lock_guard<std::recursive_mutex> guard(mutex_);
435 if (!IsDomainCacheAvailable(domain)) {
436 return false;
437 }
438 AddTask([this, cb, info = ConvertResultToInfo(cacheMap_[domain])]() {
439 if (nullptr != cb) {
440 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
441 }
442 return true;
443 });
444 return true;
445 }
446
ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)447 bool MDnsProtocolImpl::ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)
448 {
449 NETMGR_EXT_LOG_D("mdns_log ResolveFromNet");
450 {
451 std::lock_guard<std::recursive_mutex> guard(mutex_);
452 cacheMap_[domain];
453 cacheMap_[domain].domain = domain;
454 }
455 MDnsPayloadParser parser;
456 MDnsMessage msg{};
457 msg.questions.emplace_back(DNSProto::Question{
458 .name = domain,
459 .qtype = DNSProto::RRTYPE_A,
460 .qclass = DNSProto::RRCLASS_IN,
461 });
462 msg.questions.emplace_back(DNSProto::Question{
463 .name = domain,
464 .qtype = DNSProto::RRTYPE_AAAA,
465 .qclass = DNSProto::RRCLASS_IN,
466 });
467 // key is serviceName
468 AddEvent(domain, [this, cb, domain]() { return ResolveFromCache(domain, cb); });
469 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
470 return size > 0;
471 }
472
ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)473 int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)
474 {
475 NETMGR_EXT_LOG_D("mdns_log execute ResolveInstance");
476 if (!IsInstanceValid(instance)) {
477 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
478 }
479 std::string name = Decorated(instance);
480 if (!IsDomainValid(name)) {
481 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
482 }
483 if (ResolveInstanceFromCache(name, cb)) {
484 return NETMANAGER_EXT_SUCCESS;
485 }
486 return ResolveInstanceFromNet(name, cb) ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
487 }
488
Announce(const Result &info, bool off)489 int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
490 {
491 NETMGR_EXT_LOG_I("mdns_log Announce message");
492 MDnsMessage response{};
493 response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
494 std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
495 response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
496 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
497 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
498 .ttl = off ? 0U : DEFAULT_TTL,
499 .rdata = name});
500 response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
501 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
502 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
503 .ttl = off ? 0U : DEFAULT_TTL,
504 .rdata = DNSProto::RDataSrv{
505 .priority = 0,
506 .weight = 0,
507 .port = static_cast<uint16_t>(info.port),
508 .name = GetHostDomain(),
509 }});
510 response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
511 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
512 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
513 .ttl = off ? 0U : DEFAULT_TTL,
514 .rdata = info.txt});
515 MDnsPayloadParser parser;
516 ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
517 return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
518 }
519
ReceivePacket(int sock, const MDnsPayload &payload)520 void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
521 {
522 if (payload.size() == 0) {
523 return;
524 }
525 MDnsPayloadParser parser;
526 MDnsMessage msg = parser.FromBytes(payload);
527 if (parser.GetError() != 0) {
528 NETMGR_EXT_LOG_E("parser payload failed");
529 return;
530 }
531 if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
532 ProcessQuestion(sock, msg);
533 } else {
534 ProcessAnswer(sock, msg);
535 }
536 }
537
AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type, const std::string &name, const std::any &rdata)538 void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
539 const std::string &name, const std::any &rdata)
540 {
541 rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
542 .rtype = static_cast<uint16_t>(type),
543 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
544 .ttl = DEFAULT_TTL,
545 .rdata = rdata});
546 }
547
ProcessQuestion(int sock, const MDnsMessage &msg)548 void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
549 {
550 const sockaddr *saddrIf = listener_.GetSockAddr(sock);
551 if (saddrIf == nullptr) {
552 NETMGR_EXT_LOG_W("mdns_log ProcessQuestion saddrIf is null");
553 return;
554 }
555 std::any anyAddr;
556 DNSProto::RRType anyAddrType;
557 if (saddrIf->sa_family == AF_INET6) {
558 anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
559 anyAddrType = DNSProto::RRTYPE_AAAA;
560 } else {
561 anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
562 anyAddrType = DNSProto::RRTYPE_A;
563 }
564 int phase = 0;
565 MDnsMessage response{};
566 response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
567 for (size_t i = 0; i < msg.header.qdcount; ++i) {
568 ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
569 }
570 if (phase < PHASE_DOMAIN) {
571 AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
572 }
573
574 if (phase != 0 && response.answers.size() > 0) {
575 listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
576 }
577 }
578
ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType, const DNSProto::Question &qu, int &phase, MDnsMessage &response)579 void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
580 const DNSProto::Question &qu, int &phase, MDnsMessage &response)
581 {
582 NETMGR_EXT_LOG_D("mdns_log ProcessQuestionRecord");
583 std::lock_guard<std::recursive_mutex> guard(mutex_);
584 std::string name = qu.name;
585 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
586 std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
587 if (EndsWith(elem.first, name)) {
588 AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
589 AppendRecord(response.additional, DNSProto::RRTYPE_SRV, elem.first,
590 DNSProto::RDataSrv{
591 .priority = 0,
592 .weight = 0,
593 .port = static_cast<uint16_t>(elem.second.port),
594 .name = GetHostDomain(),
595 });
596 AppendRecord(response.additional, DNSProto::RRTYPE_TXT, elem.first, elem.second.txt);
597 }
598 });
599 phase = std::max(phase, PHASE_PTR);
600 }
601 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
602 auto iter = srvMap_.find(name);
603 if (iter == srvMap_.end()) {
604 return;
605 }
606 AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
607 DNSProto::RDataSrv{
608 .priority = 0,
609 .weight = 0,
610 .port = static_cast<uint16_t>(iter->second.port),
611 .name = GetHostDomain(),
612 });
613 phase = std::max(phase, PHASE_SRV);
614 }
615 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
616 auto iter = srvMap_.find(name);
617 if (iter == srvMap_.end()) {
618 return;
619 }
620 AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
621 phase = std::max(phase, PHASE_SRV);
622 }
623 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
624 if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
625 return;
626 }
627 AppendRecord(response.answers, anyAddrType, name, anyAddr);
628 phase = std::max(phase, PHASE_DOMAIN);
629 }
630 }
631
ProcessAnswer(int sock, const MDnsMessage &msg)632 void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
633 {
634 const sockaddr *saddrIf = listener_.GetSockAddr(sock);
635 if (saddrIf == nullptr) {
636 return;
637 }
638 bool v6 = (saddrIf->sa_family == AF_INET6);
639 std::set<std::string> changed;
640 for (const auto &answer : msg.answers) {
641 ProcessAnswerRecord(v6, answer, changed);
642 }
643 for (const auto &i : msg.additional) {
644 ProcessAnswerRecord(v6, i, changed);
645 }
646 for (const auto &i : changed) {
647 std::lock_guard<std::recursive_mutex> guard(mutex_);
648 RunTaskQueue(taskOnChange_[i]);
649 KillCache(i);
650 }
651 }
652
UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)653 void MDnsProtocolImpl::UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
654 {
655 const std::string *data = std::any_cast<std::string>(&rr.rdata);
656 if (data == nullptr) {
657 return;
658 }
659
660 std::string name = rr.name;
661 if (browserMap_.find(name) == browserMap_.end()) {
662 return;
663 }
664 auto &results = browserMap_[name];
665 std::string srvName;
666 std::string srvType;
667 ExtractNameAndType(*data, srvName, srvType);
668 if (srvName.empty() || srvType.empty()) {
669 return;
670 }
671 auto res =
672 std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
673 if (res == results.end()) {
674 results.emplace_back(Result{
675 .serviceName = srvName,
676 .serviceType = srvType,
677 .state = State::ADD,
678 });
679 }
680 res = std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
681 if (res->serviceName != srvName || res->state == State::DEAD) {
682 res->state = State::REFRESH;
683 res->serviceName = srvName;
684 }
685 if (rr.ttl == 0) {
686 res->state = State::REMOVE;
687 }
688 if (res->state != State::LIVE && res->state != State::DEAD) {
689 changed.emplace(name);
690 }
691 res->ttl = rr.ttl;
692 res->refrehTime = MilliSecondsSinceEpoch();
693 }
694
UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)695 void MDnsProtocolImpl::UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
696 {
697 const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
698 if (srv == nullptr) {
699 return;
700 }
701 std::string name = rr.name;
702 if (cacheMap_.find(name) == cacheMap_.end()) {
703 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
704 cacheMap_[name].state = State::ADD;
705 cacheMap_[name].domain = srv->name;
706 cacheMap_[name].port = srv->port;
707 }
708 Result &result = cacheMap_[name];
709 if (result.domain != srv->name || result.port != srv->port || result.state == State::DEAD) {
710 if (result.state != State::ADD) {
711 result.state = State::REFRESH;
712 }
713 result.domain = srv->name;
714 result.port = srv->port;
715 }
716 if (rr.ttl == 0) {
717 result.state = State::REMOVE;
718 }
719 if (result.state != State::LIVE && result.state != State::DEAD) {
720 changed.emplace(name);
721 }
722 result.ttl = rr.ttl;
723 result.refrehTime = MilliSecondsSinceEpoch();
724 }
725
UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)726 void MDnsProtocolImpl::UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
727 {
728 const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
729 if (txt == nullptr) {
730 return;
731 }
732 std::string name = rr.name;
733 if (cacheMap_.find(name) == cacheMap_.end()) {
734 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
735 cacheMap_[name].state = State::ADD;
736 cacheMap_[name].txt = *txt;
737 }
738 Result &result = cacheMap_[name];
739 if (result.txt != *txt || result.state == State::DEAD) {
740 if (result.state != State::ADD) {
741 result.state = State::REFRESH;
742 }
743 result.txt = *txt;
744 }
745 if (rr.ttl == 0) {
746 result.state = State::REMOVE;
747 }
748 if (result.state != State::LIVE && result.state != State::DEAD) {
749 changed.emplace(name);
750 }
751 result.ttl = rr.ttl;
752 result.refrehTime = MilliSecondsSinceEpoch();
753 }
754
UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)755 void MDnsProtocolImpl::UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
756 {
757 if (v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
758 return;
759 }
760 const std::string addr = AddrToString(rr.rdata);
761 bool v6rr = (rr.rtype == DNSProto::RRTYPE_AAAA);
762 if (addr.empty()) {
763 return;
764 }
765 std::string name = rr.name;
766 if (cacheMap_.find(name) == cacheMap_.end()) {
767 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
768 cacheMap_[name].state = State::ADD;
769 cacheMap_[name].ipv6 = v6rr;
770 cacheMap_[name].addr = addr;
771 }
772 Result &result = cacheMap_[name];
773 if (result.addr != addr || result.ipv6 != v6rr || result.state == State::DEAD) {
774 result.state = State::REFRESH;
775 result.addr = addr;
776 result.ipv6 = v6rr;
777 }
778 if (rr.ttl == 0) {
779 result.state = State::REMOVE;
780 }
781 if (result.state != State::LIVE && result.state != State::DEAD) {
782 changed.emplace(name);
783 }
784 result.ttl = rr.ttl;
785 result.refrehTime = MilliSecondsSinceEpoch();
786 }
787
ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)788 void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
789 {
790 NETMGR_EXT_LOG_D("mdns_log ProcessAnswerRecord, type=[%{public}d]", rr.rtype);
791 std::lock_guard<std::recursive_mutex> guard(mutex_);
792 std::string name = rr.name;
793 if (cacheMap_.find(name) == cacheMap_.end() && browserMap_.find(name) == browserMap_.end() &&
794 srvMap_.find(name) != srvMap_.end()) {
795 return;
796 }
797 if (rr.rtype == DNSProto::RRTYPE_PTR) {
798 UpdatePtr(v6, rr, changed);
799 } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
800 UpdateSrv(v6, rr, changed);
801 } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
802 UpdateTxt(v6, rr, changed);
803 } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
804 UpdateAddr(v6, rr, changed);
805 } else {
806 NETMGR_EXT_LOG_D("mdns_log Unknown packet received, type=[%{public}d]", rr.rtype);
807 }
808 }
809
GetHostDomain()810 std::string MDnsProtocolImpl::GetHostDomain()
811 {
812 if (config_.hostname.empty()) {
813 char buffer[MDNS_MAX_DOMAIN_LABEL];
814 if (gethostname(buffer, sizeof(buffer)) == 0) {
815 config_.hostname = buffer;
816 static auto uid = []() {
817 std::random_device rd;
818 return rd();
819 }();
820 config_.hostname += std::to_string(uid);
821 }
822 }
823 return Decorated(config_.hostname);
824 }
825
AddTask(const Task &task, bool atonce)826 void MDnsProtocolImpl::AddTask(const Task &task, bool atonce)
827 {
828 {
829 std::lock_guard<std::recursive_mutex> guard(mutex_);
830 taskQueue_.emplace_back(task);
831 }
832 if (atonce) {
833 listener_.TriggerRefresh();
834 }
835 }
836
ConvertResultToInfo(const MDnsProtocolImpl::Result &result)837 MDnsServiceInfo MDnsProtocolImpl::ConvertResultToInfo(const MDnsProtocolImpl::Result &result)
838 {
839 MDnsServiceInfo info;
840 info.name = result.serviceName;
841 info.type = result.serviceType;
842 if (!result.addr.empty()) {
843 info.family = result.ipv6 ? MDnsServiceInfo::IPV6 : MDnsServiceInfo::IPV4;
844 }
845 info.addr = result.addr;
846 info.port = result.port;
847 info.txtRecord = result.txt;
848 return info;
849 }
850
IsCacheAvailable(const std::string &key)851 bool MDnsProtocolImpl::IsCacheAvailable(const std::string &key)
852 {
853 constexpr int64_t ms2S = 1000LL;
854 NETMGR_EXT_LOG_D("mdns_log IsCacheAvailable, ttl=[%{public}u]", cacheMap_[key].ttl);
855 return cacheMap_.find(key) != cacheMap_.end() &&
856 (ms2S * cacheMap_[key].ttl) > static_cast<uint32_t>(MilliSecondsSinceEpoch() - cacheMap_[key].refrehTime);
857 }
858
IsDomainCacheAvailable(const std::string &key)859 bool MDnsProtocolImpl::IsDomainCacheAvailable(const std::string &key)
860 {
861 return IsCacheAvailable(key) && !cacheMap_[key].addr.empty();
862 }
863
IsInstanceCacheAvailable(const std::string &key)864 bool MDnsProtocolImpl::IsInstanceCacheAvailable(const std::string &key)
865 {
866 return IsCacheAvailable(key) && !cacheMap_[key].domain.empty();
867 }
868
IsBrowserAvailable(const std::string &key)869 bool MDnsProtocolImpl::IsBrowserAvailable(const std::string &key)
870 {
871 return browserMap_.find(key) != browserMap_.end() && !browserMap_[key].empty();
872 }
873
AddEvent(const std::string &key, const Task &task)874 void MDnsProtocolImpl::AddEvent(const std::string &key, const Task &task)
875 {
876 std::lock_guard<std::recursive_mutex> guard(mutex_);
877 taskOnChange_[key].emplace_back(task);
878 }
879
RunTaskQueue(std::list<Task> &queue)880 void MDnsProtocolImpl::RunTaskQueue(std::list<Task> &queue)
881 {
882 std::list<Task> tmp;
883 for (auto &&func : queue) {
884 if (!func()) {
885 tmp.emplace_back(func);
886 }
887 }
888 tmp.swap(queue);
889 }
890
KillCache(const std::string &key)891 void MDnsProtocolImpl::KillCache(const std::string &key)
892 {
893 NETMGR_EXT_LOG_D("mdns_log KillCache");
894 if (IsBrowserAvailable(key) && browserMap_.find(key) != browserMap_.end()) {
895 for (auto it = browserMap_[key].begin(); it != browserMap_[key].end();) {
896 KillBrowseCache(key, it);
897 }
898 }
899 if (IsCacheAvailable(key)) {
900 std::lock_guard<std::recursive_mutex> guard(mutex_);
901 auto &elem = cacheMap_[key];
902 if (elem.state == State::REMOVE) {
903 elem.state = State::DEAD;
904 cacheMap_.erase(key);
905 } else if (elem.state == State::ADD || elem.state == State::REFRESH) {
906 elem.state = State::LIVE;
907 }
908 }
909 }
910
KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)911 void MDnsProtocolImpl::KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)
912 {
913 NETMGR_EXT_LOG_D("mdns_log KillBrowseCache");
914 if (it->state == State::REMOVE) {
915 it->state = State::DEAD;
916 if (nameCbMap_.find(key) != nameCbMap_.end()) {
917 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
918 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
919 }
920 std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
921 cacheMap_.erase(fullName);
922 it = browserMap_[key].erase(it);
923 } else if (it->state == State::ADD || it->state == State::REFRESH) {
924 it->state = State::LIVE;
925 it++;
926 } else {
927 it++;
928 }
929 }
930
StopCbMap(const std::string &serviceType)931 int32_t MDnsProtocolImpl::StopCbMap(const std::string &serviceType)
932 {
933 NETMGR_EXT_LOG_D("mdns_log StopCbMap");
934 std::lock_guard<std::recursive_mutex> guard(mutex_);
935 std::string name = Decorated(serviceType);
936 sptr<IDiscoveryCallback> cb = nullptr;
937 if (nameCbMap_.find(name) != nameCbMap_.end()) {
938 cb = nameCbMap_[name];
939 nameCbMap_.erase(name);
940 }
941 taskOnChange_.erase(name);
942 auto it = browserMap_.find(name);
943 if (it != browserMap_.end()) {
944 if (cb != nullptr) {
945 NETMGR_EXT_LOG_I("mdns_log StopCbMap res size:[%{public}zu]", it->second.size());
946 for (auto &&res : it->second) {
947 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
948 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
949 }
950 }
951 browserMap_.erase(name);
952 }
953 return NETMANAGER_SUCCESS;
954 }
955 } // namespace NetManagerStandard
956 } // namespace OHOS
957