1// SPDX-License-Identifier: GPL-2.0+
2/*
3 * Copyright (c) 2024 Huawei Device Co., Ltd.
4 *
5 * Operations on the lowpower protocol
6 * Authors: yangyanjun
7 */
8#ifdef CONFIG_LOWPOWER_PROTOCOL
9#include <linux/types.h>
10#include <linux/kernel.h>
11#include <linux/proc_fs.h>
12#include <linux/printk.h>
13#include <linux/list.h>
14#include <linux/rwlock_types.h>
15#include <linux/net_namespace.h>
16#include <net/sock.h>
17#include <net/ip.h>
18#include <net/tcp.h>
19#include <net/lowpower_protocol.h>
20
21static atomic_t g_foreground_uid = ATOMIC_INIT(FOREGROUND_UID_INIT);
22#define OPT_LEN 3
23#define TO_DECIMAL 10
24#define LIST_MAX 500
25#define DECIMAL_CHAR_NUM 10 // u32 decimal characters (4,294,967,295)
26static DEFINE_RWLOCK(g_dpa_rwlock);
27static u32 g_dpa_uid_list_cnt;
28static struct list_head g_dpa_uid_list;
29struct dpa_node {
30	struct list_head list_node;
31	uid_t uid;
32};
33
34static void foreground_uid_atomic_set(uid_t val)
35{
36	atomic_set(&g_foreground_uid, val);
37}
38
39static uid_t foreground_uid_atomic_read(void)
40{
41	return (uid_t)atomic_read(&g_foreground_uid);
42}
43
44// cat /proc/net/foreground_uid
45static int foreground_uid_show(struct seq_file *seq, void *v)
46{
47	uid_t uid = foreground_uid_atomic_read();
48
49	seq_printf(seq, "%u\n", uid);
50	return 0;
51}
52
53// echo xx > /proc/net/foreground_uid
54static int foreground_uid_write(struct file *file, char *buf, size_t size)
55{
56	char *p = buf;
57	uid_t uid = simple_strtoul(p, &p, TO_DECIMAL);
58
59	if (!p)
60		return -EINVAL;
61
62	foreground_uid_atomic_set(uid);
63	return 0;
64}
65
66// cat /proc/net/dpa_uid
67static int dpa_uid_show(struct seq_file *seq, void *v)
68{
69	struct dpa_node *node = NULL;
70	struct dpa_node *tmp_node = NULL;
71
72	read_lock(&g_dpa_rwlock);
73	seq_printf(seq, "uid list num: %u\n", g_dpa_uid_list_cnt);
74	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node)
75		seq_printf(seq, "%u\n", node->uid);
76	read_unlock(&g_dpa_rwlock);
77	return 0;
78}
79
80// echo "add xx yy zz" > /proc/net/dpa_uid
81// echo "del xx yy zz" > /proc/net/dpa_uid
82static int dpa_uid_add(uid_t uid);
83static int dpa_uid_del(uid_t uid);
84static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
85			u32 index_max, u32 *index);
86static int dpa_uid_write(struct file *file, char *buf, size_t size)
87{
88	u32 dpa_list[LIST_MAX];
89	u32 index = 0;
90	int ret = -EINVAL;
91	int i;
92
93	if (get_dpa_uids(buf, size, dpa_list, LIST_MAX, &index) != 0) {
94		pr_err("[dpa-uid-cfg] fail to parse dpa uids\n");
95		return ret;
96	}
97
98	if (strncmp(buf, "add", OPT_LEN) == 0) {
99		for (i = 0; i < index; i++) {
100			ret = dpa_uid_add(dpa_list[i]);
101			if (ret != 0) {
102				pr_err("[dpa-uid-cfg] add fail, index=%u\n", i);
103				return ret;
104			}
105		}
106	} else if (strncmp(buf, "del", OPT_LEN) == 0) {
107		for (i = 0; i < index; i++) {
108			ret = dpa_uid_del(dpa_list[i]);
109			if (ret != 0) {
110				pr_err("[dpa-uid-cfg] del fail, index=%u\n", i);
111				return ret;
112			}
113		}
114	} else {
115		pr_err("[dpa-uid-cfg] cmd unknown\n");
116	}
117	return ret;
118}
119
120static int dpa_uid_add(uid_t uid)
121{
122	bool exist = false;
123	struct dpa_node *node = NULL;
124	struct dpa_node *tmp_node = NULL;
125
126	write_lock(&g_dpa_rwlock);
127	if (g_dpa_uid_list_cnt >= LIST_MAX) {
128		write_unlock(&g_dpa_rwlock);
129		return -EFBIG;
130	}
131
132	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
133		if (node->uid == uid) {
134			exist = true;
135			break;
136		}
137	}
138
139	if (!exist) {
140		node = kzalloc(sizeof(*node), GFP_ATOMIC);
141		if (node) {
142			node->uid = uid;
143			list_add_tail(&node->list_node, &g_dpa_uid_list);
144			g_dpa_uid_list_cnt++;
145		}
146	}
147	write_unlock(&g_dpa_rwlock);
148	return 0;
149}
150
151static int dpa_uid_del(uid_t uid)
152{
153	struct dpa_node *node = NULL;
154	struct dpa_node *tmp_node = NULL;
155
156	write_lock(&g_dpa_rwlock);
157	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
158		if (node->uid == uid) {
159			list_del(&node->list_node);
160			if (g_dpa_uid_list_cnt)
161				--g_dpa_uid_list_cnt;
162			break;
163		}
164	}
165	write_unlock(&g_dpa_rwlock);
166	return 0;
167}
168
169static uid_t parse_single_uid(char *begin, char *end)
170{
171	char *cur = NULL;
172	uid_t uid = 0;
173	u32 len = end - begin;
174
175	// u32 decimal characters (4,294,967,295)
176	if (len > DECIMAL_CHAR_NUM) {
177		pr_err("[dpa-uid-cfg] single uid len(%u) overflow\n", len);
178		return uid;
179	}
180
181	cur = begin;
182	while (cur < end) {
183		if (*cur < '0' || *cur > '9') {
184			pr_err("[dpa-uid-cfg] invalid character '%c'\n", *cur);
185			return uid;
186		}
187		cur++;
188	}
189
190	uid = simple_strtoul(begin, &begin, TO_DECIMAL);
191	if (!begin || !uid) {
192		pr_err("[dpa-uid-cfg] fail to change str to data");
193		return uid;
194	}
195
196	return uid;
197}
198
199static int parse_uids(char *args, u32 args_len, u32 *uid_list,
200		      u32 index_max, u32 *index)
201{
202	char *begin = args;
203	char *end = strchr(args, ' ');
204	uid_t uid = 0;
205	u32 len = 0;
206
207	while (end) {
208		// cur decimal characters cnt + ' ' or '\n'
209		len += end - begin + 1;
210		if (len > args_len || *index > index_max) {
211			pr_err("[dpa-uid-cfg] str len(%u) or index(%u) overflow\n",
212			       len, *index);
213			return -EINVAL;
214		}
215
216		uid = parse_single_uid(begin, end);
217		if (!uid)
218			return -EINVAL;
219		uid_list[(*index)++] = uid;
220		begin = ++end; // next decimal characters (skip ' ' or '\n')
221		end = strchr(begin, ' ');
222	}
223
224	// find last uid characters
225	end = strchr(begin, '\n');
226	if (!end) {
227		pr_err("[dpa-uid-cfg] last character is not '\\n'");
228		return -EINVAL;
229	}
230
231	// cur decimal characters cnt + ' ' or '\n'
232	len += end - begin + 1;
233	if (len > args_len || *index > index_max) {
234		pr_err("[dpa-uid-cfg] str len(%u) or last index(%u) overflow\n",
235			len, *index);
236		return -EINVAL;
237	}
238	uid = parse_single_uid(begin, end);
239	if (!uid)
240		return -EINVAL;
241	uid_list[(*index)++] = uid;
242	return 0;
243}
244
245static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
246			u32 index_max, u32 *index)
247{
248	char *args = NULL;
249	u32 opt_len;
250	u32 data_len;
251
252	// split into cmd and argslist
253	args = strchr(buf, ' ');
254	if (!args) {
255		pr_err("[dpa-uid-cfg] cmd fmt invalid\n");
256		return -EINVAL;
257	}
258
259	// cmd is add or del, len is 3
260	opt_len = args - buf;
261	if (opt_len != OPT_LEN) {
262		pr_err("[dpa-uid-cfg] cmd len invalid\n");
263		return -EINVAL;
264	}
265
266	data_len = size - (opt_len + 1);
267	return parse_uids(args + 1, data_len, uid_list, index_max, index);
268}
269
270bool dpa_uid_match(uid_t kuid)
271{
272	bool match = false;
273	struct dpa_node *node = NULL;
274	struct dpa_node *tmp_node = NULL;
275
276	if (kuid == 0)
277		return match;
278
279	read_lock(&g_dpa_rwlock);
280	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
281		if (node->uid == kuid) {
282			match = true;
283			break;
284		}
285	}
286	read_unlock(&g_dpa_rwlock);
287	return match;
288}
289EXPORT_SYMBOL(dpa_uid_match);
290
291// call this fun in net/ipv4/af_inet.c inet_init_net()
292void __net_init lowpower_protocol_net_init(struct net *net)
293{
294	if (!proc_create_net_single_write("foreground_uid", 0644,
295					  net->proc_net,
296					  foreground_uid_show,
297					  foreground_uid_write,
298					  NULL))
299		pr_err("fail to create /proc/net/foreground_uid");
300
301	INIT_LIST_HEAD(&g_dpa_uid_list);
302	if (!proc_create_net_single_write("dpa_uid", 0644,
303					  net->proc_net,
304					  dpa_uid_show,
305					  dpa_uid_write,
306					  NULL))
307		pr_err("fail to create /proc/net/dpa_uid");
308}
309
310static bool foreground_uid_match(struct net *net, struct sock *sk)
311{
312	uid_t kuid;
313	uid_t foreground_uid;
314	struct sock *fullsk;
315
316	if (!net || !sk)
317		return false;
318
319	fullsk = sk_to_full_sk(sk);
320	if (!fullsk || !sk_fullsock(fullsk))
321		return false;
322
323	kuid = sock_net_uid(net, fullsk).val;
324	foreground_uid = foreground_uid_atomic_read();
325	if (kuid != foreground_uid)
326		return false;
327
328	return true;
329}
330
331/*
332 * ack optimization is only enable for large data receiving tasks and
333 * there is no packet loss scenario
334 */
335int tcp_ack_num(struct sock *sk)
336{
337	if (!sk)
338		return 1;
339
340	if (foreground_uid_match(sock_net(sk), sk) == false)
341		return 1;
342
343	if (tcp_sk(sk)->bytes_received >= BIG_DATA_BYTES &&
344	    tcp_sk(sk)->dup_ack_counter < TCP_FASTRETRANS_THRESH)
345		return TCP_ACK_NUM;
346	return 1;
347}
348
349bool netfilter_bypass_enable(struct net *net, struct sk_buff *skb,
350			     int (*fun)(struct net *, struct sock *, struct sk_buff *),
351			     int *ret)
352{
353	if (!net || !skb || !ip_hdr(skb) || ip_hdr(skb)->protocol != IPPROTO_TCP)
354		return false;
355
356	if (foreground_uid_match(net, skb->sk)) {
357		*ret = fun(net, NULL, skb);
358		return true;
359	}
360	return false;
361}
362#endif /* CONFIG_LOWPOWER_PROTOCOL */
363