1// SPDX-License-Identifier: GPL-2.0
2/*
3 * Copyright (c) 2023 Oracle and/or its affiliates.
4 *
5 * KUnit test of the handshake upcall mechanism.
6 */
7
8#include <kunit/test.h>
9#include <kunit/visibility.h>
10
11#include <linux/kernel.h>
12
13#include <net/sock.h>
14#include <net/genetlink.h>
15#include <net/netns/generic.h>
16
17#include <uapi/linux/handshake.h>
18#include "handshake.h"
19
20MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING);
21
22static int test_accept_func(struct handshake_req *req, struct genl_info *info,
23			    int fd)
24{
25	return 0;
26}
27
28static void test_done_func(struct handshake_req *req, unsigned int status,
29			   struct genl_info *info)
30{
31}
32
33struct handshake_req_alloc_test_param {
34	const char			*desc;
35	struct handshake_proto		*proto;
36	gfp_t				gfp;
37	bool				expect_success;
38};
39
40static struct handshake_proto handshake_req_alloc_proto_2 = {
41	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_NONE,
42};
43
44static struct handshake_proto handshake_req_alloc_proto_3 = {
45	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_MAX,
46};
47
48static struct handshake_proto handshake_req_alloc_proto_4 = {
49	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
50};
51
52static struct handshake_proto handshake_req_alloc_proto_5 = {
53	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
54	.hp_accept		= test_accept_func,
55};
56
57static struct handshake_proto handshake_req_alloc_proto_6 = {
58	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
59	.hp_privsize		= UINT_MAX,
60	.hp_accept		= test_accept_func,
61	.hp_done		= test_done_func,
62};
63
64static struct handshake_proto handshake_req_alloc_proto_good = {
65	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
66	.hp_accept		= test_accept_func,
67	.hp_done		= test_done_func,
68};
69
70static const
71struct handshake_req_alloc_test_param handshake_req_alloc_params[] = {
72	{
73		.desc			= "handshake_req_alloc NULL proto",
74		.proto			= NULL,
75		.gfp			= GFP_KERNEL,
76		.expect_success		= false,
77	},
78	{
79		.desc			= "handshake_req_alloc CLASS_NONE",
80		.proto			= &handshake_req_alloc_proto_2,
81		.gfp			= GFP_KERNEL,
82		.expect_success		= false,
83	},
84	{
85		.desc			= "handshake_req_alloc CLASS_MAX",
86		.proto			= &handshake_req_alloc_proto_3,
87		.gfp			= GFP_KERNEL,
88		.expect_success		= false,
89	},
90	{
91		.desc			= "handshake_req_alloc no callbacks",
92		.proto			= &handshake_req_alloc_proto_4,
93		.gfp			= GFP_KERNEL,
94		.expect_success		= false,
95	},
96	{
97		.desc			= "handshake_req_alloc no done callback",
98		.proto			= &handshake_req_alloc_proto_5,
99		.gfp			= GFP_KERNEL,
100		.expect_success		= false,
101	},
102	{
103		.desc			= "handshake_req_alloc excessive privsize",
104		.proto			= &handshake_req_alloc_proto_6,
105		.gfp			= GFP_KERNEL | __GFP_NOWARN,
106		.expect_success		= false,
107	},
108	{
109		.desc			= "handshake_req_alloc all good",
110		.proto			= &handshake_req_alloc_proto_good,
111		.gfp			= GFP_KERNEL,
112		.expect_success		= true,
113	},
114};
115
116static void
117handshake_req_alloc_get_desc(const struct handshake_req_alloc_test_param *param,
118			     char *desc)
119{
120	strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE);
121}
122
123/* Creates the function handshake_req_alloc_gen_params */
124KUNIT_ARRAY_PARAM(handshake_req_alloc, handshake_req_alloc_params,
125		  handshake_req_alloc_get_desc);
126
127static void handshake_req_alloc_case(struct kunit *test)
128{
129	const struct handshake_req_alloc_test_param *param = test->param_value;
130	struct handshake_req *result;
131
132	/* Arrange */
133
134	/* Act */
135	result = handshake_req_alloc(param->proto, param->gfp);
136
137	/* Assert */
138	if (param->expect_success)
139		KUNIT_EXPECT_NOT_NULL(test, result);
140	else
141		KUNIT_EXPECT_NULL(test, result);
142
143	kfree(result);
144}
145
146static void handshake_req_submit_test1(struct kunit *test)
147{
148	struct socket *sock;
149	int err, result;
150
151	/* Arrange */
152	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
153			    &sock, 1);
154	KUNIT_ASSERT_EQ(test, err, 0);
155
156	/* Act */
157	result = handshake_req_submit(sock, NULL, GFP_KERNEL);
158
159	/* Assert */
160	KUNIT_EXPECT_EQ(test, result, -EINVAL);
161
162	sock_release(sock);
163}
164
165static void handshake_req_submit_test2(struct kunit *test)
166{
167	struct handshake_req *req;
168	int result;
169
170	/* Arrange */
171	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
172	KUNIT_ASSERT_NOT_NULL(test, req);
173
174	/* Act */
175	result = handshake_req_submit(NULL, req, GFP_KERNEL);
176
177	/* Assert */
178	KUNIT_EXPECT_EQ(test, result, -EINVAL);
179
180	/* handshake_req_submit() destroys @req on error */
181}
182
183static void handshake_req_submit_test3(struct kunit *test)
184{
185	struct handshake_req *req;
186	struct socket *sock;
187	int err, result;
188
189	/* Arrange */
190	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
191	KUNIT_ASSERT_NOT_NULL(test, req);
192
193	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
194			    &sock, 1);
195	KUNIT_ASSERT_EQ(test, err, 0);
196	sock->file = NULL;
197
198	/* Act */
199	result = handshake_req_submit(sock, req, GFP_KERNEL);
200
201	/* Assert */
202	KUNIT_EXPECT_EQ(test, result, -EINVAL);
203
204	/* handshake_req_submit() destroys @req on error */
205	sock_release(sock);
206}
207
208static void handshake_req_submit_test4(struct kunit *test)
209{
210	struct handshake_req *req, *result;
211	struct socket *sock;
212	struct file *filp;
213	int err;
214
215	/* Arrange */
216	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
217	KUNIT_ASSERT_NOT_NULL(test, req);
218
219	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
220			    &sock, 1);
221	KUNIT_ASSERT_EQ(test, err, 0);
222	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
223	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
224	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
225	sock->file = filp;
226
227	err = handshake_req_submit(sock, req, GFP_KERNEL);
228	KUNIT_ASSERT_EQ(test, err, 0);
229
230	/* Act */
231	result = handshake_req_hash_lookup(sock->sk);
232
233	/* Assert */
234	KUNIT_EXPECT_NOT_NULL(test, result);
235	KUNIT_EXPECT_PTR_EQ(test, req, result);
236
237	handshake_req_cancel(sock->sk);
238	fput(filp);
239}
240
241static void handshake_req_submit_test5(struct kunit *test)
242{
243	struct handshake_req *req;
244	struct handshake_net *hn;
245	struct socket *sock;
246	struct file *filp;
247	struct net *net;
248	int saved, err;
249
250	/* Arrange */
251	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
252	KUNIT_ASSERT_NOT_NULL(test, req);
253
254	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
255			    &sock, 1);
256	KUNIT_ASSERT_EQ(test, err, 0);
257	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
258	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
259	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
260	sock->file = filp;
261
262	net = sock_net(sock->sk);
263	hn = handshake_pernet(net);
264	KUNIT_ASSERT_NOT_NULL(test, hn);
265
266	saved = hn->hn_pending;
267	hn->hn_pending = hn->hn_pending_max + 1;
268
269	/* Act */
270	err = handshake_req_submit(sock, req, GFP_KERNEL);
271
272	/* Assert */
273	KUNIT_EXPECT_EQ(test, err, -EAGAIN);
274
275	fput(filp);
276	hn->hn_pending = saved;
277}
278
279static void handshake_req_submit_test6(struct kunit *test)
280{
281	struct handshake_req *req1, *req2;
282	struct socket *sock;
283	struct file *filp;
284	int err;
285
286	/* Arrange */
287	req1 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
288	KUNIT_ASSERT_NOT_NULL(test, req1);
289	req2 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
290	KUNIT_ASSERT_NOT_NULL(test, req2);
291
292	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
293			    &sock, 1);
294	KUNIT_ASSERT_EQ(test, err, 0);
295	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
296	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
297	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
298	sock->file = filp;
299
300	/* Act */
301	err = handshake_req_submit(sock, req1, GFP_KERNEL);
302	KUNIT_ASSERT_EQ(test, err, 0);
303	err = handshake_req_submit(sock, req2, GFP_KERNEL);
304
305	/* Assert */
306	KUNIT_EXPECT_EQ(test, err, -EBUSY);
307
308	handshake_req_cancel(sock->sk);
309	fput(filp);
310}
311
312static void handshake_req_cancel_test1(struct kunit *test)
313{
314	struct handshake_req *req;
315	struct socket *sock;
316	struct file *filp;
317	bool result;
318	int err;
319
320	/* Arrange */
321	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
322	KUNIT_ASSERT_NOT_NULL(test, req);
323
324	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
325			    &sock, 1);
326	KUNIT_ASSERT_EQ(test, err, 0);
327
328	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
329	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
330	sock->file = filp;
331
332	err = handshake_req_submit(sock, req, GFP_KERNEL);
333	KUNIT_ASSERT_EQ(test, err, 0);
334
335	/* NB: handshake_req hasn't been accepted */
336
337	/* Act */
338	result = handshake_req_cancel(sock->sk);
339
340	/* Assert */
341	KUNIT_EXPECT_TRUE(test, result);
342
343	fput(filp);
344}
345
346static void handshake_req_cancel_test2(struct kunit *test)
347{
348	struct handshake_req *req, *next;
349	struct handshake_net *hn;
350	struct socket *sock;
351	struct file *filp;
352	struct net *net;
353	bool result;
354	int err;
355
356	/* Arrange */
357	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
358	KUNIT_ASSERT_NOT_NULL(test, req);
359
360	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
361			    &sock, 1);
362	KUNIT_ASSERT_EQ(test, err, 0);
363
364	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
365	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
366	sock->file = filp;
367
368	err = handshake_req_submit(sock, req, GFP_KERNEL);
369	KUNIT_ASSERT_EQ(test, err, 0);
370
371	net = sock_net(sock->sk);
372	hn = handshake_pernet(net);
373	KUNIT_ASSERT_NOT_NULL(test, hn);
374
375	/* Pretend to accept this request */
376	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
377	KUNIT_ASSERT_PTR_EQ(test, req, next);
378
379	/* Act */
380	result = handshake_req_cancel(sock->sk);
381
382	/* Assert */
383	KUNIT_EXPECT_TRUE(test, result);
384
385	fput(filp);
386}
387
388static void handshake_req_cancel_test3(struct kunit *test)
389{
390	struct handshake_req *req, *next;
391	struct handshake_net *hn;
392	struct socket *sock;
393	struct file *filp;
394	struct net *net;
395	bool result;
396	int err;
397
398	/* Arrange */
399	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
400	KUNIT_ASSERT_NOT_NULL(test, req);
401
402	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
403			    &sock, 1);
404	KUNIT_ASSERT_EQ(test, err, 0);
405
406	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
407	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
408	sock->file = filp;
409
410	err = handshake_req_submit(sock, req, GFP_KERNEL);
411	KUNIT_ASSERT_EQ(test, err, 0);
412
413	net = sock_net(sock->sk);
414	hn = handshake_pernet(net);
415	KUNIT_ASSERT_NOT_NULL(test, hn);
416
417	/* Pretend to accept this request */
418	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
419	KUNIT_ASSERT_PTR_EQ(test, req, next);
420
421	/* Pretend to complete this request */
422	handshake_complete(next, -ETIMEDOUT, NULL);
423
424	/* Act */
425	result = handshake_req_cancel(sock->sk);
426
427	/* Assert */
428	KUNIT_EXPECT_FALSE(test, result);
429
430	fput(filp);
431}
432
433static struct handshake_req *handshake_req_destroy_test;
434
435static void test_destroy_func(struct handshake_req *req)
436{
437	handshake_req_destroy_test = req;
438}
439
440static struct handshake_proto handshake_req_alloc_proto_destroy = {
441	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
442	.hp_accept		= test_accept_func,
443	.hp_done		= test_done_func,
444	.hp_destroy		= test_destroy_func,
445};
446
447static void handshake_req_destroy_test1(struct kunit *test)
448{
449	struct handshake_req *req;
450	struct socket *sock;
451	struct file *filp;
452	int err;
453
454	/* Arrange */
455	handshake_req_destroy_test = NULL;
456
457	req = handshake_req_alloc(&handshake_req_alloc_proto_destroy, GFP_KERNEL);
458	KUNIT_ASSERT_NOT_NULL(test, req);
459
460	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
461			    &sock, 1);
462	KUNIT_ASSERT_EQ(test, err, 0);
463
464	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
465	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
466	sock->file = filp;
467
468	err = handshake_req_submit(sock, req, GFP_KERNEL);
469	KUNIT_ASSERT_EQ(test, err, 0);
470
471	handshake_req_cancel(sock->sk);
472
473	/* Act */
474	/* Ensure the close/release/put process has run to
475	 * completion before checking the result.
476	 */
477	__fput_sync(filp);
478
479	/* Assert */
480	KUNIT_EXPECT_PTR_EQ(test, handshake_req_destroy_test, req);
481}
482
483static struct kunit_case handshake_api_test_cases[] = {
484	{
485		.name			= "req_alloc API fuzzing",
486		.run_case		= handshake_req_alloc_case,
487		.generate_params	= handshake_req_alloc_gen_params,
488	},
489	{
490		.name			= "req_submit NULL req arg",
491		.run_case		= handshake_req_submit_test1,
492	},
493	{
494		.name			= "req_submit NULL sock arg",
495		.run_case		= handshake_req_submit_test2,
496	},
497	{
498		.name			= "req_submit NULL sock->file",
499		.run_case		= handshake_req_submit_test3,
500	},
501	{
502		.name			= "req_lookup works",
503		.run_case		= handshake_req_submit_test4,
504	},
505	{
506		.name			= "req_submit max pending",
507		.run_case		= handshake_req_submit_test5,
508	},
509	{
510		.name			= "req_submit multiple",
511		.run_case		= handshake_req_submit_test6,
512	},
513	{
514		.name			= "req_cancel before accept",
515		.run_case		= handshake_req_cancel_test1,
516	},
517	{
518		.name			= "req_cancel after accept",
519		.run_case		= handshake_req_cancel_test2,
520	},
521	{
522		.name			= "req_cancel after done",
523		.run_case		= handshake_req_cancel_test3,
524	},
525	{
526		.name			= "req_destroy works",
527		.run_case		= handshake_req_destroy_test1,
528	},
529	{}
530};
531
532static struct kunit_suite handshake_api_suite = {
533       .name                   = "Handshake API tests",
534       .test_cases             = handshake_api_test_cases,
535};
536
537kunit_test_suites(&handshake_api_suite);
538
539MODULE_DESCRIPTION("Test handshake upcall API functions");
540MODULE_LICENSE("GPL");
541