1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "dns_param_cache.h"
17
18 #include <algorithm>
19
20 #include "netmanager_base_common_utils.h"
21 #ifdef FEATURE_NET_FIREWALL_ENABLE
22 #include "bpf_netfirewall.h"
23 #include "netfirewall_parcel.h"
24 #include <ctime>
25 #endif
26
27 namespace OHOS::nmd {
28 using namespace OHOS::NetManagerStandard::CommonUtils;
29 namespace {
GetVectorData(const std::vector<std::string> &data, std::string &result)30 void GetVectorData(const std::vector<std::string> &data, std::string &result)
31 {
32 result.append("{ ");
33 std::for_each(data.begin(), data.end(), [&result](const auto &str) { result.append(ToAnonymousIp(str) + ", "); });
34 result.append("}\n");
35 }
36 constexpr int RES_TIMEOUT = 5000; // min. milliseconds between retries
37 constexpr int RES_DEFAULT_RETRY = 2; // Default
38 } // namespace
39
DnsParamCache()40 DnsParamCache::DnsParamCache() : defaultNetId_(0) {}
41
GetInstance()42 DnsParamCache &DnsParamCache::GetInstance()
43 {
44 static DnsParamCache instance;
45 return instance;
46 }
47
SelectNameservers(const std::vector<std::string> &servers)48 std::vector<std::string> DnsParamCache::SelectNameservers(const std::vector<std::string> &servers)
49 {
50 std::vector<std::string> res = servers;
51 if (res.size() > MAX_SERVER_NUM - 1) {
52 res.resize(MAX_SERVER_NUM - 1);
53 }
54 return res;
55 }
56
CreateCacheForNet(uint16_t netId)57 int32_t DnsParamCache::CreateCacheForNet(uint16_t netId)
58 {
59 NETNATIVE_LOG_D("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
60 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
61 auto it = serverConfigMap_.find(netId);
62 if (it != serverConfigMap_.end()) {
63 NETNATIVE_LOGE("DnsParamCache::CreateCacheForNet, netid already exist, no need to create");
64 return -EEXIST;
65 }
66 serverConfigMap_[netId].SetNetId(netId);
67 return 0;
68 }
69
DestroyNetworkCache(uint16_t netId)70 int32_t DnsParamCache::DestroyNetworkCache(uint16_t netId)
71 {
72 NETNATIVE_LOG_D("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
73 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
74 auto it = serverConfigMap_.find(netId);
75 if (it == serverConfigMap_.end()) {
76 return -ENOENT;
77 }
78 serverConfigMap_.erase(it);
79 if (defaultNetId_ == netId) {
80 defaultNetId_ = 0;
81 }
82 return 0;
83 }
84
SetResolverConfig(uint16_t netId, uint16_t baseTimeoutMsec, uint8_t retryCount, const std::vector<std::string> &servers, const std::vector<std::string> &domains)85 int32_t DnsParamCache::SetResolverConfig(uint16_t netId, uint16_t baseTimeoutMsec, uint8_t retryCount,
86 const std::vector<std::string> &servers,
87 const std::vector<std::string> &domains)
88 {
89 std::vector<std::string> nameservers = SelectNameservers(servers);
90 NETNATIVE_LOG_D("DnsParamCache::SetResolverConfig, netid:%{public}d, numServers:%{public}d,", netId,
91 static_cast<int>(nameservers.size()));
92
93 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
94
95 // select_domains
96 auto it = serverConfigMap_.find(netId);
97 if (it == serverConfigMap_.end()) {
98 NETNATIVE_LOGE("DnsParamCache::SetResolverConfig failed, netid is non-existent");
99 return -ENOENT;
100 }
101
102 auto oldDnsServers = it->second.GetServers();
103 std::sort(oldDnsServers.begin(), oldDnsServers.end());
104
105 auto newDnsServers = servers;
106 std::sort(newDnsServers.begin(), newDnsServers.end());
107
108 if (oldDnsServers != newDnsServers) {
109 it->second.GetCache().Clear();
110 }
111
112 it->second.SetNetId(netId);
113 it->second.SetServers(servers);
114 it->second.SetDomains(domains);
115 if (retryCount == 0) {
116 it->second.SetRetryCount(RES_DEFAULT_RETRY);
117 } else {
118 it->second.SetRetryCount(retryCount);
119 }
120 if (baseTimeoutMsec == 0) {
121 it->second.SetTimeoutMsec(RES_TIMEOUT);
122 } else {
123 it->second.SetTimeoutMsec(baseTimeoutMsec);
124 }
125 return 0;
126 }
127
SetDefaultNetwork(uint16_t netId)128 void DnsParamCache::SetDefaultNetwork(uint16_t netId)
129 {
130 defaultNetId_ = netId;
131 }
132
EnableIpv6(uint16_t netId)133 void DnsParamCache::EnableIpv6(uint16_t netId)
134 {
135 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
136 auto it = serverConfigMap_.find(netId);
137 if (it == serverConfigMap_.end()) {
138 DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
139 return;
140 }
141
142 it->second.EnableIpv6();
143 }
144
IsIpv6Enable(uint16_t netId)145 bool DnsParamCache::IsIpv6Enable(uint16_t netId)
146 {
147 if (netId == 0) {
148 netId = defaultNetId_;
149 }
150
151 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
152 auto it = serverConfigMap_.find(netId);
153 if (it == serverConfigMap_.end()) {
154 DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
155 return false;
156 }
157
158 return it->second.IsIpv6Enable();
159 }
160
GetResolverConfig(uint16_t netId, std::vector<std::string> &servers, std::vector<std::string> &domains, uint16_t &baseTimeoutMsec, uint8_t &retryCount)161 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, std::vector<std::string> &servers,
162 std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
163 uint8_t &retryCount)
164 {
165 NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig no uid");
166 if (netId == 0) {
167 netId = defaultNetId_;
168 NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
169 }
170
171 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
172 auto it = serverConfigMap_.find(netId);
173 if (it == serverConfigMap_.end()) {
174 DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
175 return -ENOENT;
176 }
177
178 servers = it->second.GetServers();
179 #ifdef FEATURE_NET_FIREWALL_ENABLE
180 std::vector<std::string> dns;
181 if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
182 DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
183 servers.assign(dns.begin(), dns.end());
184 }
185 #endif
186 domains = it->second.GetDomains();
187 baseTimeoutMsec = it->second.GetTimeoutMsec();
188 retryCount = it->second.GetRetryCount();
189
190 return 0;
191 }
192
GetResolverConfig(uint16_t netId, uint32_t uid, std::vector<std::string> &servers, std::vector<std::string> &domains, uint16_t &baseTimeoutMsec, uint8_t &retryCount)193 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, uint32_t uid, std::vector<std::string> &servers,
194 std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
195 uint8_t &retryCount)
196 {
197 NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig has uid");
198 if (netId == 0) {
199 netId = defaultNetId_;
200 NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
201 }
202
203 {
204 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
205 for (auto mem : vpnUidRanges_) {
206 if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
207 NETNATIVE_LOG_D("is vpn hap");
208 auto it = serverConfigMap_.find(vpnNetId_);
209 if (it == serverConfigMap_.end()) {
210 NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
211 break;
212 }
213 servers = it->second.GetServers();
214 #ifdef FEATURE_NET_FIREWALL_ENABLE
215 std::vector<std::string> dns;
216 if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
217 DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
218 servers.assign(dns.begin(), dns.end());
219 }
220 #endif
221 domains = it->second.GetDomains();
222 baseTimeoutMsec = it->second.GetTimeoutMsec();
223 retryCount = it->second.GetRetryCount();
224 return 0;
225 }
226 }
227 }
228 return GetResolverConfig(netId, servers, domains, baseTimeoutMsec, retryCount);
229 }
230
GetDefaultNetwork() const231 int32_t DnsParamCache::GetDefaultNetwork() const
232 {
233 return defaultNetId_;
234 }
235
SetDnsCache(uint16_t netId, const std::string &hostName, const AddrInfo &addrInfo)236 void DnsParamCache::SetDnsCache(uint16_t netId, const std::string &hostName, const AddrInfo &addrInfo)
237 {
238 if (netId == 0) {
239 netId = defaultNetId_;
240 }
241 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
242 #ifdef FEATURE_NET_FIREWALL_ENABLE
243 int32_t appUid = static_cast<int32_t>(GetCallingUid());
244 bool isMatchAllow = false;
245 if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
246 DNS_CONFIG_PRINT("SetDnsCache failed: domain was Intercepted: %{public}s,", hostName.c_str());
247 return;
248 }
249 if (isMatchAllow && (addrInfo.aiFamily == AF_INET || addrInfo.aiFamily == AF_INET6)) {
250 NetAddrInfo netInfo;
251 netInfo.aiFamily = addrInfo.aiFamily;
252 if (addrInfo.aiFamily == AF_INET) {
253 netInfo.aiAddr.sin = addrInfo.aiAddr.sin.sin_addr;
254 } else {
255 memcpy_s(&netInfo.aiAddr.sin6, sizeof(addrInfo.aiAddr.sin6.sin6_addr), &addrInfo.aiAddr.sin6.sin6_addr,
256 sizeof(addrInfo.aiAddr.sin6.sin6_addr));
257 }
258 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->AddDomainCache(netInfo);
259 }
260 #endif
261 auto it = serverConfigMap_.find(netId);
262 if (it == serverConfigMap_.end()) {
263 DNS_CONFIG_PRINT("SetDnsCache failed: netid is not have netid:%{public}d,", netId);
264 return;
265 }
266
267 it->second.GetCache().Put(hostName, addrInfo);
268 }
269
GetDnsCache(uint16_t netId, const std::string &hostName)270 std::vector<AddrInfo> DnsParamCache::GetDnsCache(uint16_t netId, const std::string &hostName)
271 {
272 if (netId == 0) {
273 netId = defaultNetId_;
274 }
275
276 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
277 #ifdef FEATURE_NET_FIREWALL_ENABLE
278 int32_t appUid = static_cast<int32_t>(GetCallingUid());
279 bool isMatchAllow = false;
280 if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
281 NotifyDomianIntercept(appUid, hostName);
282 AddrInfo fakeAddr = { 0 };
283 fakeAddr.aiFamily = AF_UNSPEC;
284 fakeAddr.aiAddr.sin.sin_family = AF_UNSPEC;
285 fakeAddr.aiAddr.sin.sin_addr.s_addr = INADDR_NONE;
286 fakeAddr.aiAddrLen = sizeof(struct sockaddr_in);
287 return { fakeAddr };
288 }
289 #endif
290
291 auto it = serverConfigMap_.find(netId);
292 if (it == serverConfigMap_.end()) {
293 DNS_CONFIG_PRINT("GetDnsCache failed: netid is not have netid:%{public}d,", netId);
294 return {};
295 }
296
297 return it->second.GetCache().Get(hostName);
298 }
299
SetCacheDelayed(uint16_t netId, const std::string &hostName)300 void DnsParamCache::SetCacheDelayed(uint16_t netId, const std::string &hostName)
301 {
302 if (netId == 0) {
303 netId = defaultNetId_;
304 }
305
306 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
307 auto it = serverConfigMap_.find(netId);
308 if (it == serverConfigMap_.end()) {
309 DNS_CONFIG_PRINT("SetCacheDelayed failed: netid is not have netid:%{public}d,", netId);
310 return;
311 }
312
313 it->second.SetCacheDelayed(hostName);
314 }
315
AddUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)316 int32_t DnsParamCache::AddUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
317 {
318 std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
319 NETNATIVE_LOG_D("DnsParamCache::AddUidRange size = [%{public}zu]", uidRanges.size());
320 vpnNetId_ = netId;
321 auto middle = vpnUidRanges_.insert(vpnUidRanges_.end(), uidRanges.begin(), uidRanges.end());
322 std::inplace_merge(vpnUidRanges_.begin(), middle, vpnUidRanges_.end());
323 return 0;
324 }
325
DelUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)326 int32_t DnsParamCache::DelUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
327 {
328 std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
329 NETNATIVE_LOG_D("DnsParamCache::DelUidRange size = [%{public}zu]", uidRanges.size());
330 vpnNetId_ = 0;
331 auto end = std::set_difference(vpnUidRanges_.begin(), vpnUidRanges_.end(), uidRanges.begin(),
332 uidRanges.end(), vpnUidRanges_.begin());
333 vpnUidRanges_.erase(end, vpnUidRanges_.end());
334 return 0;
335 }
336
IsVpnOpen() const337 bool DnsParamCache::IsVpnOpen() const
338 {
339 return vpnUidRanges_.size();
340 }
341
342 #ifdef FEATURE_NET_FIREWALL_ENABLE
GetUserId(int32_t appUid)343 int32_t DnsParamCache::GetUserId(int32_t appUid)
344 {
345 int32_t userId = appUid / USER_ID_DIVIDOR;
346 return userId > 0 ? userId : currentUserId_;
347 }
348
GetDnsServersByAppUid(int32_t appUid, std::vector<std::string> &servers)349 bool DnsParamCache::GetDnsServersByAppUid(int32_t appUid, std::vector<std::string> &servers)
350 {
351 if (netFirewallDnsRuleMap_.empty()) {
352 return false;
353 }
354 DNS_CONFIG_PRINT("GetDnsServersByAppUid: appUid=%{public}d", appUid);
355 auto it = netFirewallDnsRuleMap_.find(appUid);
356 if (it == netFirewallDnsRuleMap_.end()) {
357 // if appUid not found, try to find invalid appUid=0;
358 it = netFirewallDnsRuleMap_.find(0);
359 }
360 if (it != netFirewallDnsRuleMap_.end()) {
361 int32_t userId = GetUserId(appUid);
362 std::vector<sptr<NetFirewallDnsRule>> rules = it->second;
363 for (const auto &rule : rules) {
364 if (rule->userId != userId) {
365 continue;
366 }
367 servers.emplace_back(rule->primaryDns);
368 servers.emplace_back(rule->standbyDns);
369 }
370 return true;
371 }
372 return false;
373 }
374
SetFirewallRules(NetFirewallRuleType type, const std::vector<sptr<NetFirewallBaseRule>> &ruleList, bool isFinish)375 int32_t DnsParamCache::SetFirewallRules(NetFirewallRuleType type,
376 const std::vector<sptr<NetFirewallBaseRule>> &ruleList, bool isFinish)
377 {
378 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
379 NETNATIVE_LOGI("SetFirewallRules: size=%{public}zu isFinish=%{public}" PRId32, ruleList.size(), isFinish);
380 if (ruleList.empty()) {
381 NETNATIVE_LOGE("SetFirewallRules: rules is empty");
382 return -1;
383 }
384 int32_t ret = 0;
385 switch (type) {
386 case NetFirewallRuleType::RULE_DNS: {
387 for (const auto &rule : ruleList) {
388 firewallDnsRules_.emplace_back(firewall_rule_cast<NetFirewallDnsRule>(rule));
389 }
390 if (isFinish) {
391 ret = SetFirewallDnsRules(firewallDnsRules_);
392 firewallDnsRules_.clear();
393 }
394 break;
395 }
396 case NetFirewallRuleType::RULE_DOMAIN: {
397 for (const auto &rule : ruleList) {
398 firewallDomainRules_.emplace_back(firewall_rule_cast<NetFirewallDomainRule>(rule));
399 }
400 if (isFinish) {
401 ret = SetFirewallDomainRules(firewallDomainRules_);
402 firewallDomainRules_.clear();
403 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
404 }
405 break;
406 }
407 default:
408 break;
409 }
410 return ret;
411 }
412
SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> &ruleList)413 int32_t DnsParamCache::SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> &ruleList)
414 {
415 for (const auto &rule : ruleList) {
416 std::vector<sptr<NetFirewallDnsRule>> rules;
417 auto it = netFirewallDnsRuleMap_.find(rule->appUid);
418 if (it != netFirewallDnsRuleMap_.end()) {
419 rules = it->second;
420 }
421 rules.emplace_back(std::move(rule));
422 netFirewallDnsRuleMap_.emplace(rule->appUid, std::move(rules));
423 }
424 return 0;
425 }
426
GetFirewallRuleAction(int32_t appUid, const std::vector<sptr<NetFirewallDomainRule>> &rules)427 FirewallRuleAction DnsParamCache::GetFirewallRuleAction(int32_t appUid,
428 const std::vector<sptr<NetFirewallDomainRule>> &rules)
429 {
430 int32_t userId = GetUserId(appUid);
431 for (const auto &rule : rules) {
432 if (rule->userId != userId) {
433 continue;
434 }
435 if ((rule->appUid && appUid == rule->appUid) || !rule->appUid) {
436 return rule->ruleAction;
437 }
438 }
439
440 return FirewallRuleAction::RULE_INVALID;
441 }
442
checkEmpty4InterceptDomain(const std::string &hostName)443 bool DnsParamCache::checkEmpty4InterceptDomain(const std::string &hostName)
444 {
445 if (hostName.empty()) {
446 return true;
447 }
448 if (!netFirewallDomainRulesAllowMap_.empty() || !netFirewallDomainRulesDenyMap_.empty()) {
449 return false;
450 }
451 if (domainAllowLsmTrie_ && !domainAllowLsmTrie_->Empty()) {
452 return false;
453 }
454 return !domainDenyLsmTrie_ || domainDenyLsmTrie_->Empty();
455 }
456
IsInterceptDomain(int32_t appUid, const std::string &hostName, bool &isMatchAllow)457 bool DnsParamCache::IsInterceptDomain(int32_t appUid, const std::string &hostName, bool &isMatchAllow)
458 {
459 if (checkEmpty4InterceptDomain(hostName)) {
460 return false;
461 }
462 std::string host = hostName.substr(0, hostName.find(' '));
463 DNS_CONFIG_PRINT("IsInterceptDomain: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
464 std::transform(host.begin(), host.end(), host.begin(), ::tolower);
465 std::vector<sptr<NetFirewallDomainRule>> rules;
466 FirewallRuleAction exactAllowAction = FirewallRuleAction::RULE_INVALID;
467 auto it = netFirewallDomainRulesAllowMap_.find(host);
468 if (it != netFirewallDomainRulesAllowMap_.end()) {
469 rules = it->second;
470 exactAllowAction = GetFirewallRuleAction(appUid, rules);
471 }
472 FirewallRuleAction exactDenyAction = FirewallRuleAction::RULE_INVALID;
473 auto iter = netFirewallDomainRulesDenyMap_.find(host);
474 if (iter != netFirewallDomainRulesDenyMap_.end()) {
475 rules = iter->second;
476 exactDenyAction = GetFirewallRuleAction(appUid, rules);
477 }
478 FirewallRuleAction wildcardAllowAction = FirewallRuleAction::RULE_INVALID;
479 if (domainAllowLsmTrie_->LongestSuffixMatch(host, rules)) {
480 wildcardAllowAction = GetFirewallRuleAction(appUid, rules);
481 }
482 FirewallRuleAction wildcardDenyAction = FirewallRuleAction::RULE_INVALID;
483 if (domainDenyLsmTrie_->LongestSuffixMatch(host, rules)) {
484 wildcardDenyAction = GetFirewallRuleAction(appUid, rules);
485 }
486 isMatchAllow = (exactAllowAction != FirewallRuleAction::RULE_INVALID) ||
487 (wildcardAllowAction != FirewallRuleAction::RULE_INVALID);
488 bool isDeny = (exactDenyAction != FirewallRuleAction::RULE_INVALID) ||
489 (wildcardDenyAction != FirewallRuleAction::RULE_INVALID);
490 if (isMatchAllow) {
491 // Apply default rules in case of conflict
492 return isDeny && (firewallDefaultAction_ == FirewallRuleAction::RULE_DENY);
493 }
494 return isDeny;
495 }
496
SetFirewallDefaultAction(FirewallRuleAction inDefault, FirewallRuleAction outDefault)497 int32_t DnsParamCache::SetFirewallDefaultAction(FirewallRuleAction inDefault, FirewallRuleAction outDefault)
498 {
499 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
500 DNS_CONFIG_PRINT("SetFirewallDefaultAction: firewallDefaultAction_: %{public}d", (int)outDefault);
501 firewallDefaultAction_ = outDefault;
502 return 0;
503 }
504
BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> &rule, const std::string &domain)505 void DnsParamCache::BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> &rule, const std::string &domain)
506 {
507 std::vector<sptr<NetFirewallDomainRule>> rules;
508 std::string suffix(domain);
509 auto wildcardCharIndex = suffix.find('*');
510 if (wildcardCharIndex != std::string::npos) {
511 suffix = suffix.substr(wildcardCharIndex + 1);
512 }
513 DNS_CONFIG_PRINT("BuildFirewallDomainLsmTrie: suffix: %{public}s", suffix.c_str());
514 std::transform(suffix.begin(), suffix.end(), suffix.begin(), ::tolower);
515 if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
516 if (domainDenyLsmTrie_->LongestSuffixMatch(suffix, rules)) {
517 rules.emplace_back(std::move(rule));
518 domainDenyLsmTrie_->Update(suffix, rules);
519 return;
520 }
521 rules.emplace_back(std::move(rule));
522 domainDenyLsmTrie_->Insert(suffix, rules);
523 } else {
524 if (domainAllowLsmTrie_->LongestSuffixMatch(suffix, rules)) {
525 rules.emplace_back(std::move(rule));
526 domainAllowLsmTrie_->Update(suffix, rules);
527 return;
528 }
529 rules.emplace_back(std::move(rule));
530 domainAllowLsmTrie_->Insert(suffix, rules);
531 }
532 }
533
BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> &rule, const std::string &raw)534 void DnsParamCache::BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> &rule, const std::string &raw)
535 {
536 DNS_CONFIG_PRINT("BuildFirewallDomainMap: domain: %{public}s", raw.c_str());
537 std::string domain(raw);
538 std::vector<sptr<NetFirewallDomainRule>> rules;
539 std::transform(domain.begin(), domain.end(), domain.begin(), ::tolower);
540 if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
541 auto it = netFirewallDomainRulesDenyMap_.find(domain);
542 if (it != netFirewallDomainRulesDenyMap_.end()) {
543 rules = it->second;
544 }
545
546 rules.emplace_back(std::move(rule));
547 netFirewallDomainRulesDenyMap_.emplace(domain, std::move(rules));
548 } else {
549 auto it = netFirewallDomainRulesAllowMap_.find(domain);
550 if (it != netFirewallDomainRulesAllowMap_.end()) {
551 rules = it->second;
552 }
553
554 rules.emplace_back(rule);
555 netFirewallDomainRulesAllowMap_.emplace(domain, std::move(rules));
556 }
557 }
558
SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> &ruleList)559 int32_t DnsParamCache::SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> &ruleList)
560 {
561 if (!domainAllowLsmTrie_) {
562 domainAllowLsmTrie_ =
563 std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
564 }
565 if (!domainDenyLsmTrie_) {
566 domainDenyLsmTrie_ =
567 std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
568 }
569 for (const auto &rule : ruleList) {
570 for (const auto ¶m : rule->domains) {
571 if (param.isWildcard) {
572 BuildFirewallDomainLsmTrie(rule, param.domain);
573 } else {
574 BuildFirewallDomainMap(rule, param.domain);
575 }
576 }
577 }
578 return 0;
579 }
580
ClearFirewallRules(NetFirewallRuleType type)581 int32_t DnsParamCache::ClearFirewallRules(NetFirewallRuleType type)
582 {
583 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
584 switch (type) {
585 case NetFirewallRuleType::RULE_DNS:
586 firewallDnsRules_.clear();
587 netFirewallDnsRuleMap_.clear();
588 break;
589 case NetFirewallRuleType::RULE_DOMAIN: {
590 firewallDomainRules_.clear();
591 netFirewallDomainRulesAllowMap_.clear();
592 netFirewallDomainRulesDenyMap_.clear();
593 if (domainAllowLsmTrie_) {
594 domainAllowLsmTrie_ = nullptr;
595 }
596 if (domainDenyLsmTrie_) {
597 domainDenyLsmTrie_ = nullptr;
598 }
599 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
600 break;
601 }
602 case NetFirewallRuleType::RULE_ALL: {
603 firewallDnsRules_.clear();
604 netFirewallDnsRuleMap_.clear();
605 firewallDomainRules_.clear();
606 netFirewallDomainRulesAllowMap_.clear();
607 netFirewallDomainRulesDenyMap_.clear();
608 if (domainAllowLsmTrie_) {
609 domainAllowLsmTrie_ = nullptr;
610 }
611 if (domainDenyLsmTrie_) {
612 domainDenyLsmTrie_ = nullptr;
613 }
614 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
615 break;
616 }
617 default:
618 break;
619 }
620 return 0;
621 }
622
NotifyDomianIntercept(int32_t appUid, const std::string &hostName)623 void DnsParamCache::NotifyDomianIntercept(int32_t appUid, const std::string &hostName)
624 {
625 if (hostName.empty()) {
626 return;
627 }
628 std::string host = hostName.substr(0, hostName.find(' '));
629 NETNATIVE_LOGI("NotifyDomianIntercept: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
630 sptr<NetManagerStandard::InterceptRecord> record = new (std::nothrow) NetManagerStandard::InterceptRecord();
631 record->time = (int32_t)time(NULL);
632 record->appUid = appUid;
633 record->domain = host;
634
635 if (oldRecord_ != nullptr && (record->time - oldRecord_->time) < INTERCEPT_BUFF_INTERVAL_SEC) {
636 if (record->appUid == oldRecord_->appUid && record->domain == oldRecord_->domain) {
637 return;
638 }
639 }
640 oldRecord_ = record;
641 for (const auto &callback : callbacks_) {
642 callback->OnIntercept(record);
643 }
644 }
645
RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)646 int32_t DnsParamCache::RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
647 {
648 if (!callback) {
649 return -1;
650 }
651
652 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
653 callbacks_.emplace_back(callback);
654
655 return 0;
656 }
657
UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)658 int32_t DnsParamCache::UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
659 {
660 if (!callback) {
661 return -1;
662 }
663
664 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
665 for (auto it = callbacks_.begin(); it != callbacks_.end(); ++it) {
666 if (*it == callback) {
667 callbacks_.erase(it);
668 return 0;
669 }
670 }
671 return -1;
672 }
673 #endif
674
GetDumpInfo(std::string &info)675 void DnsParamCache::GetDumpInfo(std::string &info)
676 {
677 std::string dnsData;
678 static const std::string TAB = " ";
679 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
680 std::for_each(serverConfigMap_.begin(), serverConfigMap_.end(), [&dnsData](const auto &serverConfig) {
681 dnsData.append(TAB + "NetId: " + std::to_string(serverConfig.second.GetNetId()) + "\n");
682 dnsData.append(TAB + "TimeoutMsec: " + std::to_string(serverConfig.second.GetTimeoutMsec()) + "\n");
683 dnsData.append(TAB + "RetryCount: " + std::to_string(serverConfig.second.GetRetryCount()) + "\n");
684 dnsData.append(TAB + "Servers:");
685 GetVectorData(serverConfig.second.GetServers(), dnsData);
686 dnsData.append(TAB + "Domains:");
687 GetVectorData(serverConfig.second.GetDomains(), dnsData);
688 });
689 info.append(dnsData);
690 }
691 } // namespace OHOS::nmd
692