1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 * Copyright (c) 2020 SUSE LLC <mdoucha@suse.cz>
4 */
5
6/*
7 * Check that the kernel correctly handles send()/sendto()/sendmsg() calls
8 * with MSG_MORE flag
9 */
10
11#define _GNU_SOURCE
12#include <sys/types.h>
13#include <sys/socket.h>
14#include <netinet/in.h>
15#include <sys/ioctl.h>
16#include <net/if.h>
17#include <sched.h>
18
19#include "tst_test.h"
20#include "tst_net.h"
21
22#define SENDSIZE 16
23#define RECVSIZE 32
24
25static int sock = -1, dst_sock = -1, listen_sock = -1;
26static struct sockaddr_in addr;
27static char sendbuf[SENDSIZE];
28
29static void do_send(int sock, void *buf, size_t size, int flags)
30{
31	SAFE_SEND(1, sock, buf, size, flags);
32}
33
34static void do_sendto(int sock, void *buf, size_t size, int flags)
35{
36	SAFE_SENDTO(1, sock, buf, size, flags, (struct sockaddr *)&addr,
37		sizeof(addr));
38}
39
40static void do_sendmsg(int sock, void *buf, size_t size, int flags)
41{
42	struct msghdr msg;
43	struct iovec iov;
44
45	iov.iov_base = buf;
46	iov.iov_len = size;
47	msg.msg_name = &addr;
48	msg.msg_namelen = sizeof(addr);
49	msg.msg_iov = &iov;
50	msg.msg_iovlen = 1;
51	msg.msg_control = NULL;
52	msg.msg_controllen = 0;
53	msg.msg_flags = 0;
54	SAFE_SENDMSG(size, sock, &msg, flags);
55}
56
57static struct test_case {
58	int domain, type, protocol;
59	void (*send)(int sock, void *buf, size_t size, int flags);
60	int needs_connect, needs_accept;
61	const char *name;
62} testcase_list[] = {
63	{AF_INET, SOCK_STREAM, 0, do_send, 1, 1, "TCP send"},
64	{AF_INET, SOCK_DGRAM, 0, do_send, 1, 0, "UDP send"},
65	{AF_INET, SOCK_DGRAM, 0, do_sendto, 0, 0, "UDP sendto"},
66	{AF_INET, SOCK_DGRAM, 0, do_sendmsg, 0, 0, "UDP sendmsg"}
67};
68
69static void setup(void)
70{
71	memset(sendbuf, 0x42, SENDSIZE);
72}
73
74static int check_recv(int sock, long expsize, int loop)
75{
76	char recvbuf[RECVSIZE] = {0};
77
78	while (1) {
79		TEST(recv(sock, recvbuf, RECVSIZE, MSG_DONTWAIT));
80
81		if (TST_RET == -1) {
82			/* expected error immediately after send(MSG_MORE) */
83			if (TST_ERR == EAGAIN || TST_ERR == EWOULDBLOCK) {
84				if (expsize)
85					continue;
86				else
87					break;
88			}
89
90			/* unexpected error */
91			tst_res(TFAIL | TTERRNO, "recv() error at step %d, expsize %ld",
92				loop, expsize);
93			return 0;
94		}
95
96		if (TST_RET < 0) {
97			tst_res(TFAIL | TTERRNO, "recv() returns %ld at step %d, expsize %ld",
98				TST_RET, loop, expsize);
99			return 0;
100		}
101
102		if (TST_RET != expsize) {
103			tst_res(TFAIL, "recv() read %ld bytes, expected %ld, step %d",
104				TST_RET, expsize, loop);
105			return 0;
106		}
107		return 1;
108	}
109
110	return 1;
111}
112
113static void cleanup(void)
114{
115	if (sock >= 0)
116		SAFE_CLOSE(sock);
117
118	if (dst_sock >= 0 && dst_sock != listen_sock)
119		SAFE_CLOSE(dst_sock);
120
121	if (listen_sock >= 0)
122		SAFE_CLOSE(listen_sock);
123}
124
125static void run(unsigned int n)
126{
127	int i, ret;
128	struct test_case *tc = testcase_list + n;
129	socklen_t len = sizeof(addr);
130
131	tst_res(TINFO, "Tesing %s", tc->name);
132
133	tst_init_sockaddr_inet_bin(&addr, INADDR_LOOPBACK, 0);
134	listen_sock = SAFE_SOCKET(tc->domain, tc->type, tc->protocol);
135	dst_sock = listen_sock;
136	SAFE_BIND(listen_sock, (struct sockaddr *)&addr, sizeof(addr));
137	SAFE_GETSOCKNAME(listen_sock, (struct sockaddr *)&addr, &len);
138
139	if (tc->needs_accept)
140		SAFE_LISTEN(listen_sock, 1);
141
142	for (i = 0; i < 1000; i++) {
143		sock = SAFE_SOCKET(tc->domain, tc->type, tc->protocol);
144
145		if (tc->needs_connect)
146			SAFE_CONNECT(sock, (struct sockaddr *)&addr, len);
147
148		if (tc->needs_accept)
149			dst_sock = SAFE_ACCEPT(listen_sock, NULL, NULL);
150
151		tc->send(sock, sendbuf, SENDSIZE, 0);
152		ret = check_recv(dst_sock, SENDSIZE, i + 1);
153
154		if (!ret)
155			break;
156
157		tc->send(sock, sendbuf, SENDSIZE, MSG_MORE);
158		ret = check_recv(dst_sock, 0, i + 1);
159
160		if (!ret)
161			break;
162
163		tc->send(sock, sendbuf, 1, 0);
164		ret = check_recv(dst_sock, SENDSIZE + 1, i + 1);
165
166		if (!ret)
167			break;
168
169		SAFE_CLOSE(sock);
170
171		if (dst_sock != listen_sock)
172			SAFE_CLOSE(dst_sock);
173	}
174
175	if (ret)
176		tst_res(TPASS, "MSG_MORE works correctly");
177
178	cleanup();
179	dst_sock = -1;
180}
181
182static struct tst_test test = {
183	.test = run,
184	.tcnt = ARRAY_SIZE(testcase_list),
185	.setup = setup,
186	.cleanup = cleanup
187};
188