1 /*
2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "netsys_sock_client.h"
17
18 #include <cerrno>
19 #include <atomic>
20 #include <mutex>
21 #include <sys/socket.h>
22 #include <sys/un.h>
23 #include <unistd.h>
24
25 #include "fwmark_client.h"
26 #include "net_manager_constants.h"
27 #include "netnative_log_wrapper.h"
28
29 namespace {
30 SocketDispatchType defaultSocketDispatchType;
31 std::atomic_int g_netIdForApp(0);
32 std::atomic<const SocketDispatchType*> g_dispatch(&defaultSocketDispatchType);
33 std::atomic_bool g_hookFlag(false);
34 std::once_flag g_onceFlag;
GetDispatch()35 const SocketDispatchType* GetDispatch()
36 {
37 return g_dispatch.load(std::memory_order_relaxed);
38 }
39 } // namespace
40
HookSocket(int (*fn)(int, int, int), int domain, int type, int protocol)41 int HookSocket(int (*fn)(int, int, int), int domain, int type, int protocol)
42 {
43 int fd = -1;
44 if (fn) {
45 fd = fn(domain, type, protocol);
46 }
47
48 if (fd < 0) {
49 return fd;
50 }
51
52 if (g_netIdForApp > 0 && (domain == AF_INET || domain == AF_INET6)) {
53 if (OHOS::nmd::FwmarkClient().BindSocket(fd, g_netIdForApp) != OHOS::NetManagerStandard::NETMANAGER_SUCCESS) {
54 NETNATIVE_LOGE("BindSocket [%{public}d] to netid [%{public}d] failed",
55 fd, g_netIdForApp.load(std::memory_order_relaxed));
56 return -1;
57 }
58 }
59
60 return fd;
61 }
62
ohos_socket_hook_initialize(const SocketDispatchType* disptch, bool*, const char*)63 bool ohos_socket_hook_initialize(const SocketDispatchType* disptch, bool*, const char*)
64 {
65 std::call_once(g_onceFlag, [&]() {
66 g_dispatch.store(disptch);
67 g_hookFlag = true;
68 });
69 return true;
70 }
71
ohos_socket_hook_finalize(void)72 void ohos_socket_hook_finalize(void)
73 {
74 g_hookFlag = false;
75 }
76
ohos_socket_hook_socket(int domain, int type, int protocol)77 int ohos_socket_hook_socket(int domain, int type, int protocol)
78 {
79 return HookSocket(GetDispatch()->socket, domain, type, protocol);
80 }
81
ohos_socket_hook_get_hook_flag(void)82 bool ohos_socket_hook_get_hook_flag(void)
83 {
84 return g_hookFlag;
85 }
86
ohos_socket_hook_set_hook_flag(bool flag)87 bool ohos_socket_hook_set_hook_flag(bool flag)
88 {
89 g_hookFlag = flag;
90 return true;
91 }
92
SetNetForApp(int netId)93 void SetNetForApp(int netId)
94 {
95 g_netIdForApp = netId;
96 }
97
GetNetForApp()98 int GetNetForApp()
99 {
100 return g_netIdForApp;
101 }
102