1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "dns_param_cache.h"
17 
18 #include <algorithm>
19 
20 #include "netmanager_base_common_utils.h"
21 #ifdef FEATURE_NET_FIREWALL_ENABLE
22 #include "bpf_netfirewall.h"
23 #include "netfirewall_parcel.h"
24 #include <ctime>
25 #endif
26 
27 namespace OHOS::nmd {
28 using namespace OHOS::NetManagerStandard::CommonUtils;
29 namespace {
GetVectorData(const std::vector<std::string> &data, std::string &result)30 void GetVectorData(const std::vector<std::string> &data, std::string &result)
31 {
32     result.append("{ ");
33     std::for_each(data.begin(), data.end(), [&result](const auto &str) { result.append(ToAnonymousIp(str) + ", "); });
34     result.append("}\n");
35 }
36 constexpr int RES_TIMEOUT = 5000;    // min. milliseconds between retries
37 constexpr int RES_DEFAULT_RETRY = 2; // Default
38 } // namespace
39 
DnsParamCache()40 DnsParamCache::DnsParamCache() : defaultNetId_(0) {}
41 
GetInstance()42 DnsParamCache &DnsParamCache::GetInstance()
43 {
44     static DnsParamCache instance;
45     return instance;
46 }
47 
SelectNameservers(const std::vector<std::string> &servers)48 std::vector<std::string> DnsParamCache::SelectNameservers(const std::vector<std::string> &servers)
49 {
50     std::vector<std::string> res = servers;
51     if (res.size() > MAX_SERVER_NUM - 1) {
52         res.resize(MAX_SERVER_NUM - 1);
53     }
54     return res;
55 }
56 
CreateCacheForNet(uint16_t netId)57 int32_t DnsParamCache::CreateCacheForNet(uint16_t netId)
58 {
59     NETNATIVE_LOG_D("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
60     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
61     auto it = serverConfigMap_.find(netId);
62     if (it != serverConfigMap_.end()) {
63         NETNATIVE_LOGE("DnsParamCache::CreateCacheForNet, netid already exist, no need to create");
64         return -EEXIST;
65     }
66     serverConfigMap_[netId].SetNetId(netId);
67     return 0;
68 }
69 
DestroyNetworkCache(uint16_t netId)70 int32_t DnsParamCache::DestroyNetworkCache(uint16_t netId)
71 {
72     NETNATIVE_LOG_D("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
73     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
74     auto it = serverConfigMap_.find(netId);
75     if (it == serverConfigMap_.end()) {
76         return -ENOENT;
77     }
78     serverConfigMap_.erase(it);
79     if (defaultNetId_ == netId) {
80         defaultNetId_ = 0;
81     }
82     return 0;
83 }
84 
SetResolverConfig(uint16_t netId, uint16_t baseTimeoutMsec, uint8_t retryCount, const std::vector<std::string> &servers, const std::vector<std::string> &domains)85 int32_t DnsParamCache::SetResolverConfig(uint16_t netId, uint16_t baseTimeoutMsec, uint8_t retryCount,
86                                          const std::vector<std::string> &servers,
87                                          const std::vector<std::string> &domains)
88 {
89     std::vector<std::string> nameservers = SelectNameservers(servers);
90     NETNATIVE_LOG_D("DnsParamCache::SetResolverConfig, netid:%{public}d, numServers:%{public}d,", netId,
91                     static_cast<int>(nameservers.size()));
92 
93     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
94 
95     // select_domains
96     auto it = serverConfigMap_.find(netId);
97     if (it == serverConfigMap_.end()) {
98         NETNATIVE_LOGE("DnsParamCache::SetResolverConfig failed, netid is non-existent");
99         return -ENOENT;
100     }
101 
102     auto oldDnsServers = it->second.GetServers();
103     std::sort(oldDnsServers.begin(), oldDnsServers.end());
104 
105     auto newDnsServers = servers;
106     std::sort(newDnsServers.begin(), newDnsServers.end());
107 
108     if (oldDnsServers != newDnsServers) {
109         it->second.GetCache().Clear();
110     }
111 
112     it->second.SetNetId(netId);
113     it->second.SetServers(servers);
114     it->second.SetDomains(domains);
115     if (retryCount == 0) {
116         it->second.SetRetryCount(RES_DEFAULT_RETRY);
117     } else {
118         it->second.SetRetryCount(retryCount);
119     }
120     if (baseTimeoutMsec == 0) {
121         it->second.SetTimeoutMsec(RES_TIMEOUT);
122     } else {
123         it->second.SetTimeoutMsec(baseTimeoutMsec);
124     }
125     return 0;
126 }
127 
SetDefaultNetwork(uint16_t netId)128 void DnsParamCache::SetDefaultNetwork(uint16_t netId)
129 {
130     defaultNetId_ = netId;
131 }
132 
EnableIpv6(uint16_t netId)133 void DnsParamCache::EnableIpv6(uint16_t netId)
134 {
135     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
136     auto it = serverConfigMap_.find(netId);
137     if (it == serverConfigMap_.end()) {
138         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
139         return;
140     }
141 
142     it->second.EnableIpv6();
143 }
144 
IsIpv6Enable(uint16_t netId)145 bool DnsParamCache::IsIpv6Enable(uint16_t netId)
146 {
147     if (netId == 0) {
148         netId = defaultNetId_;
149     }
150 
151     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
152     auto it = serverConfigMap_.find(netId);
153     if (it == serverConfigMap_.end()) {
154         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
155         return false;
156     }
157 
158     return it->second.IsIpv6Enable();
159 }
160 
GetResolverConfig(uint16_t netId, std::vector<std::string> &servers, std::vector<std::string> &domains, uint16_t &baseTimeoutMsec, uint8_t &retryCount)161 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, std::vector<std::string> &servers,
162                                          std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
163                                          uint8_t &retryCount)
164 {
165     NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig no uid");
166     if (netId == 0) {
167         netId = defaultNetId_;
168         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
169     }
170 
171     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
172     auto it = serverConfigMap_.find(netId);
173     if (it == serverConfigMap_.end()) {
174         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
175         return -ENOENT;
176     }
177 
178     servers = it->second.GetServers();
179 #ifdef FEATURE_NET_FIREWALL_ENABLE
180     std::vector<std::string> dns;
181     if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
182         DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
183         servers.assign(dns.begin(), dns.end());
184     }
185 #endif
186     domains = it->second.GetDomains();
187     baseTimeoutMsec = it->second.GetTimeoutMsec();
188     retryCount = it->second.GetRetryCount();
189 
190     return 0;
191 }
192 
GetResolverConfig(uint16_t netId, uint32_t uid, std::vector<std::string> &servers, std::vector<std::string> &domains, uint16_t &baseTimeoutMsec, uint8_t &retryCount)193 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, uint32_t uid, std::vector<std::string> &servers,
194                                          std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
195                                          uint8_t &retryCount)
196 {
197     NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig has uid");
198     if (netId == 0) {
199         netId = defaultNetId_;
200         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
201     }
202 
203     {
204         std::lock_guard<ffrt::mutex> guard(cacheMutex_);
205         for (auto mem : vpnUidRanges_) {
206             if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
207                 NETNATIVE_LOG_D("is vpn hap");
208                 auto it = serverConfigMap_.find(vpnNetId_);
209                 if (it == serverConfigMap_.end()) {
210                     NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
211                     break;
212                 }
213                 servers = it->second.GetServers();
214 #ifdef FEATURE_NET_FIREWALL_ENABLE
215                 std::vector<std::string> dns;
216                 if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
217                     DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
218                     servers.assign(dns.begin(), dns.end());
219                 }
220 #endif
221                 domains = it->second.GetDomains();
222                 baseTimeoutMsec = it->second.GetTimeoutMsec();
223                 retryCount = it->second.GetRetryCount();
224                 return 0;
225             }
226         }
227     }
228     return GetResolverConfig(netId, servers, domains, baseTimeoutMsec, retryCount);
229 }
230 
GetDefaultNetwork() const231 int32_t DnsParamCache::GetDefaultNetwork() const
232 {
233     return defaultNetId_;
234 }
235 
SetDnsCache(uint16_t netId, const std::string &hostName, const AddrInfo &addrInfo)236 void DnsParamCache::SetDnsCache(uint16_t netId, const std::string &hostName, const AddrInfo &addrInfo)
237 {
238     if (netId == 0) {
239         netId = defaultNetId_;
240     }
241     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
242 #ifdef FEATURE_NET_FIREWALL_ENABLE
243     int32_t appUid = static_cast<int32_t>(GetCallingUid());
244     bool isMatchAllow = false;
245     if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
246         DNS_CONFIG_PRINT("SetDnsCache failed: domain was Intercepted: %{public}s,", hostName.c_str());
247         return;
248     }
249     if (isMatchAllow && (addrInfo.aiFamily == AF_INET || addrInfo.aiFamily == AF_INET6)) {
250         NetAddrInfo netInfo;
251         netInfo.aiFamily = addrInfo.aiFamily;
252         if (addrInfo.aiFamily == AF_INET) {
253             netInfo.aiAddr.sin = addrInfo.aiAddr.sin.sin_addr;
254         } else {
255             memcpy_s(&netInfo.aiAddr.sin6, sizeof(addrInfo.aiAddr.sin6.sin6_addr), &addrInfo.aiAddr.sin6.sin6_addr,
256                      sizeof(addrInfo.aiAddr.sin6.sin6_addr));
257         }
258         OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->AddDomainCache(netInfo);
259     }
260 #endif
261     auto it = serverConfigMap_.find(netId);
262     if (it == serverConfigMap_.end()) {
263         DNS_CONFIG_PRINT("SetDnsCache failed: netid is not have netid:%{public}d,", netId);
264         return;
265     }
266 
267     it->second.GetCache().Put(hostName, addrInfo);
268 }
269 
GetDnsCache(uint16_t netId, const std::string &hostName)270 std::vector<AddrInfo> DnsParamCache::GetDnsCache(uint16_t netId, const std::string &hostName)
271 {
272     if (netId == 0) {
273         netId = defaultNetId_;
274     }
275 
276     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
277 #ifdef FEATURE_NET_FIREWALL_ENABLE
278     int32_t appUid = static_cast<int32_t>(GetCallingUid());
279     bool isMatchAllow = false;
280     if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
281         NotifyDomianIntercept(appUid, hostName);
282         AddrInfo fakeAddr = { 0 };
283         fakeAddr.aiFamily = AF_UNSPEC;
284         fakeAddr.aiAddr.sin.sin_family = AF_UNSPEC;
285         fakeAddr.aiAddr.sin.sin_addr.s_addr = INADDR_NONE;
286         fakeAddr.aiAddrLen = sizeof(struct sockaddr_in);
287         return { fakeAddr };
288     }
289 #endif
290 
291     auto it = serverConfigMap_.find(netId);
292     if (it == serverConfigMap_.end()) {
293         DNS_CONFIG_PRINT("GetDnsCache failed: netid is not have netid:%{public}d,", netId);
294         return {};
295     }
296 
297     return it->second.GetCache().Get(hostName);
298 }
299 
SetCacheDelayed(uint16_t netId, const std::string &hostName)300 void DnsParamCache::SetCacheDelayed(uint16_t netId, const std::string &hostName)
301 {
302     if (netId == 0) {
303         netId = defaultNetId_;
304     }
305 
306     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
307     auto it = serverConfigMap_.find(netId);
308     if (it == serverConfigMap_.end()) {
309         DNS_CONFIG_PRINT("SetCacheDelayed failed: netid is not have netid:%{public}d,", netId);
310         return;
311     }
312 
313     it->second.SetCacheDelayed(hostName);
314 }
315 
AddUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)316 int32_t DnsParamCache::AddUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
317 {
318     std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
319     NETNATIVE_LOG_D("DnsParamCache::AddUidRange size = [%{public}zu]", uidRanges.size());
320     vpnNetId_ = netId;
321     auto middle = vpnUidRanges_.insert(vpnUidRanges_.end(), uidRanges.begin(), uidRanges.end());
322     std::inplace_merge(vpnUidRanges_.begin(), middle, vpnUidRanges_.end());
323     return 0;
324 }
325 
DelUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)326 int32_t DnsParamCache::DelUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
327 {
328     std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
329     NETNATIVE_LOG_D("DnsParamCache::DelUidRange size = [%{public}zu]", uidRanges.size());
330     vpnNetId_ = 0;
331     auto end = std::set_difference(vpnUidRanges_.begin(), vpnUidRanges_.end(), uidRanges.begin(),
332                                    uidRanges.end(), vpnUidRanges_.begin());
333     vpnUidRanges_.erase(end, vpnUidRanges_.end());
334     return 0;
335 }
336 
IsVpnOpen() const337 bool DnsParamCache::IsVpnOpen() const
338 {
339     return vpnUidRanges_.size();
340 }
341 
342 #ifdef FEATURE_NET_FIREWALL_ENABLE
GetUserId(int32_t appUid)343 int32_t DnsParamCache::GetUserId(int32_t appUid)
344 {
345     int32_t userId = appUid / USER_ID_DIVIDOR;
346     return userId > 0 ? userId : currentUserId_;
347 }
348 
GetDnsServersByAppUid(int32_t appUid, std::vector<std::string> &servers)349 bool DnsParamCache::GetDnsServersByAppUid(int32_t appUid, std::vector<std::string> &servers)
350 {
351     if (netFirewallDnsRuleMap_.empty()) {
352         return false;
353     }
354     DNS_CONFIG_PRINT("GetDnsServersByAppUid: appUid=%{public}d", appUid);
355     auto it = netFirewallDnsRuleMap_.find(appUid);
356     if (it == netFirewallDnsRuleMap_.end()) {
357         // if appUid not found, try to find invalid appUid=0;
358         it = netFirewallDnsRuleMap_.find(0);
359     }
360     if (it != netFirewallDnsRuleMap_.end()) {
361         int32_t userId = GetUserId(appUid);
362         std::vector<sptr<NetFirewallDnsRule>> rules = it->second;
363         for (const auto &rule : rules) {
364             if (rule->userId != userId) {
365                 continue;
366             }
367             servers.emplace_back(rule->primaryDns);
368             servers.emplace_back(rule->standbyDns);
369         }
370         return true;
371     }
372     return false;
373 }
374 
SetFirewallRules(NetFirewallRuleType type, const std::vector<sptr<NetFirewallBaseRule>> &ruleList, bool isFinish)375 int32_t DnsParamCache::SetFirewallRules(NetFirewallRuleType type,
376                                         const std::vector<sptr<NetFirewallBaseRule>> &ruleList, bool isFinish)
377 {
378     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
379     NETNATIVE_LOGI("SetFirewallRules: size=%{public}zu isFinish=%{public}" PRId32, ruleList.size(), isFinish);
380     if (ruleList.empty()) {
381         NETNATIVE_LOGE("SetFirewallRules: rules is empty");
382         return -1;
383     }
384     int32_t ret = 0;
385     switch (type) {
386         case NetFirewallRuleType::RULE_DNS: {
387             for (const auto &rule : ruleList) {
388                 firewallDnsRules_.emplace_back(firewall_rule_cast<NetFirewallDnsRule>(rule));
389             }
390             if (isFinish) {
391                 ret = SetFirewallDnsRules(firewallDnsRules_);
392                 firewallDnsRules_.clear();
393             }
394             break;
395         }
396         case NetFirewallRuleType::RULE_DOMAIN: {
397             for (const auto &rule : ruleList) {
398                 firewallDomainRules_.emplace_back(firewall_rule_cast<NetFirewallDomainRule>(rule));
399             }
400             if (isFinish) {
401                 ret = SetFirewallDomainRules(firewallDomainRules_);
402                 firewallDomainRules_.clear();
403                 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
404             }
405             break;
406         }
407         default:
408             break;
409     }
410     return ret;
411 }
412 
SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> &ruleList)413 int32_t DnsParamCache::SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> &ruleList)
414 {
415     for (const auto &rule : ruleList) {
416         std::vector<sptr<NetFirewallDnsRule>> rules;
417         auto it = netFirewallDnsRuleMap_.find(rule->appUid);
418         if (it != netFirewallDnsRuleMap_.end()) {
419             rules = it->second;
420         }
421         rules.emplace_back(std::move(rule));
422         netFirewallDnsRuleMap_.emplace(rule->appUid, std::move(rules));
423     }
424     return 0;
425 }
426 
GetFirewallRuleAction(int32_t appUid, const std::vector<sptr<NetFirewallDomainRule>> &rules)427 FirewallRuleAction DnsParamCache::GetFirewallRuleAction(int32_t appUid,
428                                                         const std::vector<sptr<NetFirewallDomainRule>> &rules)
429 {
430     int32_t userId = GetUserId(appUid);
431     for (const auto &rule : rules) {
432         if (rule->userId != userId) {
433             continue;
434         }
435         if ((rule->appUid && appUid == rule->appUid) || !rule->appUid) {
436             return rule->ruleAction;
437         }
438     }
439 
440     return FirewallRuleAction::RULE_INVALID;
441 }
442 
checkEmpty4InterceptDomain(const std::string &hostName)443 bool DnsParamCache::checkEmpty4InterceptDomain(const std::string &hostName)
444 {
445     if (hostName.empty()) {
446         return true;
447     }
448     if (!netFirewallDomainRulesAllowMap_.empty() || !netFirewallDomainRulesDenyMap_.empty()) {
449         return false;
450     }
451     if (domainAllowLsmTrie_ && !domainAllowLsmTrie_->Empty()) {
452         return false;
453     }
454     return !domainDenyLsmTrie_ || domainDenyLsmTrie_->Empty();
455 }
456 
IsInterceptDomain(int32_t appUid, const std::string &hostName, bool &isMatchAllow)457 bool DnsParamCache::IsInterceptDomain(int32_t appUid, const std::string &hostName, bool &isMatchAllow)
458 {
459     if (checkEmpty4InterceptDomain(hostName)) {
460         return false;
461     }
462     std::string host = hostName.substr(0, hostName.find(' '));
463     DNS_CONFIG_PRINT("IsInterceptDomain: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
464     std::transform(host.begin(), host.end(), host.begin(), ::tolower);
465     std::vector<sptr<NetFirewallDomainRule>> rules;
466     FirewallRuleAction exactAllowAction = FirewallRuleAction::RULE_INVALID;
467     auto it = netFirewallDomainRulesAllowMap_.find(host);
468     if (it != netFirewallDomainRulesAllowMap_.end()) {
469         rules = it->second;
470         exactAllowAction = GetFirewallRuleAction(appUid, rules);
471     }
472     FirewallRuleAction exactDenyAction = FirewallRuleAction::RULE_INVALID;
473     auto iter = netFirewallDomainRulesDenyMap_.find(host);
474     if (iter != netFirewallDomainRulesDenyMap_.end()) {
475         rules = iter->second;
476         exactDenyAction = GetFirewallRuleAction(appUid, rules);
477     }
478     FirewallRuleAction wildcardAllowAction = FirewallRuleAction::RULE_INVALID;
479     if (domainAllowLsmTrie_->LongestSuffixMatch(host, rules)) {
480         wildcardAllowAction = GetFirewallRuleAction(appUid, rules);
481     }
482     FirewallRuleAction wildcardDenyAction = FirewallRuleAction::RULE_INVALID;
483     if (domainDenyLsmTrie_->LongestSuffixMatch(host, rules)) {
484         wildcardDenyAction = GetFirewallRuleAction(appUid, rules);
485     }
486     isMatchAllow = (exactAllowAction != FirewallRuleAction::RULE_INVALID) ||
487                    (wildcardAllowAction != FirewallRuleAction::RULE_INVALID);
488     bool isDeny = (exactDenyAction != FirewallRuleAction::RULE_INVALID) ||
489                   (wildcardDenyAction != FirewallRuleAction::RULE_INVALID);
490     if (isMatchAllow) {
491         // Apply default rules in case of conflict
492         return isDeny && (firewallDefaultAction_ == FirewallRuleAction::RULE_DENY);
493     }
494     return isDeny;
495 }
496 
SetFirewallDefaultAction(FirewallRuleAction inDefault, FirewallRuleAction outDefault)497 int32_t DnsParamCache::SetFirewallDefaultAction(FirewallRuleAction inDefault, FirewallRuleAction outDefault)
498 {
499     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
500     DNS_CONFIG_PRINT("SetFirewallDefaultAction: firewallDefaultAction_: %{public}d", (int)outDefault);
501     firewallDefaultAction_ = outDefault;
502     return 0;
503 }
504 
BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> &rule, const std::string &domain)505 void DnsParamCache::BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> &rule, const std::string &domain)
506 {
507     std::vector<sptr<NetFirewallDomainRule>> rules;
508     std::string suffix(domain);
509     auto wildcardCharIndex = suffix.find('*');
510     if (wildcardCharIndex != std::string::npos) {
511         suffix = suffix.substr(wildcardCharIndex + 1);
512     }
513     DNS_CONFIG_PRINT("BuildFirewallDomainLsmTrie: suffix: %{public}s", suffix.c_str());
514     std::transform(suffix.begin(), suffix.end(), suffix.begin(), ::tolower);
515     if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
516         if (domainDenyLsmTrie_->LongestSuffixMatch(suffix, rules)) {
517             rules.emplace_back(std::move(rule));
518             domainDenyLsmTrie_->Update(suffix, rules);
519             return;
520         }
521         rules.emplace_back(std::move(rule));
522         domainDenyLsmTrie_->Insert(suffix, rules);
523     } else {
524         if (domainAllowLsmTrie_->LongestSuffixMatch(suffix, rules)) {
525             rules.emplace_back(std::move(rule));
526             domainAllowLsmTrie_->Update(suffix, rules);
527             return;
528         }
529         rules.emplace_back(std::move(rule));
530         domainAllowLsmTrie_->Insert(suffix, rules);
531     }
532 }
533 
BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> &rule, const std::string &raw)534 void DnsParamCache::BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> &rule, const std::string &raw)
535 {
536     DNS_CONFIG_PRINT("BuildFirewallDomainMap: domain: %{public}s", raw.c_str());
537     std::string domain(raw);
538     std::vector<sptr<NetFirewallDomainRule>> rules;
539     std::transform(domain.begin(), domain.end(), domain.begin(), ::tolower);
540     if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
541         auto it = netFirewallDomainRulesDenyMap_.find(domain);
542         if (it != netFirewallDomainRulesDenyMap_.end()) {
543             rules = it->second;
544         }
545 
546         rules.emplace_back(std::move(rule));
547         netFirewallDomainRulesDenyMap_.emplace(domain, std::move(rules));
548     } else {
549         auto it = netFirewallDomainRulesAllowMap_.find(domain);
550         if (it != netFirewallDomainRulesAllowMap_.end()) {
551             rules = it->second;
552         }
553 
554         rules.emplace_back(rule);
555         netFirewallDomainRulesAllowMap_.emplace(domain, std::move(rules));
556     }
557 }
558 
SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> &ruleList)559 int32_t DnsParamCache::SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> &ruleList)
560 {
561     if (!domainAllowLsmTrie_) {
562         domainAllowLsmTrie_ =
563             std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
564     }
565     if (!domainDenyLsmTrie_) {
566         domainDenyLsmTrie_ =
567             std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
568     }
569     for (const auto &rule : ruleList) {
570         for (const auto &param : rule->domains) {
571             if (param.isWildcard) {
572                 BuildFirewallDomainLsmTrie(rule, param.domain);
573             } else {
574                 BuildFirewallDomainMap(rule, param.domain);
575             }
576         }
577     }
578     return 0;
579 }
580 
ClearFirewallRules(NetFirewallRuleType type)581 int32_t DnsParamCache::ClearFirewallRules(NetFirewallRuleType type)
582 {
583     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
584     switch (type) {
585         case NetFirewallRuleType::RULE_DNS:
586             firewallDnsRules_.clear();
587             netFirewallDnsRuleMap_.clear();
588             break;
589         case NetFirewallRuleType::RULE_DOMAIN: {
590             firewallDomainRules_.clear();
591             netFirewallDomainRulesAllowMap_.clear();
592             netFirewallDomainRulesDenyMap_.clear();
593             if (domainAllowLsmTrie_) {
594                 domainAllowLsmTrie_ = nullptr;
595             }
596             if (domainDenyLsmTrie_) {
597                 domainDenyLsmTrie_ = nullptr;
598             }
599             OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
600             break;
601         }
602         case NetFirewallRuleType::RULE_ALL: {
603             firewallDnsRules_.clear();
604             netFirewallDnsRuleMap_.clear();
605             firewallDomainRules_.clear();
606             netFirewallDomainRulesAllowMap_.clear();
607             netFirewallDomainRulesDenyMap_.clear();
608             if (domainAllowLsmTrie_) {
609                 domainAllowLsmTrie_ = nullptr;
610             }
611             if (domainDenyLsmTrie_) {
612                 domainDenyLsmTrie_ = nullptr;
613             }
614             OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
615             break;
616         }
617         default:
618             break;
619     }
620     return 0;
621 }
622 
NotifyDomianIntercept(int32_t appUid, const std::string &hostName)623 void DnsParamCache::NotifyDomianIntercept(int32_t appUid, const std::string &hostName)
624 {
625     if (hostName.empty()) {
626         return;
627     }
628     std::string host = hostName.substr(0, hostName.find(' '));
629     NETNATIVE_LOGI("NotifyDomianIntercept: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
630     sptr<NetManagerStandard::InterceptRecord> record = new (std::nothrow) NetManagerStandard::InterceptRecord();
631     record->time = (int32_t)time(NULL);
632     record->appUid = appUid;
633     record->domain = host;
634 
635     if (oldRecord_ != nullptr && (record->time - oldRecord_->time) < INTERCEPT_BUFF_INTERVAL_SEC) {
636         if (record->appUid == oldRecord_->appUid && record->domain == oldRecord_->domain) {
637             return;
638         }
639     }
640     oldRecord_ = record;
641     for (const auto &callback : callbacks_) {
642         callback->OnIntercept(record);
643     }
644 }
645 
RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)646 int32_t DnsParamCache::RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
647 {
648     if (!callback) {
649         return -1;
650     }
651 
652     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
653     callbacks_.emplace_back(callback);
654 
655     return 0;
656 }
657 
UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)658 int32_t DnsParamCache::UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
659 {
660     if (!callback) {
661         return -1;
662     }
663 
664     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
665     for (auto it = callbacks_.begin(); it != callbacks_.end(); ++it) {
666         if (*it == callback) {
667             callbacks_.erase(it);
668             return 0;
669         }
670     }
671     return -1;
672 }
673 #endif
674 
GetDumpInfo(std::string &info)675 void DnsParamCache::GetDumpInfo(std::string &info)
676 {
677     std::string dnsData;
678     static const std::string TAB = "  ";
679     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
680     std::for_each(serverConfigMap_.begin(), serverConfigMap_.end(), [&dnsData](const auto &serverConfig) {
681         dnsData.append(TAB + "NetId: " + std::to_string(serverConfig.second.GetNetId()) + "\n");
682         dnsData.append(TAB + "TimeoutMsec: " + std::to_string(serverConfig.second.GetTimeoutMsec()) + "\n");
683         dnsData.append(TAB + "RetryCount: " + std::to_string(serverConfig.second.GetRetryCount()) + "\n");
684         dnsData.append(TAB + "Servers:");
685         GetVectorData(serverConfig.second.GetServers(), dnsData);
686         dnsData.append(TAB + "Domains:");
687         GetVectorData(serverConfig.second.GetDomains(), dnsData);
688     });
689     info.append(dnsData);
690 }
691 } // namespace OHOS::nmd
692