1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef INCLUDE_NETLINK_SOCK_DIAG_H
17 #define INCLUDE_NETLINK_SOCK_DIAG_H
18 
19 #include <linux/netlink.h>
20 #include <linux/sock_diag.h>
21 #include <linux/inet_diag.h>
22 #include <netinet/in.h>
23 #include <sys/socket.h>
24 #include <string>
25 #include <unistd.h>
26 
27 namespace OHOS {
28 namespace nmd {
29 namespace {
30     enum class SocketDestroyType {
31         DESTROY_DEFAULT_CELLULAR,
32         DESTROY_SPECIAL_CELLULAR,
33         DESTROY_DEFAULT,
34     };
35 }
36 class NetLinkSocketDiag final {
37 public:
38     NetLinkSocketDiag() = default;
39     ~NetLinkSocketDiag();
40     typedef std::function<bool(const inet_diag_msg *)> DestroyFilter;
41 
42     /**
43      * Destroy all 'active' TCP sockets that no longer exist.
44      *
45      * @param ipAddr Network IP address
46      * @param excludeLoopback “true” to exclude loopback.
47      */
48     void DestroyLiveSockets(const char *ipAddr, bool excludeLoopback);
49 
50     /**
51      * This method set the socketDestroyType_, which used to choose the correct socket.
52      * to destroy.
53      * @param netCapabilities Net capabilities in string format.
54      * @return The result of the method is returned.
55      */
56     int32_t SetSocketDestroyType(const std::string &netCapabilities);
57     void DestroyLiveSocketsWithUid(const std::string &ipAddr, uint32_t uid);
58 private:
59     static bool InLookBack(uint32_t a);
60 
61     bool CreateNetlinkSocket();
62     void CloseNetlinkSocket();
63     int32_t ExecuteDestroySocket(uint8_t proto, const inet_diag_msg *msg);
64     int32_t GetErrorFromKernel(int32_t fd);
65     bool IsLoopbackSocket(const inet_diag_msg *msg);
66     bool IsMatchNetwork(const inet_diag_msg *msg, const std::string &ipAddr);
67     int32_t ProcessSockDiagDumpResponse(uint8_t proto, const std::string &ipAddr, bool excludeLoopback);
68     int32_t SendSockDiagDumpRequest(uint8_t proto, uint8_t family, uint32_t states);
69     void SockDiagDumpCallback(uint8_t proto, const inet_diag_msg *msg, const std::string &ipAddr, bool excludeLoopback);
70     void SockDiagUidDumpCallback(uint8_t proto, const inet_diag_msg *msg, const DestroyFilter& destroy);
71     int32_t ProcessSockDiagUidDumpResponse(uint8_t proto, const DestroyFilter& destroy);
72 private:
73     struct SockDiagRequest {
74         nlmsghdr nlh_;
75         inet_diag_req_v2 req_;
76     };
77     struct MarkMatch {
78         inet_diag_bc_op op_;
79         uint32_t mark_;
80         uint32_t mask_;
81     };
82     struct ByteCode {
83         MarkMatch netIdMatch_;
84         MarkMatch controlMatch_;
85         inet_diag_bc_op controlJump_;
86     };
87     struct Ack {
88         nlmsghdr hdr_;
89         nlmsgerr err_;
90     };
91 
92     int32_t dumpSock_ = -1;
93     int32_t destroySock_ = -1;
94     int32_t socketsDestroyed_ = 0;
95     SocketDestroyType socketDestroyType_ = SocketDestroyType::DESTROY_DEFAULT;
96 };
97 } // namespace nmd
98 } // namespace OHOS
99 #endif // INCLUDE_NETLINK_SOCK_DIAG_H