1/*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include <securec.h>
17#include <string>
18
19#include "hilog/log.h"
20#include "netmgr_ext_log_wrapper.h"
21#include "netfirewall_db_helper.h"
22
23using namespace OHOS::NativeRdb;
24namespace {
25const std::string DATABASE_ID = "id";
26const std::string RULE_ID = "ruleId";
27const std::string DOMAIN_NUM = "domainNum";
28const std::string FUZZY_NUM = "fuzzyDomainNum";
29const std::string SQL_SUM = "SELECT SUM(";
30const std::string SQL_FROM = ") FROM ";
31}
32
33namespace OHOS {
34namespace NetManagerStandard {
35NetFirewallDbHelper::NetFirewallDbHelper()
36{
37    firewallDatabase_ = NetFirewallDataBase::GetInstance();
38}
39
40NetFirewallDbHelper::~NetFirewallDbHelper()
41{
42    firewallDatabase_ = nullptr;
43}
44
45NetFirewallDbHelper &NetFirewallDbHelper::GetInstance()
46{
47    static NetFirewallDbHelper instance;
48    return instance;
49}
50
51bool NetFirewallDbHelper::DomainListToBlob(const std::vector<NetFirewallDomainParam> &vec, std::vector<uint8_t> &blob,
52    uint32_t &fuzzyNum)
53{
54    blob.clear();
55    for (const auto &param : vec) {
56        if (param.isWildcard) {
57            fuzzyNum++;
58        }
59        // 1 put isWildcard
60        blob.emplace_back(param.isWildcard ? 1 : 0);
61        // 2 for those with a string type, calculate the string size
62        uint16_t size = (uint16_t)(param.domain.length());
63        uint8_t *sizePtr = (uint8_t *)&size;
64        blob.emplace_back(sizePtr[0]);
65        blob.emplace_back(sizePtr[1]);
66        // 3 store string
67        std::vector<uint8_t> domain(param.domain.begin(), param.domain.end());
68        blob.insert(blob.end(), domain.begin(), domain.end());
69    }
70    return blob.size() > 0;
71}
72
73bool NetFirewallDbHelper::BlobToDomainList(const std::vector<uint8_t> &blob, std::vector<NetFirewallDomainParam> &vec)
74{
75    vec.clear();
76    size_t blobSize = blob.size();
77    if (blobSize < 1) {
78        return false;
79    }
80
81    size_t i = 0;
82    size_t lenSize = sizeof(uint16_t);
83    while (i < blobSize) {
84        NetFirewallDomainParam param;
85        // 1 get isWildcard
86        param.isWildcard = blob[i] ? true : false;
87        // 2 get size
88        i++;
89        if (i >= blobSize || (blobSize - i) < lenSize) {
90            return true;
91        }
92        const uint8_t *sizePtr = &blob[i];
93        uint16_t size = *((uint16_t *)sizePtr);
94        size_t index = i + lenSize;
95        if (index >= blobSize || (blobSize - index) < size) {
96            return true;
97        }
98        // 3 get string
99        auto it = blob.begin() + index;
100        param.domain = std::string(it, it + size);
101        vec.emplace_back(param);
102        i += size + lenSize;
103    }
104
105    return vec.size() > 0;
106}
107
108template <typename T> void NetFirewallDbHelper::ListToBlob(const std::vector<T> &vec, std::vector<uint8_t> &blob)
109{
110    blob.clear();
111    size_t size = sizeof(T);
112    for (const auto &param : vec) {
113        const uint8_t *data = reinterpret_cast<const uint8_t *>(&param);
114        std::vector<uint8_t> item(data, data + size);
115        // 1 store each object
116        blob.insert(blob.end(), item.begin(), item.end());
117    }
118}
119
120template <typename T> void NetFirewallDbHelper::BlobToList(const std::vector<uint8_t> &blob, std::vector<T> &vec)
121{
122    vec.clear();
123    size_t blobSize = blob.size();
124    if (blobSize < 1) {
125        return;
126    }
127
128    size_t i = 0;
129    size_t size = sizeof(T);
130    while (i < blobSize) {
131        if ((blobSize - i) < size) {
132            return;
133        }
134        T value;
135        memset_s(&value, size, 0, size);
136        memcpy_s(&value, size, &blob[i], size);
137        vec.emplace_back(value);
138        i += size;
139    }
140}
141
142int32_t NetFirewallDbHelper::FillValuesOfFirewallRule(ValuesBucket &values, const NetFirewallRule &rule)
143{
144    values.Clear();
145
146    values.PutInt(NET_FIREWALL_USER_ID, rule.userId);
147    values.PutString(NET_FIREWALL_RULE_NAME, rule.ruleName);
148    values.PutString(NET_FIREWALL_RULE_DESC, rule.ruleDescription);
149    values.PutInt(NET_FIREWALL_RULE_DIR, static_cast<int32_t>(rule.ruleDirection));
150    values.PutInt(NET_FIREWALL_RULE_ACTION, static_cast<int32_t>(rule.ruleAction));
151    values.PutInt(NET_FIREWALL_RULE_TYPE, static_cast<int32_t>(rule.ruleType));
152    values.PutInt(NET_FIREWALL_IS_ENABLED, rule.isEnabled);
153    values.PutInt(NET_FIREWALL_APP_ID, rule.appUid);
154    std::vector<uint8_t> blob;
155    std::vector<DataBaseIp> dbIPs;
156    std::vector<DataBasePort> dbPorts;
157    switch (rule.ruleType) {
158        case NetFirewallRuleType::RULE_IP: {
159            values.PutInt(NET_FIREWALL_PROTOCOL, static_cast<int32_t>(rule.protocol));
160            FirewallIpToDbIp(rule.localIps, dbIPs);
161            ListToBlob(dbIPs, blob);
162            values.PutBlob(NET_FIREWALL_LOCAL_IP, blob);
163
164            FirewallIpToDbIp(rule.remoteIps, dbIPs);
165            ListToBlob(dbIPs, blob);
166            values.PutBlob(NET_FIREWALL_REMOTE_IP, blob);
167
168            FirewallPortToDbPort(rule.localPorts, dbPorts);
169            ListToBlob(dbPorts, blob);
170            values.PutBlob(NET_FIREWALL_LOCAL_PORT, blob);
171
172            FirewallPortToDbPort(rule.remotePorts, dbPorts);
173            ListToBlob(dbPorts, blob);
174            values.PutBlob(NET_FIREWALL_REMOTE_PORT, blob);
175            break;
176        }
177        case NetFirewallRuleType::RULE_DNS: {
178            values.PutString(NET_FIREWALL_DNS_PRIMARY, rule.dns.primaryDns);
179            values.PutString(NET_FIREWALL_DNS_STANDY, rule.dns.standbyDns);
180            break;
181        }
182        case NetFirewallRuleType::RULE_DOMAIN: {
183            values.PutInt(DOMAIN_NUM, rule.domains.size());
184            uint32_t fuzzyNum = 0;
185            DomainListToBlob(rule.domains, blob, fuzzyNum);
186            values.PutInt(FUZZY_NUM, fuzzyNum);
187            values.PutBlob(NET_FIREWALL_RULE_DOMAIN, blob);
188            break;
189        }
190        default:
191            break;
192    }
193    return FIREWALL_OK;
194}
195
196
197int32_t NetFirewallDbHelper::AddFirewallRule(NativeRdb::ValuesBucket &values, const NetFirewallRule &rule)
198{
199    FillValuesOfFirewallRule(values, rule);
200    return firewallDatabase_->Insert(values, FIREWALL_TABLE_NAME);
201}
202
203int32_t NetFirewallDbHelper::AddFirewallRuleRecord(const NetFirewallRule &rule)
204{
205    std::lock_guard<std::mutex> guard(databaseMutex_);
206    ValuesBucket values;
207    int32_t ret = AddFirewallRule(values, rule);
208    if (ret < FIREWALL_OK) {
209        NETMGR_EXT_LOG_E("AddFirewallRule Insert error: %{public}d", ret);
210        (void)firewallDatabase_->RollBack();
211    }
212    return ret;
213}
214
215int32_t NetFirewallDbHelper::CheckIfNeedUpdateEx(const std::string &tableName, bool &isUpdate, int32_t ruleId,
216    NetFirewallRule &oldRule)
217{
218    std::vector<std::string> columns;
219    RdbPredicates rdbPredicates(tableName);
220    rdbPredicates.BeginWrap()->EqualTo(RULE_ID, std::to_string(ruleId))->EndWrap();
221    auto resultSet = firewallDatabase_->Query(rdbPredicates, columns);
222    if (resultSet == nullptr) {
223        NETMGR_EXT_LOG_E("Query error");
224        return FIREWALL_RDB_EXECUTE_FAILTURE;
225    }
226    int32_t rowCount = 0;
227    if (resultSet->GetRowCount(rowCount) != E_OK) {
228        NETMGR_EXT_LOG_E("GetRowCount error");
229        return FIREWALL_RDB_EXECUTE_FAILTURE;
230    }
231    std::vector<NetFirewallRule> rules;
232    GetResultRightRecordEx(resultSet, rules);
233    isUpdate = rowCount > 0 && !rules.empty();
234    if (!rules.empty()) {
235        oldRule.ruleId = rules[0].ruleId;
236        oldRule.userId = rules[0].userId;
237        oldRule.ruleType = rules[0].ruleType;
238        oldRule.isEnabled = rules[0].isEnabled;
239    }
240    return FIREWALL_OK;
241}
242
243int32_t NetFirewallDbHelper::UpdateFirewallRuleRecord(const NetFirewallRule &rule)
244{
245    std::lock_guard<std::mutex> guard(databaseMutex_);
246
247    ValuesBucket values;
248    FillValuesOfFirewallRule(values, rule);
249    int32_t changedRows = 0;
250    int32_t ret = firewallDatabase_->Update(FIREWALL_TABLE_NAME, changedRows, values, "ruleId = ?",
251        std::vector<std::string> { std::to_string(rule.ruleId) });
252    if (ret < FIREWALL_OK) {
253        NETMGR_EXT_LOG_E("Update error: %{public}d", ret);
254        (void)firewallDatabase_->RollBack();
255    }
256    return ret;
257}
258
259void NetFirewallDbHelper::GetParamRuleInfoFormResultSet(std::string &columnName, int32_t index,
260    NetFirewallRuleInfo &table)
261{
262    if (columnName == NET_FIREWALL_PROTOCOL) {
263        table.protocolIndex = index;
264        return;
265    }
266    if (columnName == NET_FIREWALL_LOCAL_IP) {
267        table.localIpsIndex = index;
268        return;
269    }
270    if (columnName == NET_FIREWALL_REMOTE_IP) {
271        table.remoteIpsIndex = index;
272        return;
273    }
274    if (columnName == NET_FIREWALL_LOCAL_PORT) {
275        table.localPortsIndex = index;
276        return;
277    }
278    if (columnName == NET_FIREWALL_REMOTE_PORT) {
279        table.remotePortsIndex = index;
280        return;
281    }
282    if (columnName == NET_FIREWALL_RULE_DOMAIN) {
283        table.domainsIndex = index;
284        return;
285    }
286    if (columnName == NET_FIREWALL_DNS_PRIMARY) {
287        table.primaryDnsIndex = index;
288        return;
289    }
290    if (columnName == NET_FIREWALL_DNS_STANDY) {
291        table.standbyDnsIndex = index;
292    }
293}
294
295int32_t NetFirewallDbHelper::GetResultSetTableInfo(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
296    NetFirewallRuleInfo &table)
297{
298    std::vector<std::string> columnNames;
299    if (resultSet->GetRowCount(table.rowCount) != E_OK || resultSet->GetAllColumnNames(columnNames) != E_OK) {
300        NETMGR_EXT_LOG_E("get table info failed");
301        return FIREWALL_RDB_EXECUTE_FAILTURE;
302    }
303    int32_t columnNamesCount = static_cast<int32_t>(columnNames.size());
304    for (int32_t i = 0; i < columnNamesCount; i++) {
305        std::string &columnName = columnNames.at(i);
306        if (columnName == RULE_ID) {
307            table.ruleIdIndex = i;
308            continue;
309        }
310        if (columnName == NET_FIREWALL_USER_ID) {
311            table.userIdIndex = i;
312            continue;
313        }
314        if (columnName == NET_FIREWALL_RULE_NAME) {
315            table.ruleNameIndex = i;
316            continue;
317        }
318        if (columnName == NET_FIREWALL_RULE_DESC) {
319            table.ruleDescriptionIndex = i;
320            continue;
321        }
322        if (columnName == NET_FIREWALL_RULE_DIR) {
323            table.ruleDirectionIndex = i;
324            continue;
325        }
326        if (columnName == NET_FIREWALL_RULE_ACTION) {
327            table.ruleActionIndex = i;
328            continue;
329        }
330        if (columnName == NET_FIREWALL_RULE_TYPE) {
331            table.ruleTypeIndex = i;
332            continue;
333        }
334        if (columnName == NET_FIREWALL_IS_ENABLED) {
335            table.isEnabledIndex = i;
336            continue;
337        }
338        if (columnName == NET_FIREWALL_APP_ID) {
339            table.appUidIndex = i;
340            continue;
341        }
342        GetParamRuleInfoFormResultSet(columnName, i, table);
343    }
344    return FIREWALL_OK;
345}
346
347int32_t NetFirewallDbHelper::GetResultSetTableInfo(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
348    NetInterceptRecordInfo &table)
349{
350    int32_t rowCount = 0;
351    std::vector<std::string> columnNames;
352    if (resultSet->GetRowCount(rowCount) != E_OK || resultSet->GetAllColumnNames(columnNames) != E_OK) {
353        NETMGR_EXT_LOG_E("get table info failed");
354        return FIREWALL_RDB_EXECUTE_FAILTURE;
355    }
356    int32_t columnNamesCount = static_cast<int32_t>(columnNames.size());
357    for (int32_t i = 0; i < columnNamesCount; i++) {
358        std::string &columnName = columnNames.at(i);
359        if (columnName == NET_FIREWALL_RECORD_TIME) {
360            table.timeIndex = i;
361            continue;
362        }
363        if (columnName == NET_FIREWALL_RECORD_LOCAL_IP) {
364            table.localIpIndex = i;
365            continue;
366        }
367        if (columnName == NET_FIREWALL_RECORD_REMOTE_IP) {
368            table.remoteIpIndex = i;
369            continue;
370        }
371        if (columnName == NET_FIREWALL_RECORD_LOCAL_PORT) {
372            table.localPortIndex = i;
373            continue;
374        }
375        if (columnName == NET_FIREWALL_RECORD_REMOTE_PORT) {
376            table.remotePortIndex = i;
377            continue;
378        }
379        if (columnName == NET_FIREWALL_RECORD_PROTOCOL) {
380            table.protocolIndex = i;
381            continue;
382        }
383        if (columnName == NET_FIREWALL_RECORD_UID) {
384            table.appUidIndex = i;
385            continue;
386        }
387        if (columnName == NET_FIREWALL_DOMAIN) {
388            table.domainIndex = i;
389        }
390    }
391    table.rowCount = rowCount;
392    return FIREWALL_OK;
393}
394
395void NetFirewallDbHelper::GetRuleDataFromResultSet(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
396    const NetFirewallRuleInfo &table, NetFirewallRule &info)
397{
398    resultSet->GetInt(table.userIdIndex, info.userId);
399    resultSet->GetString(table.ruleNameIndex, info.ruleName);
400    resultSet->GetString(table.ruleDescriptionIndex, info.ruleDescription);
401    int ruleDirection = 0;
402    if (resultSet->GetInt(table.ruleDirectionIndex, ruleDirection) == E_OK) {
403        info.ruleDirection = static_cast<NetFirewallRuleDirection>(ruleDirection);
404    }
405    int ruleAction = 0;
406    if (resultSet->GetInt(table.ruleActionIndex, ruleAction) == E_OK) {
407        info.ruleAction = static_cast<FirewallRuleAction>(ruleAction);
408    }
409    int ruleType = 0;
410    if (resultSet->GetInt(table.ruleTypeIndex, ruleType) == E_OK) {
411        info.ruleType = static_cast<NetFirewallRuleType>(ruleType);
412    }
413    int isEnabled = 0;
414    if (resultSet->GetInt(table.isEnabledIndex, isEnabled) == E_OK) {
415        info.isEnabled = static_cast<bool>(isEnabled);
416    }
417    resultSet->GetInt(table.appUidIndex, info.appUid);
418    GetRuleListParamFromResultSet(resultSet, table, info);
419}
420
421void NetFirewallDbHelper::GetRuleListParamFromResultSet(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
422    const NetFirewallRuleInfo &table, NetFirewallRule &info)
423{
424    std::vector<uint8_t> value;
425    std::vector<DataBaseIp> dbIPs;
426    std::vector<DataBasePort> dbPorts;
427    switch (info.ruleType) {
428        case NetFirewallRuleType::RULE_IP: {
429            int protocol = 0;
430            if (resultSet->GetInt(table.protocolIndex, protocol) == E_OK) {
431                info.protocol = static_cast<NetworkProtocol>(protocol);
432            }
433            resultSet->GetBlob(table.localIpsIndex, value);
434            BlobToList(value, dbIPs);
435            DbIpToFirewallIp(dbIPs, info.localIps);
436            value.clear();
437            resultSet->GetBlob(table.remoteIpsIndex, value);
438            BlobToList(value, dbIPs);
439            DbIpToFirewallIp(dbIPs, info.remoteIps);
440            value.clear();
441            resultSet->GetBlob(table.localPortsIndex, value);
442            BlobToList(value, dbPorts);
443            DbPortToFirewallPort(dbPorts, info.localPorts);
444            value.clear();
445            resultSet->GetBlob(table.remotePortsIndex, value);
446            BlobToList(value, dbPorts);
447            DbPortToFirewallPort(dbPorts, info.remotePorts);
448            break;
449        }
450        case NetFirewallRuleType::RULE_DNS: {
451            resultSet->GetString(table.primaryDnsIndex, info.dns.primaryDns);
452            resultSet->GetString(table.standbyDnsIndex, info.dns.standbyDns);
453            break;
454        }
455
456        case NetFirewallRuleType::RULE_DOMAIN: {
457            resultSet->GetBlob(table.domainsIndex, value);
458            BlobToDomainList(value, info.domains);
459            break;
460        }
461        default:
462            break;
463    }
464}
465
466int32_t NetFirewallDbHelper::GetResultRightRecordEx(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
467    std::vector<NetFirewallRule> &rules)
468{
469    NetFirewallRuleInfo table;
470    int32_t ret = GetResultSetTableInfo(resultSet, table);
471    if (ret < FIREWALL_OK) {
472        NETMGR_EXT_LOG_E("GetResultSetTableInfo failed");
473        return ret;
474    }
475
476    bool endFlag = false;
477    NetFirewallRule info;
478    for (int32_t i = 0; (i < table.rowCount) && !endFlag; i++) {
479        if (resultSet->GoToRow(i) != E_OK) {
480            NETMGR_EXT_LOG_E("GoToRow %{public}d", i);
481            break;
482        }
483        resultSet->GetInt(table.ruleIdIndex, info.ruleId);
484        if (info.ruleId > 0) {
485            GetRuleDataFromResultSet(resultSet, table, info);
486            rules.emplace_back(std::move(info));
487        }
488
489        resultSet->IsEnded(endFlag);
490    }
491    resultSet->Close();
492    return rules.size();
493}
494
495int32_t NetFirewallDbHelper::GetResultRightRecordEx(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
496    std::vector<InterceptRecord> &rules)
497{
498    NetInterceptRecordInfo table;
499    int32_t ret = GetResultSetTableInfo(resultSet, table);
500    if (ret < FIREWALL_OK) {
501        NETMGR_EXT_LOG_E("GetResultSetTableInfo failed");
502        return ret;
503    }
504
505    bool endFlag = false;
506    int32_t localPort = 0;
507    int32_t remotePort = 0;
508    int32_t protocol = 0;
509    InterceptRecord info;
510    for (int32_t i = 0; (i < table.rowCount) && !endFlag; i++) {
511        if (resultSet->GoToRow(i) != E_OK) {
512            NETMGR_EXT_LOG_E("GetResultRightRecordEx GoToRow %{public}d", i);
513            break;
514        }
515        resultSet->GetInt(table.timeIndex, info.time);
516        resultSet->GetString(table.localIpIndex, info.localIp);
517        resultSet->GetString(table.remoteIpIndex, info.remoteIp);
518        if (resultSet->GetInt(table.localPortIndex, localPort) == E_OK) {
519            info.localPort = static_cast<uint16_t>(localPort);
520        }
521        if (resultSet->GetInt(table.remotePortIndex, remotePort) == E_OK) {
522            info.remotePort = static_cast<uint16_t>(remotePort);
523        }
524        if (resultSet->GetInt(table.protocolIndex, protocol) == E_OK) {
525            info.protocol = static_cast<uint16_t>(protocol);
526        }
527        resultSet->GetInt(table.appUidIndex, info.appUid);
528        resultSet->GetString(table.domainIndex, info.domain);
529        if (info.time > 0) {
530            rules.emplace_back(std::move(info));
531        }
532        resultSet->IsEnded(endFlag);
533    }
534    int32_t index = 0;
535    resultSet->GetRowIndex(index);
536    resultSet->IsEnded(endFlag);
537    NETMGR_EXT_LOG_I("row=%{public}d pos=%{public}d ret=%{public}zu end=%{public}s", table.rowCount, index,
538        rules.size(), (endFlag ? "yes" : "no"));
539
540    resultSet->Close();
541    return rules.size();
542}
543
544template <typename T>
545int32_t NetFirewallDbHelper::QueryAndGetResult(const NativeRdb::RdbPredicates &rdbPredicates,
546    const std::vector<std::string> &columns, std::vector<T> &rules)
547{
548    auto resultSet = firewallDatabase_->Query(rdbPredicates, columns);
549    if (resultSet == nullptr) {
550        NETMGR_EXT_LOG_E("Query error");
551        return FIREWALL_RDB_EXECUTE_FAILTURE;
552    }
553    return GetResultRightRecordEx(resultSet, rules);
554}
555
556int32_t NetFirewallDbHelper::QueryAllFirewallRuleRecord(std::vector<NetFirewallRule> &rules)
557{
558    std::lock_guard<std::mutex> guard(databaseMutex_);
559    NETMGR_EXT_LOG_I("Query detail: all user");
560    std::vector<std::string> columns;
561    RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
562    return QueryFirewallRuleRecord(rdbPredicates, columns, rules);
563}
564
565int32_t NetFirewallDbHelper::QueryAllUserEnabledFirewallRules(std::vector<NetFirewallRule> &rules,
566    NetFirewallRuleType type)
567{
568    std::lock_guard<std::mutex> guard(databaseMutex_);
569    NETMGR_EXT_LOG_I("Query detail: all user");
570    std::vector<std::string> columns;
571    RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
572    rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_IS_ENABLED, "1");
573    if (type != NetFirewallRuleType::RULE_ALL && type != NetFirewallRuleType::RULE_INVALID) {
574        rdbPredicates.And()->EqualTo(NET_FIREWALL_RULE_TYPE, std::to_string(static_cast<int32_t>(type)));
575    }
576    rdbPredicates.EndWrap();
577    return QueryFirewallRuleRecord(rdbPredicates, columns, rules);
578}
579
580int32_t NetFirewallDbHelper::QueryEnabledFirewallRules(int32_t userId, int32_t appUid,
581    std::vector<NetFirewallRule> &rules)
582{
583    std::lock_guard<std::mutex> guard(databaseMutex_);
584    NETMGR_EXT_LOG_I("QueryEnabledFirewallRules : userId=%{public}d ", userId);
585    std::vector<std::string> columns;
586    RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
587    rdbPredicates.BeginWrap()
588        ->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))
589        ->And()
590        ->EqualTo(NET_FIREWALL_IS_ENABLED, "1")
591        ->And()
592        ->EqualTo(NET_FIREWALL_APP_ID, appUid)
593        ->EndWrap();
594    return QueryFirewallRuleRecord(rdbPredicates, columns, rules);
595}
596
597int32_t NetFirewallDbHelper::QueryFirewallRuleRecord(int32_t ruleId, int32_t userId,
598    std::vector<NetFirewallRule> &rules)
599{
600    std::lock_guard<std::mutex> guard(databaseMutex_);
601    NETMGR_EXT_LOG_I("Query detail: ruleId=%{public}d userId=%{public}d", ruleId, userId);
602    std::vector<std::string> columns;
603    RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
604    rdbPredicates.BeginWrap()
605        ->EqualTo(RULE_ID, std::to_string(ruleId))
606        ->And()
607        ->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))
608        ->EndWrap();
609
610    return QueryFirewallRuleRecord(rdbPredicates, columns, rules);
611}
612
613int32_t NetFirewallDbHelper::QueryFirewallRuleRecord(const NativeRdb::RdbPredicates &rdbPredicates,
614    const std::vector<std::string> &columns, std::vector<NetFirewallRule> &rules)
615{
616    int32_t ret = QueryAndGetResult(rdbPredicates, columns, rules);
617    if (ret < 0) {
618        NETMGR_EXT_LOG_E("QueryFirewallRuleRecord error.");
619        return ret;
620    }
621    size_t size = rules.size();
622    if (size == 0) {
623        NETMGR_EXT_LOG_I("QueryFirewallRuleRecord rule empty");
624        return FIREWALL_OK;
625    }
626    NETMGR_EXT_LOG_I("QueryFirewallRuleRecord rule size: %{public}zu", size);
627    return FIREWALL_OK;
628}
629
630int32_t NetFirewallDbHelper::DeleteAndNoOtherOperation(const std::string &whereClause,
631    const std::vector<std::string> &whereArgs)
632{
633    int32_t changedRows = 0;
634    int32_t ret = firewallDatabase_->Delete(FIREWALL_TABLE_NAME, changedRows, whereClause, whereArgs);
635    if (ret < FIREWALL_OK) {
636        (void)firewallDatabase_->RollBack();
637        return FIREWALL_FAILURE;
638    }
639    return ret;
640}
641
642int32_t NetFirewallDbHelper::DeleteFirewallRuleRecord(int32_t userId, int32_t ruleId)
643{
644    std::lock_guard<std::mutex> guard(databaseMutex_);
645    std::string whereClause = { "userId = ? AND ruleId = ?" };
646    std::vector<std::string> whereArgs = { std::to_string(userId), std::to_string(ruleId) };
647    int32_t ret = DeleteAndNoOtherOperation(whereClause, whereArgs);
648    if (ret != FIREWALL_OK) {
649        NETMGR_EXT_LOG_E("failed: detale(ruleId): %{public}d", ret);
650    }
651    return ret;
652}
653
654int32_t NetFirewallDbHelper::DeleteFirewallRuleRecordByUserId(int32_t userId)
655{
656    std::lock_guard<std::mutex> guard(databaseMutex_);
657    std::string whereClause = { "userId = ?" };
658    std::vector<std::string> whereArgs = { std::to_string(userId) };
659    int32_t ret = DeleteAndNoOtherOperation(whereClause, whereArgs);
660    if (ret != FIREWALL_OK) {
661        NETMGR_EXT_LOG_E("failed: detale(ruleId): %{public}d", ret);
662    }
663    return ret;
664}
665
666int32_t NetFirewallDbHelper::DeleteFirewallRuleRecordByAppId(int32_t appUid)
667{
668    std::lock_guard<std::mutex> guard(databaseMutex_);
669    std::string whereClause = { "appUid = ?" };
670    std::vector<std::string> whereArgs = { std::to_string(appUid) };
671    int32_t ret = DeleteAndNoOtherOperation(whereClause, whereArgs);
672    if (ret != FIREWALL_OK) {
673        NETMGR_EXT_LOG_E("failed: detale(ruleId): %{public}d", ret);
674    }
675    return ret;
676}
677
678bool NetFirewallDbHelper::IsFirewallRuleExist(int32_t ruleId, NetFirewallRule &oldRule)
679{
680    std::lock_guard<std::mutex> guard(databaseMutex_);
681    bool isExist = false;
682    int32_t ret = CheckIfNeedUpdateEx(FIREWALL_TABLE_NAME, isExist, ruleId, oldRule);
683    if (ret < FIREWALL_OK) {
684        NETMGR_EXT_LOG_E("check if need update error: %{public}d", ret);
685    }
686    return isExist;
687}
688
689int32_t NetFirewallDbHelper::QueryFirewallRuleByUserIdCount(int32_t userId, int64_t &rowCount)
690{
691    RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
692    rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))->EndWrap();
693
694    return Count(rowCount, rdbPredicates);
695}
696
697int32_t NetFirewallDbHelper::QueryFirewallRuleAllCount(int64_t &rowCount)
698{
699    RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
700    return Count(rowCount, rdbPredicates);
701}
702
703int32_t NetFirewallDbHelper::QueryFirewallRuleAllDomainCount()
704{
705    return QuerySql(SQL_SUM + DOMAIN_NUM + SQL_FROM + FIREWALL_TABLE_NAME);
706}
707
708int32_t NetFirewallDbHelper::QueryFirewallRuleAllFuzzyDomainCount()
709{
710    return QuerySql(SQL_SUM + FUZZY_NUM + SQL_FROM + FIREWALL_TABLE_NAME);
711}
712
713int32_t NetFirewallDbHelper::QueryFirewallRuleDomainByUserIdCount(int32_t userId)
714{
715    return QuerySql(SQL_SUM + DOMAIN_NUM + SQL_FROM + FIREWALL_TABLE_NAME + " WHERE (" + NET_FIREWALL_USER_ID + " = " +
716        std::to_string(userId) + ")");
717}
718
719int32_t NetFirewallDbHelper::QueryFirewallRule(const int32_t userId, const sptr<RequestParam> &requestParam,
720    sptr<FirewallRulePage> &info)
721{
722    std::lock_guard<std::mutex> guard(databaseMutex_);
723    int64_t rowCount = 0;
724    RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
725    rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))->EndWrap();
726    firewallDatabase_->Count(rowCount, rdbPredicates);
727    info->totalPage = rowCount / requestParam->pageSize;
728    int32_t remainder = rowCount % requestParam->pageSize;
729    if (remainder > 0) {
730        info->totalPage += 1;
731    }
732    NETMGR_EXT_LOG_I("QueryFirewallRule: userId=%{public}d page=%{public}d pageSize=%{public}d total=%{public}d",
733        userId, requestParam->page, requestParam->pageSize, info->totalPage);
734    if (info->totalPage < requestParam->page) {
735        return FIREWALL_FAILURE;
736    }
737    std::vector<std::string> columns;
738    rdbPredicates.Clear();
739    rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId));
740    if (requestParam->orderType == NetFirewallOrderType::ORDER_ASC) {
741        rdbPredicates.OrderByAsc(NET_FIREWALL_RULE_NAME);
742    } else {
743        rdbPredicates.OrderByDesc(NET_FIREWALL_RULE_NAME);
744    }
745    rdbPredicates.Limit((requestParam->page - 1) * requestParam->pageSize, requestParam->pageSize)->EndWrap();
746    return QueryFirewallRuleRecord(rdbPredicates, columns, info->data);
747}
748
749int32_t NetFirewallDbHelper::Count(int64_t &outValue, const OHOS::NativeRdb::AbsRdbPredicates &predicates)
750{
751    std::lock_guard<std::mutex> guard(databaseMutex_);
752    int32_t ret = firewallDatabase_->Count(outValue, predicates);
753    if (ret < FIREWALL_OK) {
754        NETMGR_EXT_LOG_E("Count error");
755        return -1;
756    }
757    return ret;
758}
759
760int32_t NetFirewallDbHelper::QuerySql(const std::string &sql)
761{
762    std::lock_guard<std::mutex> guard(databaseMutex_);
763    std::vector<std::string> selectionArgs;
764    auto resultSet = firewallDatabase_->QuerySql(sql, selectionArgs);
765    if (resultSet == nullptr) {
766        NETMGR_EXT_LOG_E("QuerySql error");
767        return FIREWALL_RDB_EXECUTE_FAILTURE;
768    }
769    int32_t rowCount = 0;
770    if (resultSet->GetRowCount(rowCount) != E_OK || resultSet->GoToRow(0) != E_OK) {
771        return FIREWALL_RDB_EXECUTE_FAILTURE;
772    }
773    int32_t value = 0;
774    resultSet->GetInt(0, value);
775    return value;
776}
777
778bool NetFirewallDbHelper::IsDnsRuleExist(const sptr<NetFirewallRule> &rule)
779{
780    if (rule->ruleType != NetFirewallRuleType::RULE_DNS) {
781        return false;
782    }
783    std::lock_guard<std::mutex> guard(databaseMutex_);
784    RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
785    rdbPredicates.BeginWrap()
786        ->EqualTo(NET_FIREWALL_USER_ID, std::to_string(rule->userId))
787        ->And()
788        ->EqualTo(NET_FIREWALL_RULE_TYPE, std::to_string(static_cast<int32_t>(rule->ruleType)))
789        ->And()
790        ->EqualTo(NET_FIREWALL_APP_ID, std::to_string(rule->appUid))
791        ->And()
792        ->BeginWrap()
793        ->EqualTo(NET_FIREWALL_DNS_PRIMARY, rule->dns.primaryDns)
794        ->Or()
795        ->EqualTo(NET_FIREWALL_DNS_STANDY, rule->dns.standbyDns)
796        ->EndWrap()
797        ->Limit(1)
798        ->EndWrap();
799    std::vector<std::string> columns;
800    auto resultSet = firewallDatabase_->Query(rdbPredicates, columns);
801    if (resultSet == nullptr) {
802        NETMGR_EXT_LOG_E("IsDnsRuleExist Query error");
803        return false;
804    }
805    int32_t rowCount = 0;
806    resultSet->GetRowCount(rowCount);
807    return rowCount > 0;
808}
809
810int32_t NetFirewallDbHelper::AddInterceptRecord(const int32_t userId, std::vector<sptr<InterceptRecord>> &records)
811{
812    std::lock_guard<std::mutex> guard(databaseMutex_);
813    int32_t ret = firewallDatabase_->BeginTransaction();
814    // Aging by date, record up to 8 days of data
815    std::string whereClause = { "userId = ? AND time < ?" };
816    std::vector<std::string> whereArgs = { std::to_string(userId),
817        std::to_string(records.back()->time - RECORD_MAX_SAVE_TIME) };
818    int32_t changedRows = 0;
819    ret = firewallDatabase_->Delete(INTERCEPT_RECORD_TABLE, changedRows, whereClause, whereArgs);
820
821    int64_t currentRows = 0;
822    RdbPredicates rdbPredicates(INTERCEPT_RECORD_TABLE);
823    rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))->EndWrap();
824    firewallDatabase_->Count(currentRows, rdbPredicates);
825    // Aging by number, record up to 1000 pieces of data
826    size_t size = records.size();
827    size_t leftRows = static_cast<size_t>(RECORD_MAX_DATA_NUM - currentRows);
828    if (leftRows < size) {
829        std::string whereClause("id in (select id from ");
830        whereClause += INTERCEPT_RECORD_TABLE;
831        whereClause += " where userId = ? order by id limit ? )";
832        std::vector<std::string> whereArgs = { std::to_string(userId), std::to_string(size - leftRows) };
833        ret = firewallDatabase_->Delete(INTERCEPT_RECORD_TABLE, changedRows, whereClause, whereArgs);
834    }
835    // New data written to the database
836    ValuesBucket values;
837    for (size_t i = 0; i < size; i++) {
838        values.Clear();
839        values.PutInt(NET_FIREWALL_USER_ID, userId);
840        values.PutInt(NET_FIREWALL_RECORD_TIME, records[i]->time);
841        values.PutString(NET_FIREWALL_RECORD_LOCAL_IP, records[i]->localIp);
842        values.PutString(NET_FIREWALL_RECORD_REMOTE_IP, records[i]->remoteIp);
843        values.PutInt(NET_FIREWALL_RECORD_LOCAL_PORT, static_cast<int32_t>(records[i]->localPort));
844        values.PutInt(NET_FIREWALL_RECORD_REMOTE_PORT, static_cast<int32_t>(records[i]->remotePort));
845        values.PutInt(NET_FIREWALL_RECORD_PROTOCOL, static_cast<int32_t>(records[i]->protocol));
846        values.PutInt(NET_FIREWALL_RECORD_UID, records[i]->appUid);
847        values.PutString(NET_FIREWALL_DOMAIN, records[i]->domain);
848
849        ret = firewallDatabase_->Insert(values, INTERCEPT_RECORD_TABLE);
850        if (ret < FIREWALL_OK) {
851            NETMGR_EXT_LOG_E("AddInterceptRecord error: %{public}d", ret);
852            firewallDatabase_->Commit();
853            return -1;
854        }
855    }
856    return firewallDatabase_->Commit();
857}
858
859int32_t NetFirewallDbHelper::DeleteInterceptRecord(const int32_t userId)
860{
861    std::lock_guard<std::mutex> guard(databaseMutex_);
862    std::string whereClause = { "userId = ?" };
863    std::vector<std::string> whereArgs = { std::to_string(userId) };
864    int32_t changedRows = 0;
865    int32_t ret = firewallDatabase_->Delete(INTERCEPT_RECORD_TABLE, changedRows, whereClause, whereArgs);
866    if (ret < FIREWALL_OK) {
867        NETMGR_EXT_LOG_E("DeleteInterceptRecord error: %{public}d", ret);
868        return -1;
869    }
870    return ret;
871}
872
873int32_t NetFirewallDbHelper::QueryInterceptRecord(const int32_t userId, const sptr<RequestParam> &requestParam,
874    sptr<InterceptRecordPage> &info)
875{
876    std::lock_guard<std::mutex> guard(databaseMutex_);
877    int64_t rowCount = 0;
878    RdbPredicates rdbPredicates(INTERCEPT_RECORD_TABLE);
879    rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))->EndWrap();
880    firewallDatabase_->Count(rowCount, rdbPredicates);
881    info->totalPage = rowCount / requestParam->pageSize;
882    int32_t remainder = rowCount % requestParam->pageSize;
883    if (remainder > 0) {
884        info->totalPage += 1;
885    }
886    NETMGR_EXT_LOG_I("QueryInterceptRecord: userId=%{public}d page=%{public}d pageSize=%{public}d total=%{public}d",
887        userId, requestParam->page, requestParam->pageSize, info->totalPage);
888    if (info->totalPage < requestParam->page) {
889        return FIREWALL_FAILURE;
890    }
891    info->page = requestParam->page;
892    std::vector<std::string> columns;
893    rdbPredicates.Clear();
894    rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId));
895    if (requestParam->orderType == NetFirewallOrderType::ORDER_ASC) {
896        rdbPredicates.OrderByAsc(NET_FIREWALL_RECORD_TIME);
897    } else {
898        rdbPredicates.OrderByDesc(NET_FIREWALL_RECORD_TIME);
899    }
900    rdbPredicates.Limit((requestParam->page - 1) * requestParam->pageSize, requestParam->pageSize)->EndWrap();
901    return QueryAndGetResult(rdbPredicates, columns, info->data);
902}
903
904void NetFirewallDbHelper::FirewallIpToDbIp(const std::vector<NetFirewallIpParam> &ips, std::vector<DataBaseIp> &dbips)
905{
906    dbips.clear();
907    DataBaseIp dbip;
908    for (const NetFirewallIpParam &param : ips) {
909        dbip.family = param.family;
910        dbip.mask = param.mask;
911        dbip.type = param.type;
912        if (dbip.family == FAMILY_IPV4) {
913            memcpy_s(&dbip.ipv4.startIp, sizeof(uint32_t), &param.ipv4.startIp, sizeof(uint32_t));
914            memcpy_s(&dbip.ipv4.endIp, sizeof(uint32_t), &param.ipv4.endIp, sizeof(uint32_t));
915        } else {
916            memcpy_s(&dbip.ipv6.startIp, sizeof(in6_addr), &param.ipv6.startIp, sizeof(in6_addr));
917            memcpy_s(&dbip.ipv6.endIp, sizeof(in6_addr), &param.ipv6.endIp, sizeof(in6_addr));
918        }
919        dbips.emplace_back(std::move(dbip));
920    }
921}
922void NetFirewallDbHelper::DbIpToFirewallIp(const std::vector<DataBaseIp> &dbips, std::vector<NetFirewallIpParam> &ips)
923{
924    ips.clear();
925    NetFirewallIpParam dbip;
926    for (const DataBaseIp &param : dbips) {
927        dbip.family = param.family;
928        dbip.mask = param.mask;
929        dbip.type = param.type;
930        if (dbip.family == FAMILY_IPV4) {
931            memcpy_s(&dbip.ipv4.startIp, sizeof(uint32_t), &param.ipv4.startIp, sizeof(uint32_t));
932            memcpy_s(&dbip.ipv4.endIp, sizeof(uint32_t), &param.ipv4.endIp, sizeof(uint32_t));
933        } else {
934            memcpy_s(&dbip.ipv6.startIp, sizeof(in6_addr), &param.ipv6.startIp, sizeof(in6_addr));
935            memcpy_s(&dbip.ipv6.endIp, sizeof(in6_addr), &param.ipv6.endIp, sizeof(in6_addr));
936        }
937        ips.emplace_back(std::move(dbip));
938    }
939}
940void NetFirewallDbHelper::FirewallPortToDbPort(const std::vector<NetFirewallPortParam> &ports,
941    std::vector<DataBasePort> &dbports)
942{
943    dbports.clear();
944    DataBasePort dbport;
945    for (const NetFirewallPortParam &param : ports) {
946        dbport.startPort = param.startPort;
947        dbport.endPort = param.endPort;
948        dbports.emplace_back(std::move(dbport));
949    }
950}
951
952void NetFirewallDbHelper::DbPortToFirewallPort(const std::vector<DataBasePort> &dbports,
953    std::vector<NetFirewallPortParam> &ports)
954{
955    ports.clear();
956    NetFirewallPortParam dbport;
957    for (const DataBasePort &param : dbports) {
958        dbport.startPort = param.startPort;
959        dbport.endPort = param.endPort;
960        ports.emplace_back(std::move(dbport));
961    }
962}
963} // namespace NetManagerStandard
964} // namespace OHOS