1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2020, Tessares SA. */
3/* Copyright (c) 2022, SUSE. */
4
5#include <linux/const.h>
6#include <netinet/in.h>
7#include <test_progs.h>
8#include "cgroup_helpers.h"
9#include "network_helpers.h"
10#include "mptcp_sock.skel.h"
11#include "mptcpify.skel.h"
12
13#define NS_TEST "mptcp_ns"
14
15#ifndef IPPROTO_MPTCP
16#define IPPROTO_MPTCP 262
17#endif
18
19#ifndef SOL_MPTCP
20#define SOL_MPTCP 284
21#endif
22#ifndef MPTCP_INFO
23#define MPTCP_INFO		1
24#endif
25#ifndef MPTCP_INFO_FLAG_FALLBACK
26#define MPTCP_INFO_FLAG_FALLBACK		_BITUL(0)
27#endif
28#ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED
29#define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED	_BITUL(1)
30#endif
31
32#ifndef TCP_CA_NAME_MAX
33#define TCP_CA_NAME_MAX	16
34#endif
35
36struct __mptcp_info {
37	__u8	mptcpi_subflows;
38	__u8	mptcpi_add_addr_signal;
39	__u8	mptcpi_add_addr_accepted;
40	__u8	mptcpi_subflows_max;
41	__u8	mptcpi_add_addr_signal_max;
42	__u8	mptcpi_add_addr_accepted_max;
43	__u32	mptcpi_flags;
44	__u32	mptcpi_token;
45	__u64	mptcpi_write_seq;
46	__u64	mptcpi_snd_una;
47	__u64	mptcpi_rcv_nxt;
48	__u8	mptcpi_local_addr_used;
49	__u8	mptcpi_local_addr_max;
50	__u8	mptcpi_csum_enabled;
51	__u32	mptcpi_retransmits;
52	__u64	mptcpi_bytes_retrans;
53	__u64	mptcpi_bytes_sent;
54	__u64	mptcpi_bytes_received;
55	__u64	mptcpi_bytes_acked;
56};
57
58struct mptcp_storage {
59	__u32 invoked;
60	__u32 is_mptcp;
61	struct sock *sk;
62	__u32 token;
63	struct sock *first;
64	char ca_name[TCP_CA_NAME_MAX];
65};
66
67static struct nstoken *create_netns(void)
68{
69	SYS(fail, "ip netns add %s", NS_TEST);
70	SYS(fail, "ip -net %s link set dev lo up", NS_TEST);
71
72	return open_netns(NS_TEST);
73fail:
74	return NULL;
75}
76
77static void cleanup_netns(struct nstoken *nstoken)
78{
79	if (nstoken)
80		close_netns(nstoken);
81
82	SYS_NOFAIL("ip netns del %s &> /dev/null", NS_TEST);
83}
84
85static int verify_tsk(int map_fd, int client_fd)
86{
87	int err, cfd = client_fd;
88	struct mptcp_storage val;
89
90	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
91	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
92		return err;
93
94	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
95		err++;
96
97	if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
98		err++;
99
100	return err;
101}
102
103static void get_msk_ca_name(char ca_name[])
104{
105	size_t len;
106	int fd;
107
108	fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
109	if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
110		return;
111
112	len = read(fd, ca_name, TCP_CA_NAME_MAX);
113	if (!ASSERT_GT(len, 0, "failed to read ca_name"))
114		goto err;
115
116	if (len > 0 && ca_name[len - 1] == '\n')
117		ca_name[len - 1] = '\0';
118
119err:
120	close(fd);
121}
122
123static int verify_msk(int map_fd, int client_fd, __u32 token)
124{
125	char ca_name[TCP_CA_NAME_MAX];
126	int err, cfd = client_fd;
127	struct mptcp_storage val;
128
129	if (!ASSERT_GT(token, 0, "invalid token"))
130		return -1;
131
132	get_msk_ca_name(ca_name);
133
134	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
135	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
136		return err;
137
138	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
139		err++;
140
141	if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
142		err++;
143
144	if (!ASSERT_EQ(val.token, token, "unexpected token"))
145		err++;
146
147	if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
148		err++;
149
150	if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
151		err++;
152
153	return err;
154}
155
156static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
157{
158	int client_fd, prog_fd, map_fd, err;
159	struct mptcp_sock *sock_skel;
160
161	sock_skel = mptcp_sock__open_and_load();
162	if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
163		return libbpf_get_error(sock_skel);
164
165	err = mptcp_sock__attach(sock_skel);
166	if (!ASSERT_OK(err, "skel_attach"))
167		goto out;
168
169	prog_fd = bpf_program__fd(sock_skel->progs._sockops);
170	map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
171	err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
172	if (!ASSERT_OK(err, "bpf_prog_attach"))
173		goto out;
174
175	client_fd = connect_to_fd(server_fd, 0);
176	if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
177		err = -EIO;
178		goto out;
179	}
180
181	err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
182			  verify_tsk(map_fd, client_fd);
183
184	close(client_fd);
185
186out:
187	mptcp_sock__destroy(sock_skel);
188	return err;
189}
190
191static void test_base(void)
192{
193	struct nstoken *nstoken = NULL;
194	int server_fd, cgroup_fd;
195
196	cgroup_fd = test__join_cgroup("/mptcp");
197	if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
198		return;
199
200	nstoken = create_netns();
201	if (!ASSERT_OK_PTR(nstoken, "create_netns"))
202		goto fail;
203
204	/* without MPTCP */
205	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
206	if (!ASSERT_GE(server_fd, 0, "start_server"))
207		goto with_mptcp;
208
209	ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
210
211	close(server_fd);
212
213with_mptcp:
214	/* with MPTCP */
215	server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
216	if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
217		goto fail;
218
219	ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
220
221	close(server_fd);
222
223fail:
224	cleanup_netns(nstoken);
225	close(cgroup_fd);
226}
227
228static void send_byte(int fd)
229{
230	char b = 0x55;
231
232	ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte");
233}
234
235static int verify_mptcpify(int server_fd, int client_fd)
236{
237	struct __mptcp_info info;
238	socklen_t optlen;
239	int protocol;
240	int err = 0;
241
242	optlen = sizeof(protocol);
243	if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen),
244		       "getsockopt(SOL_PROTOCOL)"))
245		return -1;
246
247	if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP"))
248		err++;
249
250	optlen = sizeof(info);
251	if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen),
252		       "getsockopt(MPTCP_INFO)"))
253		return -1;
254
255	if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags"))
256		err++;
257	if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK,
258			  "MPTCP fallback"))
259		err++;
260	if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED,
261			 "no remote key received"))
262		err++;
263
264	return err;
265}
266
267static int run_mptcpify(int cgroup_fd)
268{
269	int server_fd, client_fd, err = 0;
270	struct mptcpify *mptcpify_skel;
271
272	mptcpify_skel = mptcpify__open_and_load();
273	if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load"))
274		return libbpf_get_error(mptcpify_skel);
275
276	err = mptcpify__attach(mptcpify_skel);
277	if (!ASSERT_OK(err, "skel_attach"))
278		goto out;
279
280	/* without MPTCP */
281	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
282	if (!ASSERT_GE(server_fd, 0, "start_server")) {
283		err = -EIO;
284		goto out;
285	}
286
287	client_fd = connect_to_fd(server_fd, 0);
288	if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
289		err = -EIO;
290		goto close_server;
291	}
292
293	send_byte(client_fd);
294
295	err = verify_mptcpify(server_fd, client_fd);
296
297	close(client_fd);
298close_server:
299	close(server_fd);
300out:
301	mptcpify__destroy(mptcpify_skel);
302	return err;
303}
304
305static void test_mptcpify(void)
306{
307	struct nstoken *nstoken = NULL;
308	int cgroup_fd;
309
310	cgroup_fd = test__join_cgroup("/mptcpify");
311	if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
312		return;
313
314	nstoken = create_netns();
315	if (!ASSERT_OK_PTR(nstoken, "create_netns"))
316		goto fail;
317
318	ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify");
319
320fail:
321	cleanup_netns(nstoken);
322	close(cgroup_fd);
323}
324
325void test_mptcp(void)
326{
327	if (test__start_subtest("base"))
328		test_base();
329	if (test__start_subtest("mptcpify"))
330		test_mptcpify();
331}
332