1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * Copyright (c) 2024 Huawei Device Co., Ltd.
4  *
5  * Operations on the lowpower protocol
6  * Authors: yangyanjun
7  */
8 #ifdef CONFIG_LOWPOWER_PROTOCOL
9 #include <linux/types.h>
10 #include <linux/kernel.h>
11 #include <linux/proc_fs.h>
12 #include <linux/printk.h>
13 #include <linux/list.h>
14 #include <linux/rwlock_types.h>
15 #include <linux/net_namespace.h>
16 #include <net/sock.h>
17 #include <net/ip.h>
18 #include <net/tcp.h>
19 #include <net/lowpower_protocol.h>
20 
21 static atomic_t g_foreground_uid = ATOMIC_INIT(FOREGROUND_UID_INIT);
22 #define OPT_LEN 3
23 #define TO_DECIMAL 10
24 #define LIST_MAX 500
25 #define DECIMAL_CHAR_NUM 10 // u32 decimal characters (4,294,967,295)
26 static DEFINE_RWLOCK(g_dpa_rwlock);
27 static u32 g_dpa_uid_list_cnt;
28 static struct list_head g_dpa_uid_list;
29 struct dpa_node {
30 	struct list_head list_node;
31 	uid_t uid;
32 };
33 
foreground_uid_atomic_set(uid_t val)34 static void foreground_uid_atomic_set(uid_t val)
35 {
36 	atomic_set(&g_foreground_uid, val);
37 }
38 
foreground_uid_atomic_read(void)39 static uid_t foreground_uid_atomic_read(void)
40 {
41 	return (uid_t)atomic_read(&g_foreground_uid);
42 }
43 
44 // cat /proc/net/foreground_uid
foreground_uid_show(struct seq_file *seq, void *v)45 static int foreground_uid_show(struct seq_file *seq, void *v)
46 {
47 	uid_t uid = foreground_uid_atomic_read();
48 
49 	seq_printf(seq, "%u\n", uid);
50 	return 0;
51 }
52 
53 // echo xx > /proc/net/foreground_uid
foreground_uid_write(struct file *file, char *buf, size_t size)54 static int foreground_uid_write(struct file *file, char *buf, size_t size)
55 {
56 	char *p = buf;
57 	uid_t uid = simple_strtoul(p, &p, TO_DECIMAL);
58 
59 	if (!p)
60 		return -EINVAL;
61 
62 	foreground_uid_atomic_set(uid);
63 	return 0;
64 }
65 
66 // cat /proc/net/dpa_uid
dpa_uid_show(struct seq_file *seq, void *v)67 static int dpa_uid_show(struct seq_file *seq, void *v)
68 {
69 	struct dpa_node *node = NULL;
70 	struct dpa_node *tmp_node = NULL;
71 
72 	read_lock(&g_dpa_rwlock);
73 	seq_printf(seq, "uid list num: %u\n", g_dpa_uid_list_cnt);
74 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node)
75 		seq_printf(seq, "%u\n", node->uid);
76 	read_unlock(&g_dpa_rwlock);
77 	return 0;
78 }
79 
80 // echo "add xx yy zz" > /proc/net/dpa_uid
81 // echo "del xx yy zz" > /proc/net/dpa_uid
82 static int dpa_uid_add(uid_t uid);
83 static int dpa_uid_del(uid_t uid);
84 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
85 			u32 index_max, u32 *index);
dpa_uid_write(struct file *file, char *buf, size_t size)86 static int dpa_uid_write(struct file *file, char *buf, size_t size)
87 {
88 	u32 dpa_list[LIST_MAX];
89 	u32 index = 0;
90 	int ret = -EINVAL;
91 	int i;
92 
93 	if (get_dpa_uids(buf, size, dpa_list, LIST_MAX, &index) != 0) {
94 		pr_err("[dpa-uid-cfg] fail to parse dpa uids\n");
95 		return ret;
96 	}
97 
98 	if (strncmp(buf, "add", OPT_LEN) == 0) {
99 		for (i = 0; i < index; i++) {
100 			ret = dpa_uid_add(dpa_list[i]);
101 			if (ret != 0) {
102 				pr_err("[dpa-uid-cfg] add fail, index=%u\n", i);
103 				return ret;
104 			}
105 		}
106 	} else if (strncmp(buf, "del", OPT_LEN) == 0) {
107 		for (i = 0; i < index; i++) {
108 			ret = dpa_uid_del(dpa_list[i]);
109 			if (ret != 0) {
110 				pr_err("[dpa-uid-cfg] del fail, index=%u\n", i);
111 				return ret;
112 			}
113 		}
114 	} else {
115 		pr_err("[dpa-uid-cfg] cmd unknown\n");
116 	}
117 	return ret;
118 }
119 
dpa_uid_add(uid_t uid)120 static int dpa_uid_add(uid_t uid)
121 {
122 	bool exist = false;
123 	struct dpa_node *node = NULL;
124 	struct dpa_node *tmp_node = NULL;
125 
126 	write_lock(&g_dpa_rwlock);
127 	if (g_dpa_uid_list_cnt >= LIST_MAX) {
128 		write_unlock(&g_dpa_rwlock);
129 		return -EFBIG;
130 	}
131 
132 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
133 		if (node->uid == uid) {
134 			exist = true;
135 			break;
136 		}
137 	}
138 
139 	if (!exist) {
140 		node = kzalloc(sizeof(*node), GFP_ATOMIC);
141 		if (node) {
142 			node->uid = uid;
143 			list_add_tail(&node->list_node, &g_dpa_uid_list);
144 			g_dpa_uid_list_cnt++;
145 		}
146 	}
147 	write_unlock(&g_dpa_rwlock);
148 	return 0;
149 }
150 
dpa_uid_del(uid_t uid)151 static int dpa_uid_del(uid_t uid)
152 {
153 	struct dpa_node *node = NULL;
154 	struct dpa_node *tmp_node = NULL;
155 
156 	write_lock(&g_dpa_rwlock);
157 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
158 		if (node->uid == uid) {
159 			list_del(&node->list_node);
160 			if (g_dpa_uid_list_cnt)
161 				--g_dpa_uid_list_cnt;
162 			break;
163 		}
164 	}
165 	write_unlock(&g_dpa_rwlock);
166 	return 0;
167 }
168 
parse_single_uid(char *begin, char *end)169 static uid_t parse_single_uid(char *begin, char *end)
170 {
171 	char *cur = NULL;
172 	uid_t uid = 0;
173 	u32 len = end - begin;
174 
175 	// u32 decimal characters (4,294,967,295)
176 	if (len > DECIMAL_CHAR_NUM) {
177 		pr_err("[dpa-uid-cfg] single uid len(%u) overflow\n", len);
178 		return uid;
179 	}
180 
181 	cur = begin;
182 	while (cur < end) {
183 		if (*cur < '0' || *cur > '9') {
184 			pr_err("[dpa-uid-cfg] invalid character '%c'\n", *cur);
185 			return uid;
186 		}
187 		cur++;
188 	}
189 
190 	uid = simple_strtoul(begin, &begin, TO_DECIMAL);
191 	if (!begin || !uid) {
192 		pr_err("[dpa-uid-cfg] fail to change str to data");
193 		return uid;
194 	}
195 
196 	return uid;
197 }
198 
parse_uids(char *args, u32 args_len, u32 *uid_list, u32 index_max, u32 *index)199 static int parse_uids(char *args, u32 args_len, u32 *uid_list,
200 		      u32 index_max, u32 *index)
201 {
202 	char *begin = args;
203 	char *end = strchr(args, ' ');
204 	uid_t uid = 0;
205 	u32 len = 0;
206 
207 	while (end) {
208 		// cur decimal characters cnt + ' ' or '\n'
209 		len += end - begin + 1;
210 		if (len > args_len || *index > index_max) {
211 			pr_err("[dpa-uid-cfg] str len(%u) or index(%u) overflow\n",
212 			       len, *index);
213 			return -EINVAL;
214 		}
215 
216 		uid = parse_single_uid(begin, end);
217 		if (!uid)
218 			return -EINVAL;
219 		uid_list[(*index)++] = uid;
220 		begin = ++end; // next decimal characters (skip ' ' or '\n')
221 		end = strchr(begin, ' ');
222 	}
223 
224 	// find last uid characters
225 	end = strchr(begin, '\n');
226 	if (!end) {
227 		pr_err("[dpa-uid-cfg] last character is not '\\n'");
228 		return -EINVAL;
229 	}
230 
231 	// cur decimal characters cnt + ' ' or '\n'
232 	len += end - begin + 1;
233 	if (len > args_len || *index > index_max) {
234 		pr_err("[dpa-uid-cfg] str len(%u) or last index(%u) overflow\n",
235 			len, *index);
236 		return -EINVAL;
237 	}
238 	uid = parse_single_uid(begin, end);
239 	if (!uid)
240 		return -EINVAL;
241 	uid_list[(*index)++] = uid;
242 	return 0;
243 }
244 
get_dpa_uids(char *buf, size_t size, u32 *uid_list, u32 index_max, u32 *index)245 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
246 			u32 index_max, u32 *index)
247 {
248 	char *args = NULL;
249 	u32 opt_len;
250 	u32 data_len;
251 
252 	// split into cmd and argslist
253 	args = strchr(buf, ' ');
254 	if (!args) {
255 		pr_err("[dpa-uid-cfg] cmd fmt invalid\n");
256 		return -EINVAL;
257 	}
258 
259 	// cmd is add or del, len is 3
260 	opt_len = args - buf;
261 	if (opt_len != OPT_LEN) {
262 		pr_err("[dpa-uid-cfg] cmd len invalid\n");
263 		return -EINVAL;
264 	}
265 
266 	data_len = size - (opt_len + 1);
267 	return parse_uids(args + 1, data_len, uid_list, index_max, index);
268 }
269 
dpa_uid_match(uid_t kuid)270 bool dpa_uid_match(uid_t kuid)
271 {
272 	bool match = false;
273 	struct dpa_node *node = NULL;
274 	struct dpa_node *tmp_node = NULL;
275 
276 	if (kuid == 0)
277 		return match;
278 
279 	read_lock(&g_dpa_rwlock);
280 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
281 		if (node->uid == kuid) {
282 			match = true;
283 			break;
284 		}
285 	}
286 	read_unlock(&g_dpa_rwlock);
287 	return match;
288 }
289 EXPORT_SYMBOL(dpa_uid_match);
290 
291 // call this fun in net/ipv4/af_inet.c inet_init_net()
lowpower_protocol_net_init(struct net *net)292 void __net_init lowpower_protocol_net_init(struct net *net)
293 {
294 	if (!proc_create_net_single_write("foreground_uid", 0644,
295 					  net->proc_net,
296 					  foreground_uid_show,
297 					  foreground_uid_write,
298 					  NULL))
299 		pr_err("fail to create /proc/net/foreground_uid");
300 
301 	INIT_LIST_HEAD(&g_dpa_uid_list);
302 	if (!proc_create_net_single_write("dpa_uid", 0644,
303 					  net->proc_net,
304 					  dpa_uid_show,
305 					  dpa_uid_write,
306 					  NULL))
307 		pr_err("fail to create /proc/net/dpa_uid");
308 }
309 
foreground_uid_match(struct net *net, struct sock *sk)310 static bool foreground_uid_match(struct net *net, struct sock *sk)
311 {
312 	uid_t kuid;
313 	uid_t foreground_uid;
314 	struct sock *fullsk;
315 
316 	if (!net || !sk)
317 		return false;
318 
319 	fullsk = sk_to_full_sk(sk);
320 	if (!fullsk || !sk_fullsock(fullsk))
321 		return false;
322 
323 	kuid = sock_net_uid(net, fullsk).val;
324 	foreground_uid = foreground_uid_atomic_read();
325 	if (kuid != foreground_uid)
326 		return false;
327 
328 	return true;
329 }
330 
331 /*
332  * ack optimization is only enable for large data receiving tasks and
333  * there is no packet loss scenario
334  */
tcp_ack_num(struct sock *sk)335 int tcp_ack_num(struct sock *sk)
336 {
337 	if (!sk)
338 		return 1;
339 
340 	if (foreground_uid_match(sock_net(sk), sk) == false)
341 		return 1;
342 
343 	if (tcp_sk(sk)->bytes_received >= BIG_DATA_BYTES &&
344 	    tcp_sk(sk)->dup_ack_counter < TCP_FASTRETRANS_THRESH)
345 		return TCP_ACK_NUM;
346 	return 1;
347 }
348 
netfilter_bypass_enable(struct net *net, struct sk_buff *skb, int (*fun)(struct net *, struct sock *, struct sk_buff *), int *ret)349 bool netfilter_bypass_enable(struct net *net, struct sk_buff *skb,
350 			     int (*fun)(struct net *, struct sock *, struct sk_buff *),
351 			     int *ret)
352 {
353 	if (!net || !skb || !ip_hdr(skb) || ip_hdr(skb)->protocol != IPPROTO_TCP)
354 		return false;
355 
356 	if (foreground_uid_match(net, skb->sk)) {
357 		*ret = fun(net, NULL, skb);
358 		return true;
359 	}
360 	return false;
361 }
362 #endif /* CONFIG_LOWPOWER_PROTOCOL */
363