1// SPDX-License-Identifier: GPL-2.0
2// Copyright (c) 2018 Facebook
3
4#define _GNU_SOURCE
5
6#include <stdio.h>
7#include <stdlib.h>
8#include <unistd.h>
9
10#include <arpa/inet.h>
11#include <netinet/in.h>
12#include <sys/types.h>
13#include <sys/select.h>
14#include <sys/socket.h>
15
16#include <linux/filter.h>
17
18#include <bpf/bpf.h>
19#include <bpf/libbpf.h>
20
21#include "cgroup_helpers.h"
22#include "bpf_rlimit.h"
23#include "bpf_util.h"
24
25#ifndef ENOTSUPP
26# define ENOTSUPP 524
27#endif
28
29#define CG_PATH	"/foo"
30#define CONNECT4_PROG_PATH	"./connect4_prog.o"
31#define CONNECT6_PROG_PATH	"./connect6_prog.o"
32#define SENDMSG4_PROG_PATH	"./sendmsg4_prog.o"
33#define SENDMSG6_PROG_PATH	"./sendmsg6_prog.o"
34
35#define SERV4_IP		"192.168.1.254"
36#define SERV4_REWRITE_IP	"127.0.0.1"
37#define SRC4_IP			"172.16.0.1"
38#define SRC4_REWRITE_IP		"127.0.0.4"
39#define SERV4_PORT		4040
40#define SERV4_REWRITE_PORT	4444
41
42#define SERV6_IP		"face:b00c:1234:5678::abcd"
43#define SERV6_REWRITE_IP	"::1"
44#define SERV6_V4MAPPED_IP	"::ffff:192.168.0.4"
45#define SRC6_IP			"::1"
46#define SRC6_REWRITE_IP		"::6"
47#define WILDCARD6_IP		"::"
48#define SERV6_PORT		6060
49#define SERV6_REWRITE_PORT	6666
50
51#define INET_NTOP_BUF	40
52
53struct sock_addr_test;
54
55typedef int (*load_fn)(const struct sock_addr_test *test);
56typedef int (*info_fn)(int, struct sockaddr *, socklen_t *);
57
58char bpf_log_buf[BPF_LOG_BUF_SIZE];
59
60struct sock_addr_test {
61	const char *descr;
62	/* BPF prog properties */
63	load_fn loadfn;
64	enum bpf_attach_type expected_attach_type;
65	enum bpf_attach_type attach_type;
66	/* Socket properties */
67	int domain;
68	int type;
69	/* IP:port pairs for BPF prog to override */
70	const char *requested_ip;
71	unsigned short requested_port;
72	const char *expected_ip;
73	unsigned short expected_port;
74	const char *expected_src_ip;
75	/* Expected test result */
76	enum {
77		LOAD_REJECT,
78		ATTACH_REJECT,
79		ATTACH_OKAY,
80		SYSCALL_EPERM,
81		SYSCALL_ENOTSUPP,
82		SUCCESS,
83	} expected_result;
84};
85
86static int bind4_prog_load(const struct sock_addr_test *test);
87static int bind6_prog_load(const struct sock_addr_test *test);
88static int connect4_prog_load(const struct sock_addr_test *test);
89static int connect6_prog_load(const struct sock_addr_test *test);
90static int sendmsg_allow_prog_load(const struct sock_addr_test *test);
91static int sendmsg_deny_prog_load(const struct sock_addr_test *test);
92static int recvmsg_allow_prog_load(const struct sock_addr_test *test);
93static int recvmsg_deny_prog_load(const struct sock_addr_test *test);
94static int sendmsg4_rw_asm_prog_load(const struct sock_addr_test *test);
95static int recvmsg4_rw_asm_prog_load(const struct sock_addr_test *test);
96static int sendmsg4_rw_c_prog_load(const struct sock_addr_test *test);
97static int sendmsg6_rw_asm_prog_load(const struct sock_addr_test *test);
98static int recvmsg6_rw_asm_prog_load(const struct sock_addr_test *test);
99static int sendmsg6_rw_c_prog_load(const struct sock_addr_test *test);
100static int sendmsg6_rw_v4mapped_prog_load(const struct sock_addr_test *test);
101static int sendmsg6_rw_wildcard_prog_load(const struct sock_addr_test *test);
102
103static struct sock_addr_test tests[] = {
104	/* bind */
105	{
106		"bind4: load prog with wrong expected attach type",
107		bind4_prog_load,
108		BPF_CGROUP_INET6_BIND,
109		BPF_CGROUP_INET4_BIND,
110		AF_INET,
111		SOCK_STREAM,
112		NULL,
113		0,
114		NULL,
115		0,
116		NULL,
117		LOAD_REJECT,
118	},
119	{
120		"bind4: attach prog with wrong attach type",
121		bind4_prog_load,
122		BPF_CGROUP_INET4_BIND,
123		BPF_CGROUP_INET6_BIND,
124		AF_INET,
125		SOCK_STREAM,
126		NULL,
127		0,
128		NULL,
129		0,
130		NULL,
131		ATTACH_REJECT,
132	},
133	{
134		"bind4: rewrite IP & TCP port in",
135		bind4_prog_load,
136		BPF_CGROUP_INET4_BIND,
137		BPF_CGROUP_INET4_BIND,
138		AF_INET,
139		SOCK_STREAM,
140		SERV4_IP,
141		SERV4_PORT,
142		SERV4_REWRITE_IP,
143		SERV4_REWRITE_PORT,
144		NULL,
145		SUCCESS,
146	},
147	{
148		"bind4: rewrite IP & UDP port in",
149		bind4_prog_load,
150		BPF_CGROUP_INET4_BIND,
151		BPF_CGROUP_INET4_BIND,
152		AF_INET,
153		SOCK_DGRAM,
154		SERV4_IP,
155		SERV4_PORT,
156		SERV4_REWRITE_IP,
157		SERV4_REWRITE_PORT,
158		NULL,
159		SUCCESS,
160	},
161	{
162		"bind6: load prog with wrong expected attach type",
163		bind6_prog_load,
164		BPF_CGROUP_INET4_BIND,
165		BPF_CGROUP_INET6_BIND,
166		AF_INET6,
167		SOCK_STREAM,
168		NULL,
169		0,
170		NULL,
171		0,
172		NULL,
173		LOAD_REJECT,
174	},
175	{
176		"bind6: attach prog with wrong attach type",
177		bind6_prog_load,
178		BPF_CGROUP_INET6_BIND,
179		BPF_CGROUP_INET4_BIND,
180		AF_INET,
181		SOCK_STREAM,
182		NULL,
183		0,
184		NULL,
185		0,
186		NULL,
187		ATTACH_REJECT,
188	},
189	{
190		"bind6: rewrite IP & TCP port in",
191		bind6_prog_load,
192		BPF_CGROUP_INET6_BIND,
193		BPF_CGROUP_INET6_BIND,
194		AF_INET6,
195		SOCK_STREAM,
196		SERV6_IP,
197		SERV6_PORT,
198		SERV6_REWRITE_IP,
199		SERV6_REWRITE_PORT,
200		NULL,
201		SUCCESS,
202	},
203	{
204		"bind6: rewrite IP & UDP port in",
205		bind6_prog_load,
206		BPF_CGROUP_INET6_BIND,
207		BPF_CGROUP_INET6_BIND,
208		AF_INET6,
209		SOCK_DGRAM,
210		SERV6_IP,
211		SERV6_PORT,
212		SERV6_REWRITE_IP,
213		SERV6_REWRITE_PORT,
214		NULL,
215		SUCCESS,
216	},
217
218	/* connect */
219	{
220		"connect4: load prog with wrong expected attach type",
221		connect4_prog_load,
222		BPF_CGROUP_INET6_CONNECT,
223		BPF_CGROUP_INET4_CONNECT,
224		AF_INET,
225		SOCK_STREAM,
226		NULL,
227		0,
228		NULL,
229		0,
230		NULL,
231		LOAD_REJECT,
232	},
233	{
234		"connect4: attach prog with wrong attach type",
235		connect4_prog_load,
236		BPF_CGROUP_INET4_CONNECT,
237		BPF_CGROUP_INET6_CONNECT,
238		AF_INET,
239		SOCK_STREAM,
240		NULL,
241		0,
242		NULL,
243		0,
244		NULL,
245		ATTACH_REJECT,
246	},
247	{
248		"connect4: rewrite IP & TCP port",
249		connect4_prog_load,
250		BPF_CGROUP_INET4_CONNECT,
251		BPF_CGROUP_INET4_CONNECT,
252		AF_INET,
253		SOCK_STREAM,
254		SERV4_IP,
255		SERV4_PORT,
256		SERV4_REWRITE_IP,
257		SERV4_REWRITE_PORT,
258		SRC4_REWRITE_IP,
259		SUCCESS,
260	},
261	{
262		"connect4: rewrite IP & UDP port",
263		connect4_prog_load,
264		BPF_CGROUP_INET4_CONNECT,
265		BPF_CGROUP_INET4_CONNECT,
266		AF_INET,
267		SOCK_DGRAM,
268		SERV4_IP,
269		SERV4_PORT,
270		SERV4_REWRITE_IP,
271		SERV4_REWRITE_PORT,
272		SRC4_REWRITE_IP,
273		SUCCESS,
274	},
275	{
276		"connect6: load prog with wrong expected attach type",
277		connect6_prog_load,
278		BPF_CGROUP_INET4_CONNECT,
279		BPF_CGROUP_INET6_CONNECT,
280		AF_INET6,
281		SOCK_STREAM,
282		NULL,
283		0,
284		NULL,
285		0,
286		NULL,
287		LOAD_REJECT,
288	},
289	{
290		"connect6: attach prog with wrong attach type",
291		connect6_prog_load,
292		BPF_CGROUP_INET6_CONNECT,
293		BPF_CGROUP_INET4_CONNECT,
294		AF_INET,
295		SOCK_STREAM,
296		NULL,
297		0,
298		NULL,
299		0,
300		NULL,
301		ATTACH_REJECT,
302	},
303	{
304		"connect6: rewrite IP & TCP port",
305		connect6_prog_load,
306		BPF_CGROUP_INET6_CONNECT,
307		BPF_CGROUP_INET6_CONNECT,
308		AF_INET6,
309		SOCK_STREAM,
310		SERV6_IP,
311		SERV6_PORT,
312		SERV6_REWRITE_IP,
313		SERV6_REWRITE_PORT,
314		SRC6_REWRITE_IP,
315		SUCCESS,
316	},
317	{
318		"connect6: rewrite IP & UDP port",
319		connect6_prog_load,
320		BPF_CGROUP_INET6_CONNECT,
321		BPF_CGROUP_INET6_CONNECT,
322		AF_INET6,
323		SOCK_DGRAM,
324		SERV6_IP,
325		SERV6_PORT,
326		SERV6_REWRITE_IP,
327		SERV6_REWRITE_PORT,
328		SRC6_REWRITE_IP,
329		SUCCESS,
330	},
331
332	/* sendmsg */
333	{
334		"sendmsg4: load prog with wrong expected attach type",
335		sendmsg4_rw_asm_prog_load,
336		BPF_CGROUP_UDP6_SENDMSG,
337		BPF_CGROUP_UDP4_SENDMSG,
338		AF_INET,
339		SOCK_DGRAM,
340		NULL,
341		0,
342		NULL,
343		0,
344		NULL,
345		LOAD_REJECT,
346	},
347	{
348		"sendmsg4: attach prog with wrong attach type",
349		sendmsg4_rw_asm_prog_load,
350		BPF_CGROUP_UDP4_SENDMSG,
351		BPF_CGROUP_UDP6_SENDMSG,
352		AF_INET,
353		SOCK_DGRAM,
354		NULL,
355		0,
356		NULL,
357		0,
358		NULL,
359		ATTACH_REJECT,
360	},
361	{
362		"sendmsg4: rewrite IP & port (asm)",
363		sendmsg4_rw_asm_prog_load,
364		BPF_CGROUP_UDP4_SENDMSG,
365		BPF_CGROUP_UDP4_SENDMSG,
366		AF_INET,
367		SOCK_DGRAM,
368		SERV4_IP,
369		SERV4_PORT,
370		SERV4_REWRITE_IP,
371		SERV4_REWRITE_PORT,
372		SRC4_REWRITE_IP,
373		SUCCESS,
374	},
375	{
376		"sendmsg4: rewrite IP & port (C)",
377		sendmsg4_rw_c_prog_load,
378		BPF_CGROUP_UDP4_SENDMSG,
379		BPF_CGROUP_UDP4_SENDMSG,
380		AF_INET,
381		SOCK_DGRAM,
382		SERV4_IP,
383		SERV4_PORT,
384		SERV4_REWRITE_IP,
385		SERV4_REWRITE_PORT,
386		SRC4_REWRITE_IP,
387		SUCCESS,
388	},
389	{
390		"sendmsg4: deny call",
391		sendmsg_deny_prog_load,
392		BPF_CGROUP_UDP4_SENDMSG,
393		BPF_CGROUP_UDP4_SENDMSG,
394		AF_INET,
395		SOCK_DGRAM,
396		SERV4_IP,
397		SERV4_PORT,
398		SERV4_REWRITE_IP,
399		SERV4_REWRITE_PORT,
400		SRC4_REWRITE_IP,
401		SYSCALL_EPERM,
402	},
403	{
404		"sendmsg6: load prog with wrong expected attach type",
405		sendmsg6_rw_asm_prog_load,
406		BPF_CGROUP_UDP4_SENDMSG,
407		BPF_CGROUP_UDP6_SENDMSG,
408		AF_INET6,
409		SOCK_DGRAM,
410		NULL,
411		0,
412		NULL,
413		0,
414		NULL,
415		LOAD_REJECT,
416	},
417	{
418		"sendmsg6: attach prog with wrong attach type",
419		sendmsg6_rw_asm_prog_load,
420		BPF_CGROUP_UDP6_SENDMSG,
421		BPF_CGROUP_UDP4_SENDMSG,
422		AF_INET6,
423		SOCK_DGRAM,
424		NULL,
425		0,
426		NULL,
427		0,
428		NULL,
429		ATTACH_REJECT,
430	},
431	{
432		"sendmsg6: rewrite IP & port (asm)",
433		sendmsg6_rw_asm_prog_load,
434		BPF_CGROUP_UDP6_SENDMSG,
435		BPF_CGROUP_UDP6_SENDMSG,
436		AF_INET6,
437		SOCK_DGRAM,
438		SERV6_IP,
439		SERV6_PORT,
440		SERV6_REWRITE_IP,
441		SERV6_REWRITE_PORT,
442		SRC6_REWRITE_IP,
443		SUCCESS,
444	},
445	{
446		"sendmsg6: rewrite IP & port (C)",
447		sendmsg6_rw_c_prog_load,
448		BPF_CGROUP_UDP6_SENDMSG,
449		BPF_CGROUP_UDP6_SENDMSG,
450		AF_INET6,
451		SOCK_DGRAM,
452		SERV6_IP,
453		SERV6_PORT,
454		SERV6_REWRITE_IP,
455		SERV6_REWRITE_PORT,
456		SRC6_REWRITE_IP,
457		SUCCESS,
458	},
459	{
460		"sendmsg6: IPv4-mapped IPv6",
461		sendmsg6_rw_v4mapped_prog_load,
462		BPF_CGROUP_UDP6_SENDMSG,
463		BPF_CGROUP_UDP6_SENDMSG,
464		AF_INET6,
465		SOCK_DGRAM,
466		SERV6_IP,
467		SERV6_PORT,
468		SERV6_REWRITE_IP,
469		SERV6_REWRITE_PORT,
470		SRC6_REWRITE_IP,
471		SYSCALL_ENOTSUPP,
472	},
473	{
474		"sendmsg6: set dst IP = [::] (BSD'ism)",
475		sendmsg6_rw_wildcard_prog_load,
476		BPF_CGROUP_UDP6_SENDMSG,
477		BPF_CGROUP_UDP6_SENDMSG,
478		AF_INET6,
479		SOCK_DGRAM,
480		SERV6_IP,
481		SERV6_PORT,
482		SERV6_REWRITE_IP,
483		SERV6_REWRITE_PORT,
484		SRC6_REWRITE_IP,
485		SUCCESS,
486	},
487	{
488		"sendmsg6: preserve dst IP = [::] (BSD'ism)",
489		sendmsg_allow_prog_load,
490		BPF_CGROUP_UDP6_SENDMSG,
491		BPF_CGROUP_UDP6_SENDMSG,
492		AF_INET6,
493		SOCK_DGRAM,
494		WILDCARD6_IP,
495		SERV6_PORT,
496		SERV6_REWRITE_IP,
497		SERV6_PORT,
498		SRC6_IP,
499		SUCCESS,
500	},
501	{
502		"sendmsg6: deny call",
503		sendmsg_deny_prog_load,
504		BPF_CGROUP_UDP6_SENDMSG,
505		BPF_CGROUP_UDP6_SENDMSG,
506		AF_INET6,
507		SOCK_DGRAM,
508		SERV6_IP,
509		SERV6_PORT,
510		SERV6_REWRITE_IP,
511		SERV6_REWRITE_PORT,
512		SRC6_REWRITE_IP,
513		SYSCALL_EPERM,
514	},
515
516	/* recvmsg */
517	{
518		"recvmsg4: return code ok",
519		recvmsg_allow_prog_load,
520		BPF_CGROUP_UDP4_RECVMSG,
521		BPF_CGROUP_UDP4_RECVMSG,
522		AF_INET,
523		SOCK_DGRAM,
524		NULL,
525		0,
526		NULL,
527		0,
528		NULL,
529		ATTACH_OKAY,
530	},
531	{
532		"recvmsg4: return code !ok",
533		recvmsg_deny_prog_load,
534		BPF_CGROUP_UDP4_RECVMSG,
535		BPF_CGROUP_UDP4_RECVMSG,
536		AF_INET,
537		SOCK_DGRAM,
538		NULL,
539		0,
540		NULL,
541		0,
542		NULL,
543		LOAD_REJECT,
544	},
545	{
546		"recvmsg6: return code ok",
547		recvmsg_allow_prog_load,
548		BPF_CGROUP_UDP6_RECVMSG,
549		BPF_CGROUP_UDP6_RECVMSG,
550		AF_INET6,
551		SOCK_DGRAM,
552		NULL,
553		0,
554		NULL,
555		0,
556		NULL,
557		ATTACH_OKAY,
558	},
559	{
560		"recvmsg6: return code !ok",
561		recvmsg_deny_prog_load,
562		BPF_CGROUP_UDP6_RECVMSG,
563		BPF_CGROUP_UDP6_RECVMSG,
564		AF_INET6,
565		SOCK_DGRAM,
566		NULL,
567		0,
568		NULL,
569		0,
570		NULL,
571		LOAD_REJECT,
572	},
573	{
574		"recvmsg4: rewrite IP & port (asm)",
575		recvmsg4_rw_asm_prog_load,
576		BPF_CGROUP_UDP4_RECVMSG,
577		BPF_CGROUP_UDP4_RECVMSG,
578		AF_INET,
579		SOCK_DGRAM,
580		SERV4_REWRITE_IP,
581		SERV4_REWRITE_PORT,
582		SERV4_REWRITE_IP,
583		SERV4_REWRITE_PORT,
584		SERV4_IP,
585		SUCCESS,
586	},
587	{
588		"recvmsg6: rewrite IP & port (asm)",
589		recvmsg6_rw_asm_prog_load,
590		BPF_CGROUP_UDP6_RECVMSG,
591		BPF_CGROUP_UDP6_RECVMSG,
592		AF_INET6,
593		SOCK_DGRAM,
594		SERV6_REWRITE_IP,
595		SERV6_REWRITE_PORT,
596		SERV6_REWRITE_IP,
597		SERV6_REWRITE_PORT,
598		SERV6_IP,
599		SUCCESS,
600	},
601};
602
603static int mk_sockaddr(int domain, const char *ip, unsigned short port,
604		       struct sockaddr *addr, socklen_t addr_len)
605{
606	struct sockaddr_in6 *addr6;
607	struct sockaddr_in *addr4;
608
609	if (domain != AF_INET && domain != AF_INET6) {
610		log_err("Unsupported address family");
611		return -1;
612	}
613
614	memset(addr, 0, addr_len);
615
616	if (domain == AF_INET) {
617		if (addr_len < sizeof(struct sockaddr_in))
618			return -1;
619		addr4 = (struct sockaddr_in *)addr;
620		addr4->sin_family = domain;
621		addr4->sin_port = htons(port);
622		if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1) {
623			log_err("Invalid IPv4: %s", ip);
624			return -1;
625		}
626	} else if (domain == AF_INET6) {
627		if (addr_len < sizeof(struct sockaddr_in6))
628			return -1;
629		addr6 = (struct sockaddr_in6 *)addr;
630		addr6->sin6_family = domain;
631		addr6->sin6_port = htons(port);
632		if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1) {
633			log_err("Invalid IPv6: %s", ip);
634			return -1;
635		}
636	}
637
638	return 0;
639}
640
641static int load_insns(const struct sock_addr_test *test,
642		      const struct bpf_insn *insns, size_t insns_cnt)
643{
644	struct bpf_load_program_attr load_attr;
645	int ret;
646
647	memset(&load_attr, 0, sizeof(struct bpf_load_program_attr));
648	load_attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR;
649	load_attr.expected_attach_type = test->expected_attach_type;
650	load_attr.insns = insns;
651	load_attr.insns_cnt = insns_cnt;
652	load_attr.license = "GPL";
653
654	ret = bpf_load_program_xattr(&load_attr, bpf_log_buf, BPF_LOG_BUF_SIZE);
655	if (ret < 0 && test->expected_result != LOAD_REJECT) {
656		log_err(">>> Loading program error.\n"
657			">>> Verifier output:\n%s\n-------\n", bpf_log_buf);
658	}
659
660	return ret;
661}
662
663/* [1] These testing programs try to read different context fields, including
664 * narrow loads of different sizes from user_ip4 and user_ip6, and write to
665 * those allowed to be overridden.
666 *
667 * [2] BPF_LD_IMM64 & BPF_JMP_REG are used below whenever there is a need to
668 * compare a register with unsigned 32bit integer. BPF_JMP_IMM can't be used
669 * in such cases since it accepts only _signed_ 32bit integer as IMM
670 * argument. Also note that BPF_LD_IMM64 contains 2 instructions what matters
671 * to count jumps properly.
672 */
673
674static int bind4_prog_load(const struct sock_addr_test *test)
675{
676	union {
677		uint8_t u4_addr8[4];
678		uint16_t u4_addr16[2];
679		uint32_t u4_addr32;
680	} ip4, port;
681	struct sockaddr_in addr4_rw;
682
683	if (inet_pton(AF_INET, SERV4_IP, (void *)&ip4) != 1) {
684		log_err("Invalid IPv4: %s", SERV4_IP);
685		return -1;
686	}
687
688	port.u4_addr32 = htons(SERV4_PORT);
689
690	if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT,
691			(struct sockaddr *)&addr4_rw, sizeof(addr4_rw)) == -1)
692		return -1;
693
694	/* See [1]. */
695	struct bpf_insn insns[] = {
696		BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
697
698		/* if (sk.family == AF_INET && */
699		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
700			    offsetof(struct bpf_sock_addr, family)),
701		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 32),
702
703		/*     (sk.type == SOCK_DGRAM || sk.type == SOCK_STREAM) && */
704		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
705			    offsetof(struct bpf_sock_addr, type)),
706		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 1),
707		BPF_JMP_A(1),
708		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_STREAM, 28),
709
710		/*     1st_byte_of_user_ip4 == expected && */
711		BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
712			    offsetof(struct bpf_sock_addr, user_ip4)),
713		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[0], 26),
714
715		/*     2nd_byte_of_user_ip4 == expected && */
716		BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
717			    offsetof(struct bpf_sock_addr, user_ip4) + 1),
718		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[1], 24),
719
720		/*     3rd_byte_of_user_ip4 == expected && */
721		BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
722			    offsetof(struct bpf_sock_addr, user_ip4) + 2),
723		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[2], 22),
724
725		/*     4th_byte_of_user_ip4 == expected && */
726		BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
727			    offsetof(struct bpf_sock_addr, user_ip4) + 3),
728		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[3], 20),
729
730		/*     1st_half_of_user_ip4 == expected && */
731		BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
732			    offsetof(struct bpf_sock_addr, user_ip4)),
733		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr16[0], 18),
734
735		/*     2nd_half_of_user_ip4 == expected && */
736		BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
737			    offsetof(struct bpf_sock_addr, user_ip4) + 2),
738		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr16[1], 16),
739
740		/*     whole_user_ip4 == expected && */
741		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
742			    offsetof(struct bpf_sock_addr, user_ip4)),
743		BPF_LD_IMM64(BPF_REG_8, ip4.u4_addr32), /* See [2]. */
744		BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 12),
745
746		/*     1st_byte_of_user_port == expected && */
747		BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
748			    offsetof(struct bpf_sock_addr, user_port)),
749		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, port.u4_addr8[0], 10),
750
751		/*     1st_half_of_user_port == expected && */
752		BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
753			    offsetof(struct bpf_sock_addr, user_port)),
754		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, port.u4_addr16[0], 8),
755
756		/*     user_port == expected) { */
757		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
758			    offsetof(struct bpf_sock_addr, user_port)),
759		BPF_LD_IMM64(BPF_REG_8, port.u4_addr32), /* See [2]. */
760		BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 4),
761
762		/*      user_ip4 = addr4_rw.sin_addr */
763		BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_addr.s_addr),
764		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
765			    offsetof(struct bpf_sock_addr, user_ip4)),
766
767		/*      user_port = addr4_rw.sin_port */
768		BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_port),
769		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
770			    offsetof(struct bpf_sock_addr, user_port)),
771		/* } */
772
773		/* return 1 */
774		BPF_MOV64_IMM(BPF_REG_0, 1),
775		BPF_EXIT_INSN(),
776	};
777
778	return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
779}
780
781static int bind6_prog_load(const struct sock_addr_test *test)
782{
783	struct sockaddr_in6 addr6_rw;
784	struct in6_addr ip6;
785
786	if (inet_pton(AF_INET6, SERV6_IP, (void *)&ip6) != 1) {
787		log_err("Invalid IPv6: %s", SERV6_IP);
788		return -1;
789	}
790
791	if (mk_sockaddr(AF_INET6, SERV6_REWRITE_IP, SERV6_REWRITE_PORT,
792			(struct sockaddr *)&addr6_rw, sizeof(addr6_rw)) == -1)
793		return -1;
794
795	/* See [1]. */
796	struct bpf_insn insns[] = {
797		BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
798
799		/* if (sk.family == AF_INET6 && */
800		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
801			    offsetof(struct bpf_sock_addr, family)),
802		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 18),
803
804		/*            5th_byte_of_user_ip6 == expected && */
805		BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
806			    offsetof(struct bpf_sock_addr, user_ip6[1])),
807		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr[4], 16),
808
809		/*            3rd_half_of_user_ip6 == expected && */
810		BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
811			    offsetof(struct bpf_sock_addr, user_ip6[1])),
812		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr16[2], 14),
813
814		/*            last_word_of_user_ip6 == expected) { */
815		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
816			    offsetof(struct bpf_sock_addr, user_ip6[3])),
817		BPF_LD_IMM64(BPF_REG_8, ip6.s6_addr32[3]),  /* See [2]. */
818		BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 10),
819
820
821#define STORE_IPV6_WORD(N)						       \
822		BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_addr.s6_addr32[N]),     \
823		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,		       \
824			    offsetof(struct bpf_sock_addr, user_ip6[N]))
825
826		/*      user_ip6 = addr6_rw.sin6_addr */
827		STORE_IPV6_WORD(0),
828		STORE_IPV6_WORD(1),
829		STORE_IPV6_WORD(2),
830		STORE_IPV6_WORD(3),
831
832		/*      user_port = addr6_rw.sin6_port */
833		BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_port),
834		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
835			    offsetof(struct bpf_sock_addr, user_port)),
836
837		/* } */
838
839		/* return 1 */
840		BPF_MOV64_IMM(BPF_REG_0, 1),
841		BPF_EXIT_INSN(),
842	};
843
844	return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
845}
846
847static int load_path(const struct sock_addr_test *test, const char *path)
848{
849	struct bpf_prog_load_attr attr;
850	struct bpf_object *obj;
851	int prog_fd;
852
853	memset(&attr, 0, sizeof(struct bpf_prog_load_attr));
854	attr.file = path;
855	attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR;
856	attr.expected_attach_type = test->expected_attach_type;
857	attr.prog_flags = BPF_F_TEST_RND_HI32;
858
859	if (bpf_prog_load_xattr(&attr, &obj, &prog_fd)) {
860		if (test->expected_result != LOAD_REJECT)
861			log_err(">>> Loading program (%s) error.\n", path);
862		return -1;
863	}
864
865	return prog_fd;
866}
867
868static int connect4_prog_load(const struct sock_addr_test *test)
869{
870	return load_path(test, CONNECT4_PROG_PATH);
871}
872
873static int connect6_prog_load(const struct sock_addr_test *test)
874{
875	return load_path(test, CONNECT6_PROG_PATH);
876}
877
878static int xmsg_ret_only_prog_load(const struct sock_addr_test *test,
879				   int32_t rc)
880{
881	struct bpf_insn insns[] = {
882		/* return rc */
883		BPF_MOV64_IMM(BPF_REG_0, rc),
884		BPF_EXIT_INSN(),
885	};
886	return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
887}
888
889static int sendmsg_allow_prog_load(const struct sock_addr_test *test)
890{
891	return xmsg_ret_only_prog_load(test, /*rc*/ 1);
892}
893
894static int sendmsg_deny_prog_load(const struct sock_addr_test *test)
895{
896	return xmsg_ret_only_prog_load(test, /*rc*/ 0);
897}
898
899static int recvmsg_allow_prog_load(const struct sock_addr_test *test)
900{
901	return xmsg_ret_only_prog_load(test, /*rc*/ 1);
902}
903
904static int recvmsg_deny_prog_load(const struct sock_addr_test *test)
905{
906	return xmsg_ret_only_prog_load(test, /*rc*/ 0);
907}
908
909static int sendmsg4_rw_asm_prog_load(const struct sock_addr_test *test)
910{
911	struct sockaddr_in dst4_rw_addr;
912	struct in_addr src4_rw_ip;
913
914	if (inet_pton(AF_INET, SRC4_REWRITE_IP, (void *)&src4_rw_ip) != 1) {
915		log_err("Invalid IPv4: %s", SRC4_REWRITE_IP);
916		return -1;
917	}
918
919	if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT,
920			(struct sockaddr *)&dst4_rw_addr,
921			sizeof(dst4_rw_addr)) == -1)
922		return -1;
923
924	struct bpf_insn insns[] = {
925		BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
926
927		/* if (sk.family == AF_INET && */
928		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
929			    offsetof(struct bpf_sock_addr, family)),
930		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 8),
931
932		/*     sk.type == SOCK_DGRAM)  { */
933		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
934			    offsetof(struct bpf_sock_addr, type)),
935		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 6),
936
937		/*      msg_src_ip4 = src4_rw_ip */
938		BPF_MOV32_IMM(BPF_REG_7, src4_rw_ip.s_addr),
939		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
940			    offsetof(struct bpf_sock_addr, msg_src_ip4)),
941
942		/*      user_ip4 = dst4_rw_addr.sin_addr */
943		BPF_MOV32_IMM(BPF_REG_7, dst4_rw_addr.sin_addr.s_addr),
944		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
945			    offsetof(struct bpf_sock_addr, user_ip4)),
946
947		/*      user_port = dst4_rw_addr.sin_port */
948		BPF_MOV32_IMM(BPF_REG_7, dst4_rw_addr.sin_port),
949		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
950			    offsetof(struct bpf_sock_addr, user_port)),
951		/* } */
952
953		/* return 1 */
954		BPF_MOV64_IMM(BPF_REG_0, 1),
955		BPF_EXIT_INSN(),
956	};
957
958	return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
959}
960
961static int recvmsg4_rw_asm_prog_load(const struct sock_addr_test *test)
962{
963	struct sockaddr_in src4_rw_addr;
964
965	if (mk_sockaddr(AF_INET, SERV4_IP, SERV4_PORT,
966			(struct sockaddr *)&src4_rw_addr,
967			sizeof(src4_rw_addr)) == -1)
968		return -1;
969
970	struct bpf_insn insns[] = {
971		BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
972
973		/* if (sk.family == AF_INET && */
974		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
975			    offsetof(struct bpf_sock_addr, family)),
976		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 6),
977
978		/*     sk.type == SOCK_DGRAM)  { */
979		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
980			    offsetof(struct bpf_sock_addr, type)),
981		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 4),
982
983		/*      user_ip4 = src4_rw_addr.sin_addr */
984		BPF_MOV32_IMM(BPF_REG_7, src4_rw_addr.sin_addr.s_addr),
985		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
986			    offsetof(struct bpf_sock_addr, user_ip4)),
987
988		/*      user_port = src4_rw_addr.sin_port */
989		BPF_MOV32_IMM(BPF_REG_7, src4_rw_addr.sin_port),
990		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
991			    offsetof(struct bpf_sock_addr, user_port)),
992		/* } */
993
994		/* return 1 */
995		BPF_MOV64_IMM(BPF_REG_0, 1),
996		BPF_EXIT_INSN(),
997	};
998
999	return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
1000}
1001
1002static int sendmsg4_rw_c_prog_load(const struct sock_addr_test *test)
1003{
1004	return load_path(test, SENDMSG4_PROG_PATH);
1005}
1006
1007static int sendmsg6_rw_dst_asm_prog_load(const struct sock_addr_test *test,
1008					 const char *rw_dst_ip)
1009{
1010	struct sockaddr_in6 dst6_rw_addr;
1011	struct in6_addr src6_rw_ip;
1012
1013	if (inet_pton(AF_INET6, SRC6_REWRITE_IP, (void *)&src6_rw_ip) != 1) {
1014		log_err("Invalid IPv6: %s", SRC6_REWRITE_IP);
1015		return -1;
1016	}
1017
1018	if (mk_sockaddr(AF_INET6, rw_dst_ip, SERV6_REWRITE_PORT,
1019			(struct sockaddr *)&dst6_rw_addr,
1020			sizeof(dst6_rw_addr)) == -1)
1021		return -1;
1022
1023	struct bpf_insn insns[] = {
1024		BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
1025
1026		/* if (sk.family == AF_INET6) { */
1027		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
1028			    offsetof(struct bpf_sock_addr, family)),
1029		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 18),
1030
1031#define STORE_IPV6_WORD_N(DST, SRC, N)					       \
1032		BPF_MOV32_IMM(BPF_REG_7, SRC[N]),			       \
1033		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,		       \
1034			    offsetof(struct bpf_sock_addr, DST[N]))
1035
1036#define STORE_IPV6(DST, SRC)						       \
1037		STORE_IPV6_WORD_N(DST, SRC, 0),				       \
1038		STORE_IPV6_WORD_N(DST, SRC, 1),				       \
1039		STORE_IPV6_WORD_N(DST, SRC, 2),				       \
1040		STORE_IPV6_WORD_N(DST, SRC, 3)
1041
1042		STORE_IPV6(msg_src_ip6, src6_rw_ip.s6_addr32),
1043		STORE_IPV6(user_ip6, dst6_rw_addr.sin6_addr.s6_addr32),
1044
1045		/*      user_port = dst6_rw_addr.sin6_port */
1046		BPF_MOV32_IMM(BPF_REG_7, dst6_rw_addr.sin6_port),
1047		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
1048			    offsetof(struct bpf_sock_addr, user_port)),
1049
1050		/* } */
1051
1052		/* return 1 */
1053		BPF_MOV64_IMM(BPF_REG_0, 1),
1054		BPF_EXIT_INSN(),
1055	};
1056
1057	return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
1058}
1059
1060static int sendmsg6_rw_asm_prog_load(const struct sock_addr_test *test)
1061{
1062	return sendmsg6_rw_dst_asm_prog_load(test, SERV6_REWRITE_IP);
1063}
1064
1065static int recvmsg6_rw_asm_prog_load(const struct sock_addr_test *test)
1066{
1067	struct sockaddr_in6 src6_rw_addr;
1068
1069	if (mk_sockaddr(AF_INET6, SERV6_IP, SERV6_PORT,
1070			(struct sockaddr *)&src6_rw_addr,
1071			sizeof(src6_rw_addr)) == -1)
1072		return -1;
1073
1074	struct bpf_insn insns[] = {
1075		BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
1076
1077		/* if (sk.family == AF_INET6) { */
1078		BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
1079			    offsetof(struct bpf_sock_addr, family)),
1080		BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 10),
1081
1082		STORE_IPV6(user_ip6, src6_rw_addr.sin6_addr.s6_addr32),
1083
1084		/*      user_port = dst6_rw_addr.sin6_port */
1085		BPF_MOV32_IMM(BPF_REG_7, src6_rw_addr.sin6_port),
1086		BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
1087			    offsetof(struct bpf_sock_addr, user_port)),
1088		/* } */
1089
1090		/* return 1 */
1091		BPF_MOV64_IMM(BPF_REG_0, 1),
1092		BPF_EXIT_INSN(),
1093	};
1094
1095	return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
1096}
1097
1098static int sendmsg6_rw_v4mapped_prog_load(const struct sock_addr_test *test)
1099{
1100	return sendmsg6_rw_dst_asm_prog_load(test, SERV6_V4MAPPED_IP);
1101}
1102
1103static int sendmsg6_rw_wildcard_prog_load(const struct sock_addr_test *test)
1104{
1105	return sendmsg6_rw_dst_asm_prog_load(test, WILDCARD6_IP);
1106}
1107
1108static int sendmsg6_rw_c_prog_load(const struct sock_addr_test *test)
1109{
1110	return load_path(test, SENDMSG6_PROG_PATH);
1111}
1112
1113static int cmp_addr(const struct sockaddr_storage *addr1,
1114		    const struct sockaddr_storage *addr2, int cmp_port)
1115{
1116	const struct sockaddr_in *four1, *four2;
1117	const struct sockaddr_in6 *six1, *six2;
1118
1119	if (addr1->ss_family != addr2->ss_family)
1120		return -1;
1121
1122	if (addr1->ss_family == AF_INET) {
1123		four1 = (const struct sockaddr_in *)addr1;
1124		four2 = (const struct sockaddr_in *)addr2;
1125		return !((four1->sin_port == four2->sin_port || !cmp_port) &&
1126			 four1->sin_addr.s_addr == four2->sin_addr.s_addr);
1127	} else if (addr1->ss_family == AF_INET6) {
1128		six1 = (const struct sockaddr_in6 *)addr1;
1129		six2 = (const struct sockaddr_in6 *)addr2;
1130		return !((six1->sin6_port == six2->sin6_port || !cmp_port) &&
1131			 !memcmp(&six1->sin6_addr, &six2->sin6_addr,
1132				 sizeof(struct in6_addr)));
1133	}
1134
1135	return -1;
1136}
1137
1138static int cmp_sock_addr(info_fn fn, int sock1,
1139			 const struct sockaddr_storage *addr2, int cmp_port)
1140{
1141	struct sockaddr_storage addr1;
1142	socklen_t len1 = sizeof(addr1);
1143
1144	memset(&addr1, 0, len1);
1145	if (fn(sock1, (struct sockaddr *)&addr1, (socklen_t *)&len1) != 0)
1146		return -1;
1147
1148	return cmp_addr(&addr1, addr2, cmp_port);
1149}
1150
1151static int cmp_local_ip(int sock1, const struct sockaddr_storage *addr2)
1152{
1153	return cmp_sock_addr(getsockname, sock1, addr2, /*cmp_port*/ 0);
1154}
1155
1156static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2)
1157{
1158	return cmp_sock_addr(getsockname, sock1, addr2, /*cmp_port*/ 1);
1159}
1160
1161static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2)
1162{
1163	return cmp_sock_addr(getpeername, sock1, addr2, /*cmp_port*/ 1);
1164}
1165
1166static int start_server(int type, const struct sockaddr_storage *addr,
1167			socklen_t addr_len)
1168{
1169	int fd;
1170
1171	fd = socket(addr->ss_family, type, 0);
1172	if (fd == -1) {
1173		log_err("Failed to create server socket");
1174		goto out;
1175	}
1176
1177	if (bind(fd, (const struct sockaddr *)addr, addr_len) == -1) {
1178		log_err("Failed to bind server socket");
1179		goto close_out;
1180	}
1181
1182	if (type == SOCK_STREAM) {
1183		if (listen(fd, 128) == -1) {
1184			log_err("Failed to listen on server socket");
1185			goto close_out;
1186		}
1187	}
1188
1189	goto out;
1190close_out:
1191	close(fd);
1192	fd = -1;
1193out:
1194	return fd;
1195}
1196
1197static int connect_to_server(int type, const struct sockaddr_storage *addr,
1198			     socklen_t addr_len)
1199{
1200	int domain;
1201	int fd = -1;
1202
1203	domain = addr->ss_family;
1204
1205	if (domain != AF_INET && domain != AF_INET6) {
1206		log_err("Unsupported address family");
1207		goto err;
1208	}
1209
1210	fd = socket(domain, type, 0);
1211	if (fd == -1) {
1212		log_err("Failed to create client socket");
1213		goto err;
1214	}
1215
1216	if (connect(fd, (const struct sockaddr *)addr, addr_len) == -1) {
1217		log_err("Fail to connect to server");
1218		goto err;
1219	}
1220
1221	goto out;
1222err:
1223	close(fd);
1224	fd = -1;
1225out:
1226	return fd;
1227}
1228
1229int init_pktinfo(int domain, struct cmsghdr *cmsg)
1230{
1231	struct in6_pktinfo *pktinfo6;
1232	struct in_pktinfo *pktinfo4;
1233
1234	if (domain == AF_INET) {
1235		cmsg->cmsg_level = SOL_IP;
1236		cmsg->cmsg_type = IP_PKTINFO;
1237		cmsg->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
1238		pktinfo4 = (struct in_pktinfo *)CMSG_DATA(cmsg);
1239		memset(pktinfo4, 0, sizeof(struct in_pktinfo));
1240		if (inet_pton(domain, SRC4_IP,
1241			      (void *)&pktinfo4->ipi_spec_dst) != 1)
1242			return -1;
1243	} else if (domain == AF_INET6) {
1244		cmsg->cmsg_level = SOL_IPV6;
1245		cmsg->cmsg_type = IPV6_PKTINFO;
1246		cmsg->cmsg_len = CMSG_LEN(sizeof(struct in6_pktinfo));
1247		pktinfo6 = (struct in6_pktinfo *)CMSG_DATA(cmsg);
1248		memset(pktinfo6, 0, sizeof(struct in6_pktinfo));
1249		if (inet_pton(domain, SRC6_IP,
1250			      (void *)&pktinfo6->ipi6_addr) != 1)
1251			return -1;
1252	} else {
1253		return -1;
1254	}
1255
1256	return 0;
1257}
1258
1259static int sendmsg_to_server(int type, const struct sockaddr_storage *addr,
1260			     socklen_t addr_len, int set_cmsg, int flags,
1261			     int *syscall_err)
1262{
1263	union {
1264		char buf[CMSG_SPACE(sizeof(struct in6_pktinfo))];
1265		struct cmsghdr align;
1266	} control6;
1267	union {
1268		char buf[CMSG_SPACE(sizeof(struct in_pktinfo))];
1269		struct cmsghdr align;
1270	} control4;
1271	struct msghdr hdr;
1272	struct iovec iov;
1273	char data = 'a';
1274	int domain;
1275	int fd = -1;
1276
1277	domain = addr->ss_family;
1278
1279	if (domain != AF_INET && domain != AF_INET6) {
1280		log_err("Unsupported address family");
1281		goto err;
1282	}
1283
1284	fd = socket(domain, type, 0);
1285	if (fd == -1) {
1286		log_err("Failed to create client socket");
1287		goto err;
1288	}
1289
1290	memset(&iov, 0, sizeof(iov));
1291	iov.iov_base = &data;
1292	iov.iov_len = sizeof(data);
1293
1294	memset(&hdr, 0, sizeof(hdr));
1295	hdr.msg_name = (void *)addr;
1296	hdr.msg_namelen = addr_len;
1297	hdr.msg_iov = &iov;
1298	hdr.msg_iovlen = 1;
1299
1300	if (set_cmsg) {
1301		if (domain == AF_INET) {
1302			hdr.msg_control = &control4;
1303			hdr.msg_controllen = sizeof(control4.buf);
1304		} else if (domain == AF_INET6) {
1305			hdr.msg_control = &control6;
1306			hdr.msg_controllen = sizeof(control6.buf);
1307		}
1308		if (init_pktinfo(domain, CMSG_FIRSTHDR(&hdr))) {
1309			log_err("Fail to init pktinfo");
1310			goto err;
1311		}
1312	}
1313
1314	if (sendmsg(fd, &hdr, flags) != sizeof(data)) {
1315		log_err("Fail to send message to server");
1316		*syscall_err = errno;
1317		goto err;
1318	}
1319
1320	goto out;
1321err:
1322	close(fd);
1323	fd = -1;
1324out:
1325	return fd;
1326}
1327
1328static int fastconnect_to_server(const struct sockaddr_storage *addr,
1329				 socklen_t addr_len)
1330{
1331	int sendmsg_err;
1332
1333	return sendmsg_to_server(SOCK_STREAM, addr, addr_len, /*set_cmsg*/0,
1334				 MSG_FASTOPEN, &sendmsg_err);
1335}
1336
1337static int recvmsg_from_client(int sockfd, struct sockaddr_storage *src_addr)
1338{
1339	struct timeval tv;
1340	struct msghdr hdr;
1341	struct iovec iov;
1342	char data[64];
1343	fd_set rfds;
1344
1345	FD_ZERO(&rfds);
1346	FD_SET(sockfd, &rfds);
1347
1348	tv.tv_sec = 2;
1349	tv.tv_usec = 0;
1350
1351	if (select(sockfd + 1, &rfds, NULL, NULL, &tv) <= 0 ||
1352	    !FD_ISSET(sockfd, &rfds))
1353		return -1;
1354
1355	memset(&iov, 0, sizeof(iov));
1356	iov.iov_base = data;
1357	iov.iov_len = sizeof(data);
1358
1359	memset(&hdr, 0, sizeof(hdr));
1360	hdr.msg_name = src_addr;
1361	hdr.msg_namelen = sizeof(struct sockaddr_storage);
1362	hdr.msg_iov = &iov;
1363	hdr.msg_iovlen = 1;
1364
1365	return recvmsg(sockfd, &hdr, 0);
1366}
1367
1368static int init_addrs(const struct sock_addr_test *test,
1369		      struct sockaddr_storage *requested_addr,
1370		      struct sockaddr_storage *expected_addr,
1371		      struct sockaddr_storage *expected_src_addr)
1372{
1373	socklen_t addr_len = sizeof(struct sockaddr_storage);
1374
1375	if (mk_sockaddr(test->domain, test->expected_ip, test->expected_port,
1376			(struct sockaddr *)expected_addr, addr_len) == -1)
1377		goto err;
1378
1379	if (mk_sockaddr(test->domain, test->requested_ip, test->requested_port,
1380			(struct sockaddr *)requested_addr, addr_len) == -1)
1381		goto err;
1382
1383	if (test->expected_src_ip &&
1384	    mk_sockaddr(test->domain, test->expected_src_ip, 0,
1385			(struct sockaddr *)expected_src_addr, addr_len) == -1)
1386		goto err;
1387
1388	return 0;
1389err:
1390	return -1;
1391}
1392
1393static int run_bind_test_case(const struct sock_addr_test *test)
1394{
1395	socklen_t addr_len = sizeof(struct sockaddr_storage);
1396	struct sockaddr_storage requested_addr;
1397	struct sockaddr_storage expected_addr;
1398	int clientfd = -1;
1399	int servfd = -1;
1400	int err = 0;
1401
1402	if (init_addrs(test, &requested_addr, &expected_addr, NULL))
1403		goto err;
1404
1405	servfd = start_server(test->type, &requested_addr, addr_len);
1406	if (servfd == -1)
1407		goto err;
1408
1409	if (cmp_local_addr(servfd, &expected_addr))
1410		goto err;
1411
1412	/* Try to connect to server just in case */
1413	clientfd = connect_to_server(test->type, &expected_addr, addr_len);
1414	if (clientfd == -1)
1415		goto err;
1416
1417	goto out;
1418err:
1419	err = -1;
1420out:
1421	close(clientfd);
1422	close(servfd);
1423	return err;
1424}
1425
1426static int run_connect_test_case(const struct sock_addr_test *test)
1427{
1428	socklen_t addr_len = sizeof(struct sockaddr_storage);
1429	struct sockaddr_storage expected_src_addr;
1430	struct sockaddr_storage requested_addr;
1431	struct sockaddr_storage expected_addr;
1432	int clientfd = -1;
1433	int servfd = -1;
1434	int err = 0;
1435
1436	if (init_addrs(test, &requested_addr, &expected_addr,
1437		       &expected_src_addr))
1438		goto err;
1439
1440	/* Prepare server to connect to */
1441	servfd = start_server(test->type, &expected_addr, addr_len);
1442	if (servfd == -1)
1443		goto err;
1444
1445	clientfd = connect_to_server(test->type, &requested_addr, addr_len);
1446	if (clientfd == -1)
1447		goto err;
1448
1449	/* Make sure src and dst addrs were overridden properly */
1450	if (cmp_peer_addr(clientfd, &expected_addr))
1451		goto err;
1452
1453	if (cmp_local_ip(clientfd, &expected_src_addr))
1454		goto err;
1455
1456	if (test->type == SOCK_STREAM) {
1457		/* Test TCP Fast Open scenario */
1458		clientfd = fastconnect_to_server(&requested_addr, addr_len);
1459		if (clientfd == -1)
1460			goto err;
1461
1462		/* Make sure src and dst addrs were overridden properly */
1463		if (cmp_peer_addr(clientfd, &expected_addr))
1464			goto err;
1465
1466		if (cmp_local_ip(clientfd, &expected_src_addr))
1467			goto err;
1468	}
1469
1470	goto out;
1471err:
1472	err = -1;
1473out:
1474	close(clientfd);
1475	close(servfd);
1476	return err;
1477}
1478
1479static int run_xmsg_test_case(const struct sock_addr_test *test, int max_cmsg)
1480{
1481	socklen_t addr_len = sizeof(struct sockaddr_storage);
1482	struct sockaddr_storage expected_addr;
1483	struct sockaddr_storage server_addr;
1484	struct sockaddr_storage sendmsg_addr;
1485	struct sockaddr_storage recvmsg_addr;
1486	int clientfd = -1;
1487	int servfd = -1;
1488	int set_cmsg;
1489	int err = 0;
1490
1491	if (test->type != SOCK_DGRAM)
1492		goto err;
1493
1494	if (init_addrs(test, &sendmsg_addr, &server_addr, &expected_addr))
1495		goto err;
1496
1497	/* Prepare server to sendmsg to */
1498	servfd = start_server(test->type, &server_addr, addr_len);
1499	if (servfd == -1)
1500		goto err;
1501
1502	for (set_cmsg = 0; set_cmsg <= max_cmsg; ++set_cmsg) {
1503		if (clientfd >= 0)
1504			close(clientfd);
1505
1506		clientfd = sendmsg_to_server(test->type, &sendmsg_addr,
1507					     addr_len, set_cmsg, /*flags*/0,
1508					     &err);
1509		if (err)
1510			goto out;
1511		else if (clientfd == -1)
1512			goto err;
1513
1514		/* Try to receive message on server instead of using
1515		 * getpeername(2) on client socket, to check that client's
1516		 * destination address was rewritten properly, since
1517		 * getpeername(2) doesn't work with unconnected datagram
1518		 * sockets.
1519		 *
1520		 * Get source address from recvmsg(2) as well to make sure
1521		 * source was rewritten properly: getsockname(2) can't be used
1522		 * since socket is unconnected and source defined for one
1523		 * specific packet may differ from the one used by default and
1524		 * returned by getsockname(2).
1525		 */
1526		if (recvmsg_from_client(servfd, &recvmsg_addr) == -1)
1527			goto err;
1528
1529		if (cmp_addr(&recvmsg_addr, &expected_addr, /*cmp_port*/0))
1530			goto err;
1531	}
1532
1533	goto out;
1534err:
1535	err = -1;
1536out:
1537	close(clientfd);
1538	close(servfd);
1539	return err;
1540}
1541
1542static int run_test_case(int cgfd, const struct sock_addr_test *test)
1543{
1544	int progfd = -1;
1545	int err = 0;
1546
1547	printf("Test case: %s .. ", test->descr);
1548
1549	progfd = test->loadfn(test);
1550	if (test->expected_result == LOAD_REJECT && progfd < 0)
1551		goto out;
1552	else if (test->expected_result == LOAD_REJECT || progfd < 0)
1553		goto err;
1554
1555	err = bpf_prog_attach(progfd, cgfd, test->attach_type,
1556			      BPF_F_ALLOW_OVERRIDE);
1557	if (test->expected_result == ATTACH_REJECT && err) {
1558		err = 0; /* error was expected, reset it */
1559		goto out;
1560	} else if (test->expected_result == ATTACH_REJECT || err) {
1561		goto err;
1562	} else if (test->expected_result == ATTACH_OKAY) {
1563		err = 0;
1564		goto out;
1565	}
1566
1567	switch (test->attach_type) {
1568	case BPF_CGROUP_INET4_BIND:
1569	case BPF_CGROUP_INET6_BIND:
1570		err = run_bind_test_case(test);
1571		break;
1572	case BPF_CGROUP_INET4_CONNECT:
1573	case BPF_CGROUP_INET6_CONNECT:
1574		err = run_connect_test_case(test);
1575		break;
1576	case BPF_CGROUP_UDP4_SENDMSG:
1577	case BPF_CGROUP_UDP6_SENDMSG:
1578		err = run_xmsg_test_case(test, 1);
1579		break;
1580	case BPF_CGROUP_UDP4_RECVMSG:
1581	case BPF_CGROUP_UDP6_RECVMSG:
1582		err = run_xmsg_test_case(test, 0);
1583		break;
1584	default:
1585		goto err;
1586	}
1587
1588	if (test->expected_result == SYSCALL_EPERM && err == EPERM) {
1589		err = 0; /* error was expected, reset it */
1590		goto out;
1591	}
1592
1593	if (test->expected_result == SYSCALL_ENOTSUPP && err == ENOTSUPP) {
1594		err = 0; /* error was expected, reset it */
1595		goto out;
1596	}
1597
1598	if (err || test->expected_result != SUCCESS)
1599		goto err;
1600
1601	goto out;
1602err:
1603	err = -1;
1604out:
1605	/* Detaching w/o checking return code: best effort attempt. */
1606	if (progfd != -1)
1607		bpf_prog_detach(cgfd, test->attach_type);
1608	close(progfd);
1609	printf("[%s]\n", err ? "FAIL" : "PASS");
1610	return err;
1611}
1612
1613static int run_tests(int cgfd)
1614{
1615	int passes = 0;
1616	int fails = 0;
1617	int i;
1618
1619	for (i = 0; i < ARRAY_SIZE(tests); ++i) {
1620		if (run_test_case(cgfd, &tests[i]))
1621			++fails;
1622		else
1623			++passes;
1624	}
1625	printf("Summary: %d PASSED, %d FAILED\n", passes, fails);
1626	return fails ? -1 : 0;
1627}
1628
1629int main(int argc, char **argv)
1630{
1631	int cgfd = -1;
1632	int err = 0;
1633
1634	if (argc < 2) {
1635		fprintf(stderr,
1636			"%s has to be run via %s.sh. Skip direct run.\n",
1637			argv[0], argv[0]);
1638		exit(err);
1639	}
1640
1641	cgfd = cgroup_setup_and_join(CG_PATH);
1642	if (cgfd < 0)
1643		goto err;
1644
1645	if (run_tests(cgfd))
1646		goto err;
1647
1648	goto out;
1649err:
1650	err = -1;
1651out:
1652	close(cgfd);
1653	cleanup_cgroup_environment();
1654	return err;
1655}
1656