1/*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#ifdef OHOS_SOCKET_HOOK_ENABLE
17#include <unistd.h>
18#include <signal.h>
19#include <stdlib.h>
20#include <limits.h>
21#include <dlfcn.h>
22#include <errno.h>
23#include <ctype.h>
24#include <assert.h>
25#include <string.h>
26#include <stdio.h>
27#include "musl_socket_preinit_common.h"
28#include "musl_log.h"
29
30static char *__socket_hook_shared_lib = "libfwmark_client.z.so";
31static char *__socket_hook_function_prefix = "ohos_socket_hook";
32void* shared_lib_func[LAST_FUNC];
33long long __ohos_socket_hook_shared_library;
34typedef bool (*init_func_type)(const struct SocketDispatchType*, bool*, const char*);
35typedef void (*finalize_func_type)();
36#define MAX_SYMBOL_SIZE 1000
37
38static bool init_socket_function(void* shared_library_handler, SocketSocketType* func)
39{
40	char symbol[MAX_SYMBOL_SIZE];
41	(void)snprintf(symbol, sizeof(symbol), "%s_%s", __socket_hook_function_prefix, "socket");
42	*func = (SocketSocketType)(dlsym(shared_library_handler, symbol));
43	if (*func == NULL) {
44		return false;
45	}
46	return true;
47}
48
49static void clear_socket_function()
50{
51	memset(shared_lib_func, 0, sizeof(shared_lib_func));
52}
53
54static void socket_finalize()
55{
56	((finalize_func_type)shared_lib_func[FINALIZE_FUNC])();
57	__current_dispatch = 0;
58	__socket_hook_begin_flag = false;
59	// Don't dlclose because hidumper crash
60}
61
62static bool finish_install_ohos_socket_hooks(const char* options)
63{
64	init_func_type init_func = (init_func_type)(shared_lib_func[INITIALIZE_FUNC]);
65	if (!init_func(&__libc_socket_default_dispatch, NULL, options)) {
66		MUSL_LOGI("Netsys, init_func failed.");
67		clear_socket_function();
68		return false;
69	}
70
71	int ret_value = atexit(socket_finalize);
72	if (ret_value != 0) {
73		MUSL_LOGI("Netsys, set atexit failed.");
74	}
75	return true;
76}
77
78static bool init_socket_hook_shared_library(void* shared_library_handle)
79{
80	static const char* names[] = {
81		"initialize",
82		"finalize",
83		"get_hook_flag",
84		"set_hook_flag",
85	};
86
87	for (int i = 0; i < LAST_FUNC; i++) {
88		char symbol[MAX_SYMBOL_SIZE];
89		(void)snprintf(symbol, sizeof(symbol), "%s_%s", __socket_hook_function_prefix, names[i]);
90		shared_lib_func[i] = dlsym(shared_library_handle, symbol);
91		if (shared_lib_func[i] == NULL) {
92			clear_socket_function();
93			return false;
94		}
95	}
96
97	if (!init_socket_function(shared_library_handle, &(__musl_libc_socket_dispatch.socket))) {
98		MUSL_LOGI("Netsys, set socket function failed.");
99		clear_socket_function();
100		return false;
101	}
102
103	return true;
104}
105
106static void* load_socket_hook_shared_library()
107{
108	void* shared_library_handle = NULL;
109
110	shared_library_handle = dlopen(__socket_hook_shared_lib, RTLD_NOW | RTLD_LOCAL);
111
112	if (shared_library_handle == NULL) {
113		MUSL_LOGI("Netsys, Unable to open shared library %s: %s.\n", __socket_hook_shared_lib, dlerror());
114		return NULL;
115	}
116
117	if (!init_socket_hook_shared_library(shared_library_handle)) {
118		dlclose(shared_library_handle);
119		shared_library_handle = NULL;
120	}
121	return shared_library_handle;
122}
123
124static void install_ohos_socket_hook()
125{
126	void* shared_library_handle = (void *)__ohos_socket_hook_shared_library;
127	if (shared_library_handle != NULL && shared_library_handle != (void*)-1) {
128		MUSL_LOGI("Netsys, ohos_socket_hook_shared_library has had.");
129		return;
130	}
131
132	__current_dispatch = 0;
133	shared_library_handle = load_socket_hook_shared_library();
134	if (shared_library_handle == NULL) {
135		MUSL_LOGI("Netsys, load_socket_hook_shared_library failed.");
136		return;
137	}
138	MUSL_LOGI("Netsys, load_socket_hook_shared_library success.");
139
140	if (finish_install_ohos_socket_hooks(NULL)) {
141		MUSL_LOGI("Netsys, finish_install_ohos_socket_hooks success.");
142		__ohos_socket_hook_shared_library = (long long)shared_library_handle;
143		__current_dispatch = (long long)(&__musl_libc_socket_dispatch);
144	} else {
145		MUSL_LOGI("Netsys, finish_install_ohos_socket_hooks failed.");
146		__ohos_socket_hook_shared_library = 0;
147		dlclose((void *)shared_library_handle);
148	}
149}
150
151static void init_ohos_socket_hook()
152{
153	install_ohos_socket_hook();
154}
155
156__attribute__((constructor())) static void __musl_socket_initialize()
157{
158	bool begin_flag = __get_socket_hook_begin_flag();
159	MUSL_LOGI("Netsys, %d begin musl_socket_initialize, flag %d.\n", getpid(), begin_flag);
160	if (!begin_flag) {
161		__socket_hook_begin_flag = true;
162		init_ohos_socket_hook();
163	}
164}
165#endif
166