1 /*
2  * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "netsys_sock_client.h"
17 
18 #include <cerrno>
19 #include <atomic>
20 #include <mutex>
21 #include <sys/socket.h>
22 #include <sys/un.h>
23 #include <unistd.h>
24 
25 #include "fwmark_client.h"
26 #include "net_manager_constants.h"
27 #include "netnative_log_wrapper.h"
28 
29 namespace {
30 SocketDispatchType defaultSocketDispatchType;
31 std::atomic_int g_netIdForApp(0);
32 std::atomic<const SocketDispatchType*> g_dispatch(&defaultSocketDispatchType);
33 std::atomic_bool g_hookFlag(false);
34 std::once_flag g_onceFlag;
GetDispatch()35 const SocketDispatchType* GetDispatch()
36 {
37     return g_dispatch.load(std::memory_order_relaxed);
38 }
39 } // namespace
40 
HookSocket(int (*fn)(int, int, int), int domain, int type, int protocol)41 int HookSocket(int (*fn)(int, int, int), int domain, int type, int protocol)
42 {
43     int fd = -1;
44     if (fn) {
45         fd = fn(domain, type, protocol);
46     }
47 
48     if (fd < 0) {
49         return fd;
50     }
51 
52     if (g_netIdForApp > 0 && (domain == AF_INET || domain == AF_INET6)) {
53         if (OHOS::nmd::FwmarkClient().BindSocket(fd, g_netIdForApp) != OHOS::NetManagerStandard::NETMANAGER_SUCCESS) {
54             NETNATIVE_LOGE("BindSocket [%{public}d] to netid [%{public}d] failed",
55                 fd, g_netIdForApp.load(std::memory_order_relaxed));
56             return -1;
57         }
58     }
59 
60     return fd;
61 }
62 
ohos_socket_hook_initialize(const SocketDispatchType* disptch, bool*, const char*)63 bool ohos_socket_hook_initialize(const SocketDispatchType* disptch, bool*, const char*)
64 {
65     std::call_once(g_onceFlag, [&]() {
66         g_dispatch.store(disptch);
67         g_hookFlag = true;
68     });
69     return true;
70 }
71 
ohos_socket_hook_finalize(void)72 void ohos_socket_hook_finalize(void)
73 {
74     g_hookFlag = false;
75 }
76 
ohos_socket_hook_socket(int domain, int type, int protocol)77 int ohos_socket_hook_socket(int domain, int type, int protocol)
78 {
79     return HookSocket(GetDispatch()->socket, domain, type, protocol);
80 }
81 
ohos_socket_hook_get_hook_flag(void)82 bool ohos_socket_hook_get_hook_flag(void)
83 {
84     return g_hookFlag;
85 }
86 
ohos_socket_hook_set_hook_flag(bool flag)87 bool ohos_socket_hook_set_hook_flag(bool flag)
88 {
89     g_hookFlag = flag;
90     return true;
91 }
92 
SetNetForApp(int netId)93 void SetNetForApp(int netId)
94 {
95     g_netIdForApp = netId;
96 }
97 
GetNetForApp()98 int GetNetForApp()
99 {
100     return g_netIdForApp;
101 }
102