1// SPDX-License-Identifier: GPL-2.0
2/*
3 * ipsec.c - Check xfrm on veth inside a net-ns.
4 * Copyright (c) 2018 Dmitry Safonov
5 */
6
7#define _GNU_SOURCE
8
9#include <arpa/inet.h>
10#include <asm/types.h>
11#include <errno.h>
12#include <fcntl.h>
13#include <limits.h>
14#include <linux/limits.h>
15#include <linux/netlink.h>
16#include <linux/random.h>
17#include <linux/rtnetlink.h>
18#include <linux/veth.h>
19#include <linux/xfrm.h>
20#include <netinet/in.h>
21#include <net/if.h>
22#include <sched.h>
23#include <stdbool.h>
24#include <stdint.h>
25#include <stdio.h>
26#include <stdlib.h>
27#include <string.h>
28#include <sys/mman.h>
29#include <sys/socket.h>
30#include <sys/stat.h>
31#include <sys/syscall.h>
32#include <sys/types.h>
33#include <sys/wait.h>
34#include <time.h>
35#include <unistd.h>
36
37#include "../kselftest.h"
38
39#define printk(fmt, ...)						\
40	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41
42#define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
43
44#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
45#define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
46
47#define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
48#define MAX_PAYLOAD	2048
49#define XFRM_ALGO_KEY_BUF_SIZE	512
50#define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
51#define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
52#define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
53
54/* /30 mask for one veth connection */
55#define PREFIX_LEN	30
56#define child_ip(nr)	(4*nr + 1)
57#define grchild_ip(nr)	(4*nr + 2)
58
59#define VETH_FMT	"ktst-%d"
60#define VETH_LEN	12
61
62static int nsfd_parent	= -1;
63static int nsfd_childa	= -1;
64static int nsfd_childb	= -1;
65static long page_size;
66
67/*
68 * ksft_cnt is static in kselftest, so isn't shared with children.
69 * We have to send a test result back to parent and count there.
70 * results_fd is a pipe with test feedback from children.
71 */
72static int results_fd[2];
73
74const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
75const unsigned int ping_timeout		= 300;
76const unsigned int ping_count		= 100;
77const unsigned int ping_success		= 80;
78
79static void randomize_buffer(void *buf, size_t buflen)
80{
81	int *p = (int *)buf;
82	size_t words = buflen / sizeof(int);
83	size_t leftover = buflen % sizeof(int);
84
85	if (!buflen)
86		return;
87
88	while (words--)
89		*p++ = rand();
90
91	if (leftover) {
92		int tmp = rand();
93
94		memcpy(buf + buflen - leftover, &tmp, leftover);
95	}
96
97	return;
98}
99
100static int unshare_open(void)
101{
102	const char *netns_path = "/proc/self/ns/net";
103	int fd;
104
105	if (unshare(CLONE_NEWNET) != 0) {
106		pr_err("unshare()");
107		return -1;
108	}
109
110	fd = open(netns_path, O_RDONLY);
111	if (fd <= 0) {
112		pr_err("open(%s)", netns_path);
113		return -1;
114	}
115
116	return fd;
117}
118
119static int switch_ns(int fd)
120{
121	if (setns(fd, CLONE_NEWNET)) {
122		pr_err("setns()");
123		return -1;
124	}
125	return 0;
126}
127
128/*
129 * Running the test inside a new parent net namespace to bother less
130 * about cleanup on error-path.
131 */
132static int init_namespaces(void)
133{
134	nsfd_parent = unshare_open();
135	if (nsfd_parent <= 0)
136		return -1;
137
138	nsfd_childa = unshare_open();
139	if (nsfd_childa <= 0)
140		return -1;
141
142	if (switch_ns(nsfd_parent))
143		return -1;
144
145	nsfd_childb = unshare_open();
146	if (nsfd_childb <= 0)
147		return -1;
148
149	if (switch_ns(nsfd_parent))
150		return -1;
151	return 0;
152}
153
154static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
155{
156	if (*sock > 0) {
157		seq_nr++;
158		return 0;
159	}
160
161	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
162	if (*sock <= 0) {
163		pr_err("socket(AF_NETLINK)");
164		return -1;
165	}
166
167	randomize_buffer(seq_nr, sizeof(*seq_nr));
168
169	return 0;
170}
171
172static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
173{
174	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
175}
176
177static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
178		unsigned short rta_type, const void *payload, size_t size)
179{
180	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
181	struct rtattr *attr = rtattr_hdr(nh);
182	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
183
184	if (req_sz < nl_size) {
185		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
186		return -1;
187	}
188	nh->nlmsg_len = nl_size;
189
190	attr->rta_len = RTA_LENGTH(size);
191	attr->rta_type = rta_type;
192	memcpy(RTA_DATA(attr), payload, size);
193
194	return 0;
195}
196
197static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
198		unsigned short rta_type, const void *payload, size_t size)
199{
200	struct rtattr *ret = rtattr_hdr(nh);
201
202	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
203		return 0;
204
205	return ret;
206}
207
208static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
209		unsigned short rta_type)
210{
211	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
212}
213
214static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
215{
216	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
217
218	attr->rta_len = nlmsg_end - (char *)attr;
219}
220
221static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
222		const char *peer, int ns)
223{
224	struct ifinfomsg pi;
225	struct rtattr *peer_attr;
226
227	memset(&pi, 0, sizeof(pi));
228	pi.ifi_family	= AF_UNSPEC;
229	pi.ifi_change	= 0xFFFFFFFF;
230
231	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
232	if (!peer_attr)
233		return -1;
234
235	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
236		return -1;
237
238	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
239		return -1;
240
241	rtattr_end(nh, peer_attr);
242
243	return 0;
244}
245
246static int netlink_check_answer(int sock)
247{
248	struct nlmsgerror {
249		struct nlmsghdr hdr;
250		int error;
251		struct nlmsghdr orig_msg;
252	} answer;
253
254	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
255		pr_err("recv()");
256		return -1;
257	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
258		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
259		return -1;
260	} else if (answer.error) {
261		printk("NLMSG_ERROR: %d: %s",
262			answer.error, strerror(-answer.error));
263		return answer.error;
264	}
265
266	return 0;
267}
268
269static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
270		const char *peerb, int ns_b)
271{
272	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
273	struct {
274		struct nlmsghdr		nh;
275		struct ifinfomsg	info;
276		char			attrbuf[MAX_PAYLOAD];
277	} req;
278	const char veth_type[] = "veth";
279	struct rtattr *link_info, *info_data;
280
281	memset(&req, 0, sizeof(req));
282	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
283	req.nh.nlmsg_type	= RTM_NEWLINK;
284	req.nh.nlmsg_flags	= flags;
285	req.nh.nlmsg_seq	= seq;
286	req.info.ifi_family	= AF_UNSPEC;
287	req.info.ifi_change	= 0xFFFFFFFF;
288
289	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
290		return -1;
291
292	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
293		return -1;
294
295	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
296	if (!link_info)
297		return -1;
298
299	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
300		return -1;
301
302	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
303	if (!info_data)
304		return -1;
305
306	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
307		return -1;
308
309	rtattr_end(&req.nh, info_data);
310	rtattr_end(&req.nh, link_info);
311
312	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
313		pr_err("send()");
314		return -1;
315	}
316	return netlink_check_answer(sock);
317}
318
319static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
320		struct in_addr addr, uint8_t prefix)
321{
322	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
323	struct {
324		struct nlmsghdr		nh;
325		struct ifaddrmsg	info;
326		char			attrbuf[MAX_PAYLOAD];
327	} req;
328
329	memset(&req, 0, sizeof(req));
330	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
331	req.nh.nlmsg_type	= RTM_NEWADDR;
332	req.nh.nlmsg_flags	= flags;
333	req.nh.nlmsg_seq	= seq;
334	req.info.ifa_family	= AF_INET;
335	req.info.ifa_prefixlen	= prefix;
336	req.info.ifa_index	= if_nametoindex(intf);
337
338#ifdef DEBUG
339	{
340		char addr_str[IPV4_STR_SZ] = {};
341
342		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
343
344		printk("ip addr set %s", addr_str);
345	}
346#endif
347
348	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
349		return -1;
350
351	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
352		return -1;
353
354	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
355		pr_err("send()");
356		return -1;
357	}
358	return netlink_check_answer(sock);
359}
360
361static int link_set_up(int sock, uint32_t seq, const char *intf)
362{
363	struct {
364		struct nlmsghdr		nh;
365		struct ifinfomsg	info;
366		char			attrbuf[MAX_PAYLOAD];
367	} req;
368
369	memset(&req, 0, sizeof(req));
370	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
371	req.nh.nlmsg_type	= RTM_NEWLINK;
372	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
373	req.nh.nlmsg_seq	= seq;
374	req.info.ifi_family	= AF_UNSPEC;
375	req.info.ifi_change	= 0xFFFFFFFF;
376	req.info.ifi_index	= if_nametoindex(intf);
377	req.info.ifi_flags	= IFF_UP;
378	req.info.ifi_change	= IFF_UP;
379
380	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
381		pr_err("send()");
382		return -1;
383	}
384	return netlink_check_answer(sock);
385}
386
387static int ip4_route_set(int sock, uint32_t seq, const char *intf,
388		struct in_addr src, struct in_addr dst)
389{
390	struct {
391		struct nlmsghdr	nh;
392		struct rtmsg	rt;
393		char		attrbuf[MAX_PAYLOAD];
394	} req;
395	unsigned int index = if_nametoindex(intf);
396
397	memset(&req, 0, sizeof(req));
398	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
399	req.nh.nlmsg_type	= RTM_NEWROUTE;
400	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
401	req.nh.nlmsg_seq	= seq;
402	req.rt.rtm_family	= AF_INET;
403	req.rt.rtm_dst_len	= 32;
404	req.rt.rtm_table	= RT_TABLE_MAIN;
405	req.rt.rtm_protocol	= RTPROT_BOOT;
406	req.rt.rtm_scope	= RT_SCOPE_LINK;
407	req.rt.rtm_type		= RTN_UNICAST;
408
409	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
410		return -1;
411
412	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
413		return -1;
414
415	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
416		return -1;
417
418	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
419		pr_err("send()");
420		return -1;
421	}
422
423	return netlink_check_answer(sock);
424}
425
426static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
427		struct in_addr tunsrc, struct in_addr tundst)
428{
429	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
430			tunsrc, PREFIX_LEN)) {
431		printk("Failed to set ipv4 addr");
432		return -1;
433	}
434
435	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
436		printk("Failed to set ipv4 route");
437		return -1;
438	}
439
440	return 0;
441}
442
443static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
444{
445	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
446	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
447	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
448	int route_sock = -1, ret = -1;
449	uint32_t route_seq;
450
451	if (switch_ns(nsfd))
452		return -1;
453
454	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
455		printk("Failed to open netlink route socket in child");
456		return -1;
457	}
458
459	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
460		printk("Failed to set ipv4 addr");
461		goto err;
462	}
463
464	if (link_set_up(route_sock, route_seq++, veth)) {
465		printk("Failed to bring up %s", veth);
466		goto err;
467	}
468
469	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
470		printk("Failed to add tunnel route on %s", veth);
471		goto err;
472	}
473	ret = 0;
474
475err:
476	close(route_sock);
477	return ret;
478}
479
480#define ALGO_LEN	64
481enum desc_type {
482	CREATE_TUNNEL	= 0,
483	ALLOCATE_SPI,
484	MONITOR_ACQUIRE,
485	EXPIRE_STATE,
486	EXPIRE_POLICY,
487};
488const char *desc_name[] = {
489	"create tunnel",
490	"alloc spi",
491	"monitor acquire",
492	"expire state",
493	"expire policy"
494};
495struct xfrm_desc {
496	enum desc_type	type;
497	uint8_t		proto;
498	char		a_algo[ALGO_LEN];
499	char		e_algo[ALGO_LEN];
500	char		c_algo[ALGO_LEN];
501	char		ae_algo[ALGO_LEN];
502	unsigned int	icv_len;
503	/* unsigned key_len; */
504};
505
506enum msg_type {
507	MSG_ACK		= 0,
508	MSG_EXIT,
509	MSG_PING,
510	MSG_XFRM_PREPARE,
511	MSG_XFRM_ADD,
512	MSG_XFRM_DEL,
513	MSG_XFRM_CLEANUP,
514};
515
516struct test_desc {
517	enum msg_type type;
518	union {
519		struct {
520			in_addr_t reply_ip;
521			unsigned int port;
522		} ping;
523		struct xfrm_desc xfrm_desc;
524	} body;
525};
526
527struct test_result {
528	struct xfrm_desc desc;
529	unsigned int res;
530};
531
532static void write_test_result(unsigned int res, struct xfrm_desc *d)
533{
534	struct test_result tr = {};
535	ssize_t ret;
536
537	tr.desc = *d;
538	tr.res = res;
539
540	ret = write(results_fd[1], &tr, sizeof(tr));
541	if (ret != sizeof(tr))
542		pr_err("Failed to write the result in pipe %zd", ret);
543}
544
545static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
546{
547	ssize_t bytes = write(fd, msg, sizeof(*msg));
548
549	/* Make sure that write/read is atomic to a pipe */
550	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
551
552	if (bytes < 0) {
553		pr_err("write()");
554		if (exit_of_fail)
555			exit(KSFT_FAIL);
556	}
557	if (bytes != sizeof(*msg)) {
558		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
559		if (exit_of_fail)
560			exit(KSFT_FAIL);
561	}
562}
563
564static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
565{
566	ssize_t bytes = read(fd, msg, sizeof(*msg));
567
568	if (bytes < 0) {
569		pr_err("read()");
570		if (exit_of_fail)
571			exit(KSFT_FAIL);
572	}
573	if (bytes != sizeof(*msg)) {
574		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
575		if (exit_of_fail)
576			exit(KSFT_FAIL);
577	}
578}
579
580static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
581		unsigned int *server_port, int sock[2])
582{
583	struct sockaddr_in server;
584	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
585	socklen_t s_len = sizeof(server);
586
587	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
588	if (sock[0] < 0) {
589		pr_err("socket()");
590		return -1;
591	}
592
593	server.sin_family	= AF_INET;
594	server.sin_port		= 0;
595	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
596
597	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
598		pr_err("bind()");
599		goto err_close_server;
600	}
601
602	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
603		pr_err("getsockname()");
604		goto err_close_server;
605	}
606
607	*server_port = ntohs(server.sin_port);
608
609	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
610		pr_err("setsockopt()");
611		goto err_close_server;
612	}
613
614	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
615	if (sock[1] < 0) {
616		pr_err("socket()");
617		goto err_close_server;
618	}
619
620	return 0;
621
622err_close_server:
623	close(sock[0]);
624	return -1;
625}
626
627static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
628		char *buf, size_t buf_len)
629{
630	struct sockaddr_in server;
631	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
632	char *sock_buf[buf_len];
633	ssize_t r_bytes, s_bytes;
634
635	server.sin_family	= AF_INET;
636	server.sin_port		= htons(port);
637	server.sin_addr.s_addr	= dest_ip;
638
639	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
640	if (s_bytes < 0) {
641		pr_err("sendto()");
642		return -1;
643	} else if (s_bytes != buf_len) {
644		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
645		return -1;
646	}
647
648	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
649	if (r_bytes < 0) {
650		if (errno != EAGAIN)
651			pr_err("recv()");
652		return -1;
653	} else if (r_bytes == 0) { /* EOF */
654		printk("EOF on reply to ping");
655		return -1;
656	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
657		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
658		return -1;
659	}
660
661	return 0;
662}
663
664static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
665		char *buf, size_t buf_len)
666{
667	struct sockaddr_in server;
668	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
669	char *sock_buf[buf_len];
670	ssize_t r_bytes, s_bytes;
671
672	server.sin_family	= AF_INET;
673	server.sin_port		= htons(port);
674	server.sin_addr.s_addr	= dest_ip;
675
676	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
677	if (r_bytes < 0) {
678		if (errno != EAGAIN)
679			pr_err("recv()");
680		return -1;
681	}
682	if (r_bytes == 0) { /* EOF */
683		printk("EOF on reply to ping");
684		return -1;
685	}
686	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
687		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
688		return -1;
689	}
690
691	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
692	if (s_bytes < 0) {
693		pr_err("sendto()");
694		return -1;
695	} else if (s_bytes != buf_len) {
696		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
697		return -1;
698	}
699
700	return 0;
701}
702
703typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
704		char *buf, size_t buf_len);
705static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
706		bool init_side, int d_port, in_addr_t to, ping_f func)
707{
708	struct test_desc msg;
709	unsigned int s_port, i, ping_succeeded = 0;
710	int ping_sock[2];
711	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
712
713	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
714		printk("Failed to init ping");
715		return -1;
716	}
717
718	memset(&msg, 0, sizeof(msg));
719	msg.type		= MSG_PING;
720	msg.body.ping.port	= s_port;
721	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
722
723	write_msg(cmd_fd, &msg, 0);
724	if (init_side) {
725		/* The other end sends ip to ping */
726		read_msg(cmd_fd, &msg, 0);
727		if (msg.type != MSG_PING)
728			return -1;
729		to = msg.body.ping.reply_ip;
730		d_port = msg.body.ping.port;
731	}
732
733	for (i = 0; i < ping_count ; i++) {
734		struct timespec sleep_time = {
735			.tv_sec = 0,
736			.tv_nsec = ping_delay_nsec,
737		};
738
739		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
740		nanosleep(&sleep_time, 0);
741	}
742
743	close(ping_sock[0]);
744	close(ping_sock[1]);
745
746	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
747	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
748
749	if (ping_succeeded < ping_success) {
750		printk("ping (%s) %s->%s failed %u/%u times",
751			init_side ? "send" : "reply", from_str, to_str,
752			ping_count - ping_succeeded, ping_count);
753		return -1;
754	}
755
756#ifdef DEBUG
757	printk("ping (%s) %s->%s succeeded %u/%u times",
758		init_side ? "send" : "reply", from_str, to_str,
759		ping_succeeded, ping_count);
760#endif
761
762	return 0;
763}
764
765static int xfrm_fill_key(char *name, char *buf,
766		size_t buf_len, unsigned int *key_len)
767{
768	/* TODO: use set/map instead */
769	if (strncmp(name, "digest_null", ALGO_LEN) == 0)
770		*key_len = 0;
771	else if (strncmp(name, "ecb(cipher_null)", ALGO_LEN) == 0)
772		*key_len = 0;
773	else if (strncmp(name, "cbc(des)", ALGO_LEN) == 0)
774		*key_len = 64;
775	else if (strncmp(name, "hmac(md5)", ALGO_LEN) == 0)
776		*key_len = 128;
777	else if (strncmp(name, "cmac(aes)", ALGO_LEN) == 0)
778		*key_len = 128;
779	else if (strncmp(name, "xcbc(aes)", ALGO_LEN) == 0)
780		*key_len = 128;
781	else if (strncmp(name, "cbc(cast5)", ALGO_LEN) == 0)
782		*key_len = 128;
783	else if (strncmp(name, "cbc(serpent)", ALGO_LEN) == 0)
784		*key_len = 128;
785	else if (strncmp(name, "hmac(sha1)", ALGO_LEN) == 0)
786		*key_len = 160;
787	else if (strncmp(name, "hmac(rmd160)", ALGO_LEN) == 0)
788		*key_len = 160;
789	else if (strncmp(name, "cbc(des3_ede)", ALGO_LEN) == 0)
790		*key_len = 192;
791	else if (strncmp(name, "hmac(sha256)", ALGO_LEN) == 0)
792		*key_len = 256;
793	else if (strncmp(name, "cbc(aes)", ALGO_LEN) == 0)
794		*key_len = 256;
795	else if (strncmp(name, "cbc(camellia)", ALGO_LEN) == 0)
796		*key_len = 256;
797	else if (strncmp(name, "cbc(twofish)", ALGO_LEN) == 0)
798		*key_len = 256;
799	else if (strncmp(name, "rfc3686(ctr(aes))", ALGO_LEN) == 0)
800		*key_len = 288;
801	else if (strncmp(name, "hmac(sha384)", ALGO_LEN) == 0)
802		*key_len = 384;
803	else if (strncmp(name, "cbc(blowfish)", ALGO_LEN) == 0)
804		*key_len = 448;
805	else if (strncmp(name, "hmac(sha512)", ALGO_LEN) == 0)
806		*key_len = 512;
807	else if (strncmp(name, "rfc4106(gcm(aes))-128", ALGO_LEN) == 0)
808		*key_len = 160;
809	else if (strncmp(name, "rfc4543(gcm(aes))-128", ALGO_LEN) == 0)
810		*key_len = 160;
811	else if (strncmp(name, "rfc4309(ccm(aes))-128", ALGO_LEN) == 0)
812		*key_len = 152;
813	else if (strncmp(name, "rfc4106(gcm(aes))-192", ALGO_LEN) == 0)
814		*key_len = 224;
815	else if (strncmp(name, "rfc4543(gcm(aes))-192", ALGO_LEN) == 0)
816		*key_len = 224;
817	else if (strncmp(name, "rfc4309(ccm(aes))-192", ALGO_LEN) == 0)
818		*key_len = 216;
819	else if (strncmp(name, "rfc4106(gcm(aes))-256", ALGO_LEN) == 0)
820		*key_len = 288;
821	else if (strncmp(name, "rfc4543(gcm(aes))-256", ALGO_LEN) == 0)
822		*key_len = 288;
823	else if (strncmp(name, "rfc4309(ccm(aes))-256", ALGO_LEN) == 0)
824		*key_len = 280;
825	else if (strncmp(name, "rfc7539(chacha20,poly1305)-128", ALGO_LEN) == 0)
826		*key_len = 0;
827
828	if (*key_len > buf_len) {
829		printk("Can't pack a key - too big for buffer");
830		return -1;
831	}
832
833	randomize_buffer(buf, *key_len);
834
835	return 0;
836}
837
838static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
839		struct xfrm_desc *desc)
840{
841	struct {
842		union {
843			struct xfrm_algo	alg;
844			struct xfrm_algo_aead	aead;
845			struct xfrm_algo_auth	auth;
846		} u;
847		char buf[XFRM_ALGO_KEY_BUF_SIZE];
848	} alg = {};
849	size_t alen, elen, clen, aelen;
850	unsigned short type;
851
852	alen = strlen(desc->a_algo);
853	elen = strlen(desc->e_algo);
854	clen = strlen(desc->c_algo);
855	aelen = strlen(desc->ae_algo);
856
857	/* Verify desc */
858	switch (desc->proto) {
859	case IPPROTO_AH:
860		if (!alen || elen || clen || aelen) {
861			printk("BUG: buggy ah desc");
862			return -1;
863		}
864		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
865		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
866				sizeof(alg.buf), &alg.u.alg.alg_key_len))
867			return -1;
868		type = XFRMA_ALG_AUTH;
869		break;
870	case IPPROTO_COMP:
871		if (!clen || elen || alen || aelen) {
872			printk("BUG: buggy comp desc");
873			return -1;
874		}
875		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
876		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
877				sizeof(alg.buf), &alg.u.alg.alg_key_len))
878			return -1;
879		type = XFRMA_ALG_COMP;
880		break;
881	case IPPROTO_ESP:
882		if (!((alen && elen) ^ aelen) || clen) {
883			printk("BUG: buggy esp desc");
884			return -1;
885		}
886		if (aelen) {
887			alg.u.aead.alg_icv_len = desc->icv_len;
888			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
889			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
890						sizeof(alg.buf), &alg.u.aead.alg_key_len))
891				return -1;
892			type = XFRMA_ALG_AEAD;
893		} else {
894
895			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
896			type = XFRMA_ALG_CRYPT;
897			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
898						sizeof(alg.buf), &alg.u.alg.alg_key_len))
899				return -1;
900			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
901				return -1;
902
903			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
904			type = XFRMA_ALG_AUTH;
905			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
906						sizeof(alg.buf), &alg.u.alg.alg_key_len))
907				return -1;
908		}
909		break;
910	default:
911		printk("BUG: unknown proto in desc");
912		return -1;
913	}
914
915	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
916		return -1;
917
918	return 0;
919}
920
921static inline uint32_t gen_spi(struct in_addr src)
922{
923	return htonl(inet_lnaof(src));
924}
925
926static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
927		struct in_addr src, struct in_addr dst,
928		struct xfrm_desc *desc)
929{
930	struct {
931		struct nlmsghdr		nh;
932		struct xfrm_usersa_info	info;
933		char			attrbuf[MAX_PAYLOAD];
934	} req;
935
936	memset(&req, 0, sizeof(req));
937	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
938	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
939	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
940	req.nh.nlmsg_seq	= seq;
941
942	/* Fill selector. */
943	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
944	memcpy(&req.info.sel.saddr, &src, sizeof(src));
945	req.info.sel.family		= AF_INET;
946	req.info.sel.prefixlen_d	= PREFIX_LEN;
947	req.info.sel.prefixlen_s	= PREFIX_LEN;
948
949	/* Fill id */
950	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
951	/* Note: zero-spi cannot be deleted */
952	req.info.id.spi = spi;
953	req.info.id.proto	= desc->proto;
954
955	memcpy(&req.info.saddr, &src, sizeof(src));
956
957	/* Fill lifteme_cfg */
958	req.info.lft.soft_byte_limit	= XFRM_INF;
959	req.info.lft.hard_byte_limit	= XFRM_INF;
960	req.info.lft.soft_packet_limit	= XFRM_INF;
961	req.info.lft.hard_packet_limit	= XFRM_INF;
962
963	req.info.family		= AF_INET;
964	req.info.mode		= XFRM_MODE_TUNNEL;
965
966	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
967		return -1;
968
969	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
970		pr_err("send()");
971		return -1;
972	}
973
974	return netlink_check_answer(xfrm_sock);
975}
976
977static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
978		struct in_addr src, struct in_addr dst,
979		struct xfrm_desc *desc)
980{
981	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
982		return false;
983
984	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
985		return false;
986
987	if (info->sel.family != AF_INET					||
988			info->sel.prefixlen_d != PREFIX_LEN		||
989			info->sel.prefixlen_s != PREFIX_LEN)
990		return false;
991
992	if (info->id.spi != spi || info->id.proto != desc->proto)
993		return false;
994
995	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
996		return false;
997
998	if (memcmp(&info->saddr, &src, sizeof(src)))
999		return false;
1000
1001	if (info->lft.soft_byte_limit != XFRM_INF			||
1002			info->lft.hard_byte_limit != XFRM_INF		||
1003			info->lft.soft_packet_limit != XFRM_INF		||
1004			info->lft.hard_packet_limit != XFRM_INF)
1005		return false;
1006
1007	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
1008		return false;
1009
1010	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1011
1012	return true;
1013}
1014
1015static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1016		struct in_addr src, struct in_addr dst,
1017		struct xfrm_desc *desc)
1018{
1019	struct {
1020		struct nlmsghdr		nh;
1021		char			attrbuf[MAX_PAYLOAD];
1022	} req;
1023	struct {
1024		struct nlmsghdr		nh;
1025		union {
1026			struct xfrm_usersa_info	info;
1027			int error;
1028		};
1029		char			attrbuf[MAX_PAYLOAD];
1030	} answer;
1031	struct xfrm_address_filter filter = {};
1032	bool found = false;
1033
1034
1035	memset(&req, 0, sizeof(req));
1036	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
1037	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
1038	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
1039	req.nh.nlmsg_seq	= seq;
1040
1041	/*
1042	 * Add dump filter by source address as there may be other tunnels
1043	 * in this netns (if tests run in parallel).
1044	 */
1045	filter.family = AF_INET;
1046	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
1047	memcpy(&filter.saddr, &src, sizeof(src));
1048	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1049				&filter, sizeof(filter)))
1050		return -1;
1051
1052	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1053		pr_err("send()");
1054		return -1;
1055	}
1056
1057	while (1) {
1058		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1059			pr_err("recv()");
1060			return -1;
1061		}
1062		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1063			printk("NLMSG_ERROR: %d: %s",
1064				answer.error, strerror(-answer.error));
1065			return -1;
1066		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1067			if (found)
1068				return 0;
1069			printk("didn't find allocated xfrm state in dump");
1070			return -1;
1071		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1072			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1073				found = true;
1074		}
1075	}
1076}
1077
1078static int xfrm_set(int xfrm_sock, uint32_t *seq,
1079		struct in_addr src, struct in_addr dst,
1080		struct in_addr tunsrc, struct in_addr tundst,
1081		struct xfrm_desc *desc)
1082{
1083	int err;
1084
1085	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1086	if (err) {
1087		printk("Failed to add xfrm state");
1088		return -1;
1089	}
1090
1091	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1092	if (err) {
1093		printk("Failed to add xfrm state");
1094		return -1;
1095	}
1096
1097	/* Check dumps for XFRM_MSG_GETSA */
1098	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1099	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1100	if (err) {
1101		printk("Failed to check xfrm state");
1102		return -1;
1103	}
1104
1105	return 0;
1106}
1107
1108static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1109		struct in_addr src, struct in_addr dst, uint8_t dir,
1110		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1111{
1112	struct {
1113		struct nlmsghdr			nh;
1114		struct xfrm_userpolicy_info	info;
1115		char				attrbuf[MAX_PAYLOAD];
1116	} req;
1117	struct xfrm_user_tmpl tmpl;
1118
1119	memset(&req, 0, sizeof(req));
1120	memset(&tmpl, 0, sizeof(tmpl));
1121	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
1122	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
1123	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1124	req.nh.nlmsg_seq	= seq;
1125
1126	/* Fill selector. */
1127	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1128	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1129	req.info.sel.family		= AF_INET;
1130	req.info.sel.prefixlen_d	= PREFIX_LEN;
1131	req.info.sel.prefixlen_s	= PREFIX_LEN;
1132
1133	/* Fill lifteme_cfg */
1134	req.info.lft.soft_byte_limit	= XFRM_INF;
1135	req.info.lft.hard_byte_limit	= XFRM_INF;
1136	req.info.lft.soft_packet_limit	= XFRM_INF;
1137	req.info.lft.hard_packet_limit	= XFRM_INF;
1138
1139	req.info.dir = dir;
1140
1141	/* Fill tmpl */
1142	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1143	/* Note: zero-spi cannot be deleted */
1144	tmpl.id.spi = spi;
1145	tmpl.id.proto	= proto;
1146	tmpl.family	= AF_INET;
1147	memcpy(&tmpl.saddr, &src, sizeof(src));
1148	tmpl.mode	= XFRM_MODE_TUNNEL;
1149	tmpl.aalgos = (~(uint32_t)0);
1150	tmpl.ealgos = (~(uint32_t)0);
1151	tmpl.calgos = (~(uint32_t)0);
1152
1153	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1154		return -1;
1155
1156	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1157		pr_err("send()");
1158		return -1;
1159	}
1160
1161	return netlink_check_answer(xfrm_sock);
1162}
1163
1164static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1165		struct in_addr src, struct in_addr dst,
1166		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1167{
1168	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1169				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1170		printk("Failed to add xfrm policy");
1171		return -1;
1172	}
1173
1174	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1175				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1176		printk("Failed to add xfrm policy");
1177		return -1;
1178	}
1179
1180	return 0;
1181}
1182
1183static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1184		struct in_addr src, struct in_addr dst, uint8_t dir,
1185		struct in_addr tunsrc, struct in_addr tundst)
1186{
1187	struct {
1188		struct nlmsghdr			nh;
1189		struct xfrm_userpolicy_id	id;
1190		char				attrbuf[MAX_PAYLOAD];
1191	} req;
1192
1193	memset(&req, 0, sizeof(req));
1194	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1195	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
1196	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1197	req.nh.nlmsg_seq	= seq;
1198
1199	/* Fill id */
1200	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1201	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1202	req.id.sel.family		= AF_INET;
1203	req.id.sel.prefixlen_d		= PREFIX_LEN;
1204	req.id.sel.prefixlen_s		= PREFIX_LEN;
1205	req.id.dir = dir;
1206
1207	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1208		pr_err("send()");
1209		return -1;
1210	}
1211
1212	return netlink_check_answer(xfrm_sock);
1213}
1214
1215static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1216		struct in_addr src, struct in_addr dst,
1217		struct in_addr tunsrc, struct in_addr tundst)
1218{
1219	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1220				XFRM_POLICY_OUT, tunsrc, tundst)) {
1221		printk("Failed to add xfrm policy");
1222		return -1;
1223	}
1224
1225	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1226				XFRM_POLICY_IN, tunsrc, tundst)) {
1227		printk("Failed to add xfrm policy");
1228		return -1;
1229	}
1230
1231	return 0;
1232}
1233
1234static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1235		struct in_addr src, struct in_addr dst, uint8_t proto)
1236{
1237	struct {
1238		struct nlmsghdr		nh;
1239		struct xfrm_usersa_id	id;
1240		char			attrbuf[MAX_PAYLOAD];
1241	} req;
1242	xfrm_address_t saddr = {};
1243
1244	memset(&req, 0, sizeof(req));
1245	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1246	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
1247	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1248	req.nh.nlmsg_seq	= seq;
1249
1250	memcpy(&req.id.daddr, &dst, sizeof(dst));
1251	req.id.family		= AF_INET;
1252	req.id.proto		= proto;
1253	/* Note: zero-spi cannot be deleted */
1254	req.id.spi = spi;
1255
1256	memcpy(&saddr, &src, sizeof(src));
1257	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1258		return -1;
1259
1260	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1261		pr_err("send()");
1262		return -1;
1263	}
1264
1265	return netlink_check_answer(xfrm_sock);
1266}
1267
1268static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1269		struct in_addr src, struct in_addr dst,
1270		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1271{
1272	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1273		printk("Failed to remove xfrm state");
1274		return -1;
1275	}
1276
1277	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1278		printk("Failed to remove xfrm state");
1279		return -1;
1280	}
1281
1282	return 0;
1283}
1284
1285static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1286		uint32_t spi, uint8_t proto)
1287{
1288	struct {
1289		struct nlmsghdr			nh;
1290		struct xfrm_userspi_info	spi;
1291	} req;
1292	struct {
1293		struct nlmsghdr			nh;
1294		union {
1295			struct xfrm_usersa_info	info;
1296			int error;
1297		};
1298	} answer;
1299
1300	memset(&req, 0, sizeof(req));
1301	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
1302	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
1303	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1304	req.nh.nlmsg_seq	= (*seq)++;
1305
1306	req.spi.info.family	= AF_INET;
1307	req.spi.min		= spi;
1308	req.spi.max		= spi;
1309	req.spi.info.id.proto	= proto;
1310
1311	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1312		pr_err("send()");
1313		return KSFT_FAIL;
1314	}
1315
1316	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1317		pr_err("recv()");
1318		return KSFT_FAIL;
1319	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1320		uint32_t new_spi = htonl(answer.info.id.spi);
1321
1322		if (new_spi != spi) {
1323			printk("allocated spi is different from requested: %#x != %#x",
1324					new_spi, spi);
1325			return KSFT_FAIL;
1326		}
1327		return KSFT_PASS;
1328	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1329		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1330		return KSFT_FAIL;
1331	}
1332
1333	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1334	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1335}
1336
1337static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1338{
1339	struct sockaddr_nl snl = {};
1340	socklen_t addr_len;
1341	int ret = -1;
1342
1343	snl.nl_family = AF_NETLINK;
1344	snl.nl_groups = groups;
1345
1346	if (netlink_sock(sock, seq, proto)) {
1347		printk("Failed to open xfrm netlink socket");
1348		return -1;
1349	}
1350
1351	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1352		pr_err("bind()");
1353		goto out_close;
1354	}
1355
1356	addr_len = sizeof(snl);
1357	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1358		pr_err("getsockname()");
1359		goto out_close;
1360	}
1361	if (addr_len != sizeof(snl)) {
1362		printk("Wrong address length %d", addr_len);
1363		goto out_close;
1364	}
1365	if (snl.nl_family != AF_NETLINK) {
1366		printk("Wrong address family %d", snl.nl_family);
1367		goto out_close;
1368	}
1369	return 0;
1370
1371out_close:
1372	close(*sock);
1373	return ret;
1374}
1375
1376static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1377{
1378	struct {
1379		struct nlmsghdr nh;
1380		union {
1381			struct xfrm_user_acquire acq;
1382			int error;
1383		};
1384		char attrbuf[MAX_PAYLOAD];
1385	} req;
1386	struct xfrm_user_tmpl xfrm_tmpl = {};
1387	int xfrm_listen = -1, ret = KSFT_FAIL;
1388	uint32_t seq_listen;
1389
1390	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1391		return KSFT_FAIL;
1392
1393	memset(&req, 0, sizeof(req));
1394	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
1395	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
1396	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1397	req.nh.nlmsg_seq	= (*seq)++;
1398
1399	req.acq.policy.sel.family	= AF_INET;
1400	req.acq.aalgos	= 0xfeed;
1401	req.acq.ealgos	= 0xbaad;
1402	req.acq.calgos	= 0xbabe;
1403
1404	xfrm_tmpl.family = AF_INET;
1405	xfrm_tmpl.id.proto = IPPROTO_ESP;
1406	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1407		goto out_close;
1408
1409	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1410		pr_err("send()");
1411		goto out_close;
1412	}
1413
1414	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1415		pr_err("recv()");
1416		goto out_close;
1417	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1418		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1419		goto out_close;
1420	}
1421
1422	if (req.error) {
1423		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1424		ret = req.error;
1425		goto out_close;
1426	}
1427
1428	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1429		pr_err("recv()");
1430		goto out_close;
1431	}
1432
1433	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1434			|| req.acq.calgos != 0xbabe) {
1435		printk("xfrm_user_acquire has changed  %x %x %x",
1436				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1437		goto out_close;
1438	}
1439
1440	ret = KSFT_PASS;
1441out_close:
1442	close(xfrm_listen);
1443	return ret;
1444}
1445
1446static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1447		unsigned int nr, struct xfrm_desc *desc)
1448{
1449	struct {
1450		struct nlmsghdr nh;
1451		union {
1452			struct xfrm_user_expire expire;
1453			int error;
1454		};
1455	} req;
1456	struct in_addr src, dst;
1457	int xfrm_listen = -1, ret = KSFT_FAIL;
1458	uint32_t seq_listen;
1459
1460	src = inet_makeaddr(INADDR_B, child_ip(nr));
1461	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1462
1463	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1464		printk("Failed to add xfrm state");
1465		return KSFT_FAIL;
1466	}
1467
1468	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1469		return KSFT_FAIL;
1470
1471	memset(&req, 0, sizeof(req));
1472	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1473	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
1474	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1475	req.nh.nlmsg_seq	= (*seq)++;
1476
1477	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1478	req.expire.state.id.spi		= gen_spi(src);
1479	req.expire.state.id.proto	= desc->proto;
1480	req.expire.state.family		= AF_INET;
1481	req.expire.hard			= 0xff;
1482
1483	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1484		pr_err("send()");
1485		goto out_close;
1486	}
1487
1488	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1489		pr_err("recv()");
1490		goto out_close;
1491	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1492		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1493		goto out_close;
1494	}
1495
1496	if (req.error) {
1497		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1498		ret = req.error;
1499		goto out_close;
1500	}
1501
1502	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1503		pr_err("recv()");
1504		goto out_close;
1505	}
1506
1507	if (req.expire.hard != 0x1) {
1508		printk("expire.hard is not set: %x", req.expire.hard);
1509		goto out_close;
1510	}
1511
1512	ret = KSFT_PASS;
1513out_close:
1514	close(xfrm_listen);
1515	return ret;
1516}
1517
1518static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1519		unsigned int nr, struct xfrm_desc *desc)
1520{
1521	struct {
1522		struct nlmsghdr nh;
1523		union {
1524			struct xfrm_user_polexpire expire;
1525			int error;
1526		};
1527	} req;
1528	struct in_addr src, dst, tunsrc, tundst;
1529	int xfrm_listen = -1, ret = KSFT_FAIL;
1530	uint32_t seq_listen;
1531
1532	src = inet_makeaddr(INADDR_B, child_ip(nr));
1533	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1534	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1535	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1536
1537	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1538				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1539		printk("Failed to add xfrm policy");
1540		return KSFT_FAIL;
1541	}
1542
1543	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1544		return KSFT_FAIL;
1545
1546	memset(&req, 0, sizeof(req));
1547	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1548	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
1549	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1550	req.nh.nlmsg_seq	= (*seq)++;
1551
1552	/* Fill selector. */
1553	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1554	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1555	req.expire.pol.sel.family	= AF_INET;
1556	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
1557	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
1558	req.expire.pol.dir		= XFRM_POLICY_OUT;
1559	req.expire.hard			= 0xff;
1560
1561	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1562		pr_err("send()");
1563		goto out_close;
1564	}
1565
1566	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1567		pr_err("recv()");
1568		goto out_close;
1569	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1570		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1571		goto out_close;
1572	}
1573
1574	if (req.error) {
1575		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1576		ret = req.error;
1577		goto out_close;
1578	}
1579
1580	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1581		pr_err("recv()");
1582		goto out_close;
1583	}
1584
1585	if (req.expire.hard != 0x1) {
1586		printk("expire.hard is not set: %x", req.expire.hard);
1587		goto out_close;
1588	}
1589
1590	ret = KSFT_PASS;
1591out_close:
1592	close(xfrm_listen);
1593	return ret;
1594}
1595
1596static int child_serv(int xfrm_sock, uint32_t *seq,
1597		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1598{
1599	struct in_addr src, dst, tunsrc, tundst;
1600	struct test_desc msg;
1601	int ret = KSFT_FAIL;
1602
1603	src = inet_makeaddr(INADDR_B, child_ip(nr));
1604	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1605	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1606	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1607
1608	/* UDP pinging without xfrm */
1609	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1610		printk("ping failed before setting xfrm");
1611		return KSFT_FAIL;
1612	}
1613
1614	memset(&msg, 0, sizeof(msg));
1615	msg.type = MSG_XFRM_PREPARE;
1616	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1617	write_msg(cmd_fd, &msg, 1);
1618
1619	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1620		printk("failed to prepare xfrm");
1621		goto cleanup;
1622	}
1623
1624	memset(&msg, 0, sizeof(msg));
1625	msg.type = MSG_XFRM_ADD;
1626	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1627	write_msg(cmd_fd, &msg, 1);
1628	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1629		printk("failed to set xfrm");
1630		goto delete;
1631	}
1632
1633	/* UDP pinging with xfrm tunnel */
1634	if (do_ping(cmd_fd, buf, page_size, tunsrc,
1635				true, 0, 0, udp_ping_send)) {
1636		printk("ping failed for xfrm");
1637		goto delete;
1638	}
1639
1640	ret = KSFT_PASS;
1641delete:
1642	/* xfrm delete */
1643	memset(&msg, 0, sizeof(msg));
1644	msg.type = MSG_XFRM_DEL;
1645	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1646	write_msg(cmd_fd, &msg, 1);
1647
1648	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1649		printk("failed ping to remove xfrm");
1650		ret = KSFT_FAIL;
1651	}
1652
1653cleanup:
1654	memset(&msg, 0, sizeof(msg));
1655	msg.type = MSG_XFRM_CLEANUP;
1656	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1657	write_msg(cmd_fd, &msg, 1);
1658	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1659		printk("failed ping to cleanup xfrm");
1660		ret = KSFT_FAIL;
1661	}
1662	return ret;
1663}
1664
1665static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1666{
1667	struct xfrm_desc desc;
1668	struct test_desc msg;
1669	int xfrm_sock = -1;
1670	uint32_t seq;
1671
1672	if (switch_ns(nsfd_childa))
1673		exit(KSFT_FAIL);
1674
1675	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1676		printk("Failed to open xfrm netlink socket");
1677		exit(KSFT_FAIL);
1678	}
1679
1680	/* Check that seq sock is ready, just for sure. */
1681	memset(&msg, 0, sizeof(msg));
1682	msg.type = MSG_ACK;
1683	write_msg(cmd_fd, &msg, 1);
1684	read_msg(cmd_fd, &msg, 1);
1685	if (msg.type != MSG_ACK) {
1686		printk("Ack failed");
1687		exit(KSFT_FAIL);
1688	}
1689
1690	for (;;) {
1691		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1692		int ret;
1693
1694		if (received == 0) /* EOF */
1695			break;
1696
1697		if (received != sizeof(desc)) {
1698			pr_err("read() returned %zd", received);
1699			exit(KSFT_FAIL);
1700		}
1701
1702		switch (desc.type) {
1703		case CREATE_TUNNEL:
1704			ret = child_serv(xfrm_sock, &seq, nr,
1705					 cmd_fd, buf, &desc);
1706			break;
1707		case ALLOCATE_SPI:
1708			ret = xfrm_state_allocspi(xfrm_sock, &seq,
1709						  -1, desc.proto);
1710			break;
1711		case MONITOR_ACQUIRE:
1712			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1713			break;
1714		case EXPIRE_STATE:
1715			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1716			break;
1717		case EXPIRE_POLICY:
1718			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1719			break;
1720		default:
1721			printk("Unknown desc type %d", desc.type);
1722			exit(KSFT_FAIL);
1723		}
1724		write_test_result(ret, &desc);
1725	}
1726
1727	close(xfrm_sock);
1728
1729	msg.type = MSG_EXIT;
1730	write_msg(cmd_fd, &msg, 1);
1731	exit(KSFT_PASS);
1732}
1733
1734static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1735		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1736{
1737	struct in_addr src, dst, tunsrc, tundst;
1738	bool tun_reply;
1739	struct xfrm_desc *desc = &msg->body.xfrm_desc;
1740
1741	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1742	dst = inet_makeaddr(INADDR_B, child_ip(nr));
1743	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1744	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1745
1746	switch (msg->type) {
1747	case MSG_EXIT:
1748		exit(KSFT_PASS);
1749	case MSG_ACK:
1750		write_msg(cmd_fd, msg, 1);
1751		break;
1752	case MSG_PING:
1753		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1754		/* UDP pinging without xfrm */
1755		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1756				false, msg->body.ping.port,
1757				msg->body.ping.reply_ip, udp_ping_reply)) {
1758			printk("ping failed before setting xfrm");
1759		}
1760		break;
1761	case MSG_XFRM_PREPARE:
1762		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1763					desc->proto)) {
1764			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1765			printk("failed to prepare xfrm");
1766		}
1767		break;
1768	case MSG_XFRM_ADD:
1769		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1770			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1771			printk("failed to set xfrm");
1772		}
1773		break;
1774	case MSG_XFRM_DEL:
1775		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1776					desc->proto)) {
1777			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1778			printk("failed to remove xfrm");
1779		}
1780		break;
1781	case MSG_XFRM_CLEANUP:
1782		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1783			printk("failed to cleanup xfrm");
1784		}
1785		break;
1786	default:
1787		printk("got unknown msg type %d", msg->type);
1788	};
1789}
1790
1791static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1792{
1793	struct test_desc msg;
1794	int xfrm_sock = -1;
1795	uint32_t seq;
1796
1797	if (switch_ns(nsfd_childb))
1798		exit(KSFT_FAIL);
1799
1800	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1801		printk("Failed to open xfrm netlink socket");
1802		exit(KSFT_FAIL);
1803	}
1804
1805	do {
1806		read_msg(cmd_fd, &msg, 1);
1807		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1808	} while (1);
1809
1810	close(xfrm_sock);
1811	exit(KSFT_FAIL);
1812}
1813
1814static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1815{
1816	int cmd_sock[2];
1817	void *data_map;
1818	pid_t child;
1819
1820	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1821		return -1;
1822
1823	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1824		return -1;
1825
1826	child = fork();
1827	if (child < 0) {
1828		pr_err("fork()");
1829		return -1;
1830	} else if (child) {
1831		/* in parent - selftest */
1832		return switch_ns(nsfd_parent);
1833	}
1834
1835	if (close(test_desc_fd[1])) {
1836		pr_err("close()");
1837		return -1;
1838	}
1839
1840	/* child */
1841	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1842			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1843	if (data_map == MAP_FAILED) {
1844		pr_err("mmap()");
1845		return -1;
1846	}
1847
1848	randomize_buffer(data_map, page_size);
1849
1850	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
1851		pr_err("socketpair()");
1852		return -1;
1853	}
1854
1855	child = fork();
1856	if (child < 0) {
1857		pr_err("fork()");
1858		return -1;
1859	} else if (child) {
1860		if (close(cmd_sock[0])) {
1861			pr_err("close()");
1862			return -1;
1863		}
1864		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
1865	}
1866	if (close(cmd_sock[1])) {
1867		pr_err("close()");
1868		return -1;
1869	}
1870	return grand_child_f(nr, cmd_sock[0], data_map);
1871}
1872
1873static void exit_usage(char **argv)
1874{
1875	printk("Usage: %s [nr_process]", argv[0]);
1876	exit(KSFT_FAIL);
1877}
1878
1879static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
1880{
1881	ssize_t ret;
1882
1883	ret = write(test_desc_fd, desc, sizeof(*desc));
1884
1885	if (ret == sizeof(*desc))
1886		return 0;
1887
1888	pr_err("Writing test's desc failed %ld", ret);
1889
1890	return -1;
1891}
1892
1893static int write_desc(int proto, int test_desc_fd,
1894		char *a, char *e, char *c, char *ae)
1895{
1896	struct xfrm_desc desc = {};
1897
1898	desc.type = CREATE_TUNNEL;
1899	desc.proto = proto;
1900
1901	if (a)
1902		strncpy(desc.a_algo, a, ALGO_LEN - 1);
1903	if (e)
1904		strncpy(desc.e_algo, e, ALGO_LEN - 1);
1905	if (c)
1906		strncpy(desc.c_algo, c, ALGO_LEN - 1);
1907	if (ae)
1908		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
1909
1910	return __write_desc(test_desc_fd, &desc);
1911}
1912
1913int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
1914char *ah_list[] = {
1915	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
1916	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
1917	"xcbc(aes)", "cmac(aes)"
1918};
1919char *comp_list[] = {
1920	"deflate",
1921#if 0
1922	/* No compression backend realization */
1923	"lzs", "lzjh"
1924#endif
1925};
1926char *e_list[] = {
1927	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
1928	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
1929	"cbc(twofish)", "rfc3686(ctr(aes))"
1930};
1931char *ae_list[] = {
1932#if 0
1933	/* not implemented */
1934	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
1935	"rfc7539esp(chacha20,poly1305)"
1936#endif
1937};
1938
1939const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
1940				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
1941				+ ARRAY_SIZE(ae_list);
1942
1943static int write_proto_plan(int fd, int proto)
1944{
1945	unsigned int i;
1946
1947	switch (proto) {
1948	case IPPROTO_AH:
1949		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
1950			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
1951				return -1;
1952		}
1953		break;
1954	case IPPROTO_COMP:
1955		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
1956			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
1957				return -1;
1958		}
1959		break;
1960	case IPPROTO_ESP:
1961		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
1962			int j;
1963
1964			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
1965				if (write_desc(proto, fd, ah_list[i],
1966							e_list[j], 0, 0))
1967					return -1;
1968			}
1969		}
1970		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
1971			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
1972				return -1;
1973		}
1974		break;
1975	default:
1976		printk("BUG: Specified unknown proto %d", proto);
1977		return -1;
1978	}
1979
1980	return 0;
1981}
1982
1983/*
1984 * Some structures in xfrm uapi header differ in size between
1985 * 64-bit and 32-bit ABI:
1986 *
1987 *             32-bit UABI               |            64-bit UABI
1988 *  -------------------------------------|-------------------------------------
1989 *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
1990 *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
1991 *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
1992 *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
1993 *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
1994 *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
1995 *
1996 * Check the affected by the UABI difference structures.
1997 */
1998const unsigned int compat_plan = 4;
1999static int write_compat_struct_tests(int test_desc_fd)
2000{
2001	struct xfrm_desc desc = {};
2002
2003	desc.type = ALLOCATE_SPI;
2004	desc.proto = IPPROTO_AH;
2005	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2006
2007	if (__write_desc(test_desc_fd, &desc))
2008		return -1;
2009
2010	desc.type = MONITOR_ACQUIRE;
2011	if (__write_desc(test_desc_fd, &desc))
2012		return -1;
2013
2014	desc.type = EXPIRE_STATE;
2015	if (__write_desc(test_desc_fd, &desc))
2016		return -1;
2017
2018	desc.type = EXPIRE_POLICY;
2019	if (__write_desc(test_desc_fd, &desc))
2020		return -1;
2021
2022	return 0;
2023}
2024
2025static int write_test_plan(int test_desc_fd)
2026{
2027	unsigned int i;
2028	pid_t child;
2029
2030	child = fork();
2031	if (child < 0) {
2032		pr_err("fork()");
2033		return -1;
2034	}
2035	if (child) {
2036		if (close(test_desc_fd))
2037			printk("close(): %m");
2038		return 0;
2039	}
2040
2041	if (write_compat_struct_tests(test_desc_fd))
2042		exit(KSFT_FAIL);
2043
2044	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2045		if (write_proto_plan(test_desc_fd, proto_list[i]))
2046			exit(KSFT_FAIL);
2047	}
2048
2049	exit(KSFT_PASS);
2050}
2051
2052static int children_cleanup(void)
2053{
2054	unsigned ret = KSFT_PASS;
2055
2056	while (1) {
2057		int status;
2058		pid_t p = wait(&status);
2059
2060		if ((p < 0) && errno == ECHILD)
2061			break;
2062
2063		if (p < 0) {
2064			pr_err("wait()");
2065			return KSFT_FAIL;
2066		}
2067
2068		if (!WIFEXITED(status)) {
2069			ret = KSFT_FAIL;
2070			continue;
2071		}
2072
2073		if (WEXITSTATUS(status) == KSFT_FAIL)
2074			ret = KSFT_FAIL;
2075	}
2076
2077	return ret;
2078}
2079
2080typedef void (*print_res)(const char *, ...);
2081
2082static int check_results(void)
2083{
2084	struct test_result tr = {};
2085	struct xfrm_desc *d = &tr.desc;
2086	int ret = KSFT_PASS;
2087
2088	while (1) {
2089		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2090		print_res result;
2091
2092		if (received == 0) /* EOF */
2093			break;
2094
2095		if (received != sizeof(tr)) {
2096			pr_err("read() returned %zd", received);
2097			return KSFT_FAIL;
2098		}
2099
2100		switch (tr.res) {
2101		case KSFT_PASS:
2102			result = ksft_test_result_pass;
2103			break;
2104		case KSFT_FAIL:
2105		default:
2106			result = ksft_test_result_fail;
2107			ret = KSFT_FAIL;
2108		}
2109
2110		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2111		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2112		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2113	}
2114
2115	return ret;
2116}
2117
2118int main(int argc, char **argv)
2119{
2120	long nr_process = 1;
2121	int route_sock = -1, ret = KSFT_SKIP;
2122	int test_desc_fd[2];
2123	uint32_t route_seq;
2124	unsigned int i;
2125
2126	if (argc > 2)
2127		exit_usage(argv);
2128
2129	if (argc > 1) {
2130		char *endptr;
2131
2132		errno = 0;
2133		nr_process = strtol(argv[1], &endptr, 10);
2134		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2135				|| (errno != 0 && nr_process == 0)
2136				|| (endptr == argv[1]) || (*endptr != '\0')) {
2137			printk("Failed to parse [nr_process]");
2138			exit_usage(argv);
2139		}
2140
2141		if (nr_process > MAX_PROCESSES || nr_process < 1) {
2142			printk("nr_process should be between [1; %u]",
2143					MAX_PROCESSES);
2144			exit_usage(argv);
2145		}
2146	}
2147
2148	srand(time(NULL));
2149	page_size = sysconf(_SC_PAGESIZE);
2150	if (page_size < 1)
2151		ksft_exit_skip("sysconf(): %m\n");
2152
2153	if (pipe2(test_desc_fd, O_DIRECT) < 0)
2154		ksft_exit_skip("pipe(): %m\n");
2155
2156	if (pipe2(results_fd, O_DIRECT) < 0)
2157		ksft_exit_skip("pipe(): %m\n");
2158
2159	if (init_namespaces())
2160		ksft_exit_skip("Failed to create namespaces\n");
2161
2162	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2163		ksft_exit_skip("Failed to open netlink route socket\n");
2164
2165	for (i = 0; i < nr_process; i++) {
2166		char veth[VETH_LEN];
2167
2168		snprintf(veth, VETH_LEN, VETH_FMT, i);
2169
2170		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2171			close(route_sock);
2172			ksft_exit_fail_msg("Failed to create veth device");
2173		}
2174
2175		if (start_child(i, veth, test_desc_fd)) {
2176			close(route_sock);
2177			ksft_exit_fail_msg("Child %u failed to start", i);
2178		}
2179	}
2180
2181	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2182		ksft_exit_fail_msg("close(): %m");
2183
2184	ksft_set_plan(proto_plan + compat_plan);
2185
2186	if (write_test_plan(test_desc_fd[1]))
2187		ksft_exit_fail_msg("Failed to write test plan to pipe");
2188
2189	ret = check_results();
2190
2191	if (children_cleanup() == KSFT_FAIL)
2192		exit(KSFT_FAIL);
2193
2194	exit(ret);
2195}
2196