1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * vsock_test - vsock.ko test suite
4 *
5 * Copyright (C) 2017 Red Hat, Inc.
6 *
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 */
9
10#include <getopt.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <errno.h>
15#include <unistd.h>
16#include <linux/kernel.h>
17#include <sys/types.h>
18#include <sys/socket.h>
19#include <time.h>
20#include <sys/mman.h>
21#include <poll.h>
22
23#include "timeout.h"
24#include "control.h"
25#include "util.h"
26
27static void test_stream_connection_reset(const struct test_opts *opts)
28{
29	union {
30		struct sockaddr sa;
31		struct sockaddr_vm svm;
32	} addr = {
33		.svm = {
34			.svm_family = AF_VSOCK,
35			.svm_port = 1234,
36			.svm_cid = opts->peer_cid,
37		},
38	};
39	int ret;
40	int fd;
41
42	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
43
44	timeout_begin(TIMEOUT);
45	do {
46		ret = connect(fd, &addr.sa, sizeof(addr.svm));
47		timeout_check("connect");
48	} while (ret < 0 && errno == EINTR);
49	timeout_end();
50
51	if (ret != -1) {
52		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
53		exit(EXIT_FAILURE);
54	}
55	if (errno != ECONNRESET) {
56		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
57		exit(EXIT_FAILURE);
58	}
59
60	close(fd);
61}
62
63static void test_stream_bind_only_client(const struct test_opts *opts)
64{
65	union {
66		struct sockaddr sa;
67		struct sockaddr_vm svm;
68	} addr = {
69		.svm = {
70			.svm_family = AF_VSOCK,
71			.svm_port = 1234,
72			.svm_cid = opts->peer_cid,
73		},
74	};
75	int ret;
76	int fd;
77
78	/* Wait for the server to be ready */
79	control_expectln("BIND");
80
81	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
82
83	timeout_begin(TIMEOUT);
84	do {
85		ret = connect(fd, &addr.sa, sizeof(addr.svm));
86		timeout_check("connect");
87	} while (ret < 0 && errno == EINTR);
88	timeout_end();
89
90	if (ret != -1) {
91		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
92		exit(EXIT_FAILURE);
93	}
94	if (errno != ECONNRESET) {
95		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
96		exit(EXIT_FAILURE);
97	}
98
99	/* Notify the server that the client has finished */
100	control_writeln("DONE");
101
102	close(fd);
103}
104
105static void test_stream_bind_only_server(const struct test_opts *opts)
106{
107	union {
108		struct sockaddr sa;
109		struct sockaddr_vm svm;
110	} addr = {
111		.svm = {
112			.svm_family = AF_VSOCK,
113			.svm_port = 1234,
114			.svm_cid = VMADDR_CID_ANY,
115		},
116	};
117	int fd;
118
119	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
120
121	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
122		perror("bind");
123		exit(EXIT_FAILURE);
124	}
125
126	/* Notify the client that the server is ready */
127	control_writeln("BIND");
128
129	/* Wait for the client to finish */
130	control_expectln("DONE");
131
132	close(fd);
133}
134
135static void test_stream_client_close_client(const struct test_opts *opts)
136{
137	int fd;
138
139	fd = vsock_stream_connect(opts->peer_cid, 1234);
140	if (fd < 0) {
141		perror("connect");
142		exit(EXIT_FAILURE);
143	}
144
145	send_byte(fd, 1, 0);
146	close(fd);
147}
148
149static void test_stream_client_close_server(const struct test_opts *opts)
150{
151	int fd;
152
153	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
154	if (fd < 0) {
155		perror("accept");
156		exit(EXIT_FAILURE);
157	}
158
159	/* Wait for the remote to close the connection, before check
160	 * -EPIPE error on send.
161	 */
162	vsock_wait_remote_close(fd);
163
164	send_byte(fd, -EPIPE, 0);
165	recv_byte(fd, 1, 0);
166	recv_byte(fd, 0, 0);
167	close(fd);
168}
169
170static void test_stream_server_close_client(const struct test_opts *opts)
171{
172	int fd;
173
174	fd = vsock_stream_connect(opts->peer_cid, 1234);
175	if (fd < 0) {
176		perror("connect");
177		exit(EXIT_FAILURE);
178	}
179
180	/* Wait for the remote to close the connection, before check
181	 * -EPIPE error on send.
182	 */
183	vsock_wait_remote_close(fd);
184
185	send_byte(fd, -EPIPE, 0);
186	recv_byte(fd, 1, 0);
187	recv_byte(fd, 0, 0);
188	close(fd);
189}
190
191static void test_stream_server_close_server(const struct test_opts *opts)
192{
193	int fd;
194
195	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
196	if (fd < 0) {
197		perror("accept");
198		exit(EXIT_FAILURE);
199	}
200
201	send_byte(fd, 1, 0);
202	close(fd);
203}
204
205/* With the standard socket sizes, VMCI is able to support about 100
206 * concurrent stream connections.
207 */
208#define MULTICONN_NFDS 100
209
210static void test_stream_multiconn_client(const struct test_opts *opts)
211{
212	int fds[MULTICONN_NFDS];
213	int i;
214
215	for (i = 0; i < MULTICONN_NFDS; i++) {
216		fds[i] = vsock_stream_connect(opts->peer_cid, 1234);
217		if (fds[i] < 0) {
218			perror("connect");
219			exit(EXIT_FAILURE);
220		}
221	}
222
223	for (i = 0; i < MULTICONN_NFDS; i++) {
224		if (i % 2)
225			recv_byte(fds[i], 1, 0);
226		else
227			send_byte(fds[i], 1, 0);
228	}
229
230	for (i = 0; i < MULTICONN_NFDS; i++)
231		close(fds[i]);
232}
233
234static void test_stream_multiconn_server(const struct test_opts *opts)
235{
236	int fds[MULTICONN_NFDS];
237	int i;
238
239	for (i = 0; i < MULTICONN_NFDS; i++) {
240		fds[i] = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
241		if (fds[i] < 0) {
242			perror("accept");
243			exit(EXIT_FAILURE);
244		}
245	}
246
247	for (i = 0; i < MULTICONN_NFDS; i++) {
248		if (i % 2)
249			send_byte(fds[i], 1, 0);
250		else
251			recv_byte(fds[i], 1, 0);
252	}
253
254	for (i = 0; i < MULTICONN_NFDS; i++)
255		close(fds[i]);
256}
257
258#define MSG_PEEK_BUF_LEN 64
259
260static void test_msg_peek_client(const struct test_opts *opts,
261				 bool seqpacket)
262{
263	unsigned char buf[MSG_PEEK_BUF_LEN];
264	ssize_t send_size;
265	int fd;
266	int i;
267
268	if (seqpacket)
269		fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
270	else
271		fd = vsock_stream_connect(opts->peer_cid, 1234);
272
273	if (fd < 0) {
274		perror("connect");
275		exit(EXIT_FAILURE);
276	}
277
278	for (i = 0; i < sizeof(buf); i++)
279		buf[i] = rand() & 0xFF;
280
281	control_expectln("SRVREADY");
282
283	send_size = send(fd, buf, sizeof(buf), 0);
284
285	if (send_size < 0) {
286		perror("send");
287		exit(EXIT_FAILURE);
288	}
289
290	if (send_size != sizeof(buf)) {
291		fprintf(stderr, "Invalid send size %zi\n", send_size);
292		exit(EXIT_FAILURE);
293	}
294
295	close(fd);
296}
297
298static void test_msg_peek_server(const struct test_opts *opts,
299				 bool seqpacket)
300{
301	unsigned char buf_half[MSG_PEEK_BUF_LEN / 2];
302	unsigned char buf_normal[MSG_PEEK_BUF_LEN];
303	unsigned char buf_peek[MSG_PEEK_BUF_LEN];
304	ssize_t res;
305	int fd;
306
307	if (seqpacket)
308		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
309	else
310		fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
311
312	if (fd < 0) {
313		perror("accept");
314		exit(EXIT_FAILURE);
315	}
316
317	/* Peek from empty socket. */
318	res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK | MSG_DONTWAIT);
319	if (res != -1) {
320		fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
321		exit(EXIT_FAILURE);
322	}
323
324	if (errno != EAGAIN) {
325		perror("EAGAIN expected");
326		exit(EXIT_FAILURE);
327	}
328
329	control_writeln("SRVREADY");
330
331	/* Peek part of data. */
332	res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK);
333	if (res != sizeof(buf_half)) {
334		fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
335			sizeof(buf_half), res);
336		exit(EXIT_FAILURE);
337	}
338
339	/* Peek whole data. */
340	res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK);
341	if (res != sizeof(buf_peek)) {
342		fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
343			sizeof(buf_peek), res);
344		exit(EXIT_FAILURE);
345	}
346
347	/* Compare partial and full peek. */
348	if (memcmp(buf_half, buf_peek, sizeof(buf_half))) {
349		fprintf(stderr, "Partial peek data mismatch\n");
350		exit(EXIT_FAILURE);
351	}
352
353	if (seqpacket) {
354		/* This type of socket supports MSG_TRUNC flag,
355		 * so check it with MSG_PEEK. We must get length
356		 * of the message.
357		 */
358		res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK |
359			   MSG_TRUNC);
360		if (res != sizeof(buf_peek)) {
361			fprintf(stderr,
362				"recv(2) + MSG_PEEK | MSG_TRUNC, exp %zu, got %zi\n",
363				sizeof(buf_half), res);
364			exit(EXIT_FAILURE);
365		}
366	}
367
368	res = recv(fd, buf_normal, sizeof(buf_normal), 0);
369	if (res != sizeof(buf_normal)) {
370		fprintf(stderr, "recv(2), expected %zu, got %zi\n",
371			sizeof(buf_normal), res);
372		exit(EXIT_FAILURE);
373	}
374
375	/* Compare full peek and normal read. */
376	if (memcmp(buf_peek, buf_normal, sizeof(buf_peek))) {
377		fprintf(stderr, "Full peek data mismatch\n");
378		exit(EXIT_FAILURE);
379	}
380
381	close(fd);
382}
383
384static void test_stream_msg_peek_client(const struct test_opts *opts)
385{
386	return test_msg_peek_client(opts, false);
387}
388
389static void test_stream_msg_peek_server(const struct test_opts *opts)
390{
391	return test_msg_peek_server(opts, false);
392}
393
394#define SOCK_BUF_SIZE (2 * 1024 * 1024)
395#define MAX_MSG_PAGES 4
396
397static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
398{
399	unsigned long curr_hash;
400	size_t max_msg_size;
401	int page_size;
402	int msg_count;
403	int fd;
404
405	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
406	if (fd < 0) {
407		perror("connect");
408		exit(EXIT_FAILURE);
409	}
410
411	/* Wait, until receiver sets buffer size. */
412	control_expectln("SRVREADY");
413
414	curr_hash = 0;
415	page_size = getpagesize();
416	max_msg_size = MAX_MSG_PAGES * page_size;
417	msg_count = SOCK_BUF_SIZE / max_msg_size;
418
419	for (int i = 0; i < msg_count; i++) {
420		ssize_t send_size;
421		size_t buf_size;
422		int flags;
423		void *buf;
424
425		/* Use "small" buffers and "big" buffers. */
426		if (i & 1)
427			buf_size = page_size +
428					(rand() % (max_msg_size - page_size));
429		else
430			buf_size = 1 + (rand() % page_size);
431
432		buf = malloc(buf_size);
433
434		if (!buf) {
435			perror("malloc");
436			exit(EXIT_FAILURE);
437		}
438
439		memset(buf, rand() & 0xff, buf_size);
440		/* Set at least one MSG_EOR + some random. */
441		if (i == (msg_count / 2) || (rand() & 1)) {
442			flags = MSG_EOR;
443			curr_hash++;
444		} else {
445			flags = 0;
446		}
447
448		send_size = send(fd, buf, buf_size, flags);
449
450		if (send_size < 0) {
451			perror("send");
452			exit(EXIT_FAILURE);
453		}
454
455		if (send_size != buf_size) {
456			fprintf(stderr, "Invalid send size\n");
457			exit(EXIT_FAILURE);
458		}
459
460		/*
461		 * Hash sum is computed at both client and server in
462		 * the same way:
463		 * H += hash('message data')
464		 * Such hash "controls" both data integrity and message
465		 * bounds. After data exchange, both sums are compared
466		 * using control socket, and if message bounds wasn't
467		 * broken - two values must be equal.
468		 */
469		curr_hash += hash_djb2(buf, buf_size);
470		free(buf);
471	}
472
473	control_writeln("SENDDONE");
474	control_writeulong(curr_hash);
475	close(fd);
476}
477
478static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
479{
480	unsigned long sock_buf_size;
481	unsigned long remote_hash;
482	unsigned long curr_hash;
483	int fd;
484	struct msghdr msg = {0};
485	struct iovec iov = {0};
486
487	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
488	if (fd < 0) {
489		perror("accept");
490		exit(EXIT_FAILURE);
491	}
492
493	sock_buf_size = SOCK_BUF_SIZE;
494
495	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE,
496		       &sock_buf_size, sizeof(sock_buf_size))) {
497		perror("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)");
498		exit(EXIT_FAILURE);
499	}
500
501	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
502		       &sock_buf_size, sizeof(sock_buf_size))) {
503		perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
504		exit(EXIT_FAILURE);
505	}
506
507	/* Ready to receive data. */
508	control_writeln("SRVREADY");
509	/* Wait, until peer sends whole data. */
510	control_expectln("SENDDONE");
511	iov.iov_len = MAX_MSG_PAGES * getpagesize();
512	iov.iov_base = malloc(iov.iov_len);
513	if (!iov.iov_base) {
514		perror("malloc");
515		exit(EXIT_FAILURE);
516	}
517
518	msg.msg_iov = &iov;
519	msg.msg_iovlen = 1;
520
521	curr_hash = 0;
522
523	while (1) {
524		ssize_t recv_size;
525
526		recv_size = recvmsg(fd, &msg, 0);
527
528		if (!recv_size)
529			break;
530
531		if (recv_size < 0) {
532			perror("recvmsg");
533			exit(EXIT_FAILURE);
534		}
535
536		if (msg.msg_flags & MSG_EOR)
537			curr_hash++;
538
539		curr_hash += hash_djb2(msg.msg_iov[0].iov_base, recv_size);
540	}
541
542	free(iov.iov_base);
543	close(fd);
544	remote_hash = control_readulong();
545
546	if (curr_hash != remote_hash) {
547		fprintf(stderr, "Message bounds broken\n");
548		exit(EXIT_FAILURE);
549	}
550}
551
552#define MESSAGE_TRUNC_SZ 32
553static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
554{
555	int fd;
556	char buf[MESSAGE_TRUNC_SZ];
557
558	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
559	if (fd < 0) {
560		perror("connect");
561		exit(EXIT_FAILURE);
562	}
563
564	if (send(fd, buf, sizeof(buf), 0) != sizeof(buf)) {
565		perror("send failed");
566		exit(EXIT_FAILURE);
567	}
568
569	control_writeln("SENDDONE");
570	close(fd);
571}
572
573static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
574{
575	int fd;
576	char buf[MESSAGE_TRUNC_SZ / 2];
577	struct msghdr msg = {0};
578	struct iovec iov = {0};
579
580	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
581	if (fd < 0) {
582		perror("accept");
583		exit(EXIT_FAILURE);
584	}
585
586	control_expectln("SENDDONE");
587	iov.iov_base = buf;
588	iov.iov_len = sizeof(buf);
589	msg.msg_iov = &iov;
590	msg.msg_iovlen = 1;
591
592	ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
593
594	if (ret != MESSAGE_TRUNC_SZ) {
595		printf("%zi\n", ret);
596		perror("MSG_TRUNC doesn't work");
597		exit(EXIT_FAILURE);
598	}
599
600	if (!(msg.msg_flags & MSG_TRUNC)) {
601		fprintf(stderr, "MSG_TRUNC expected\n");
602		exit(EXIT_FAILURE);
603	}
604
605	close(fd);
606}
607
608static time_t current_nsec(void)
609{
610	struct timespec ts;
611
612	if (clock_gettime(CLOCK_REALTIME, &ts)) {
613		perror("clock_gettime(3) failed");
614		exit(EXIT_FAILURE);
615	}
616
617	return (ts.tv_sec * 1000000000ULL) + ts.tv_nsec;
618}
619
620#define RCVTIMEO_TIMEOUT_SEC 1
621#define READ_OVERHEAD_NSEC 250000000 /* 0.25 sec */
622
623static void test_seqpacket_timeout_client(const struct test_opts *opts)
624{
625	int fd;
626	struct timeval tv;
627	char dummy;
628	time_t read_enter_ns;
629	time_t read_overhead_ns;
630
631	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
632	if (fd < 0) {
633		perror("connect");
634		exit(EXIT_FAILURE);
635	}
636
637	tv.tv_sec = RCVTIMEO_TIMEOUT_SEC;
638	tv.tv_usec = 0;
639
640	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (void *)&tv, sizeof(tv)) == -1) {
641		perror("setsockopt(SO_RCVTIMEO)");
642		exit(EXIT_FAILURE);
643	}
644
645	read_enter_ns = current_nsec();
646
647	if (read(fd, &dummy, sizeof(dummy)) != -1) {
648		fprintf(stderr,
649			"expected 'dummy' read(2) failure\n");
650		exit(EXIT_FAILURE);
651	}
652
653	if (errno != EAGAIN) {
654		perror("EAGAIN expected");
655		exit(EXIT_FAILURE);
656	}
657
658	read_overhead_ns = current_nsec() - read_enter_ns -
659			1000000000ULL * RCVTIMEO_TIMEOUT_SEC;
660
661	if (read_overhead_ns > READ_OVERHEAD_NSEC) {
662		fprintf(stderr,
663			"too much time in read(2), %lu > %i ns\n",
664			read_overhead_ns, READ_OVERHEAD_NSEC);
665		exit(EXIT_FAILURE);
666	}
667
668	control_writeln("WAITDONE");
669	close(fd);
670}
671
672static void test_seqpacket_timeout_server(const struct test_opts *opts)
673{
674	int fd;
675
676	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
677	if (fd < 0) {
678		perror("accept");
679		exit(EXIT_FAILURE);
680	}
681
682	control_expectln("WAITDONE");
683	close(fd);
684}
685
686static void test_seqpacket_bigmsg_client(const struct test_opts *opts)
687{
688	unsigned long sock_buf_size;
689	ssize_t send_size;
690	socklen_t len;
691	void *data;
692	int fd;
693
694	len = sizeof(sock_buf_size);
695
696	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
697	if (fd < 0) {
698		perror("connect");
699		exit(EXIT_FAILURE);
700	}
701
702	if (getsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
703		       &sock_buf_size, &len)) {
704		perror("getsockopt");
705		exit(EXIT_FAILURE);
706	}
707
708	sock_buf_size++;
709
710	data = malloc(sock_buf_size);
711	if (!data) {
712		perror("malloc");
713		exit(EXIT_FAILURE);
714	}
715
716	send_size = send(fd, data, sock_buf_size, 0);
717	if (send_size != -1) {
718		fprintf(stderr, "expected 'send(2)' failure, got %zi\n",
719			send_size);
720		exit(EXIT_FAILURE);
721	}
722
723	if (errno != EMSGSIZE) {
724		fprintf(stderr, "expected EMSGSIZE in 'errno', got %i\n",
725			errno);
726		exit(EXIT_FAILURE);
727	}
728
729	control_writeln("CLISENT");
730
731	free(data);
732	close(fd);
733}
734
735static void test_seqpacket_bigmsg_server(const struct test_opts *opts)
736{
737	int fd;
738
739	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
740	if (fd < 0) {
741		perror("accept");
742		exit(EXIT_FAILURE);
743	}
744
745	control_expectln("CLISENT");
746
747	close(fd);
748}
749
750#define BUF_PATTERN_1 'a'
751#define BUF_PATTERN_2 'b'
752
753static void test_seqpacket_invalid_rec_buffer_client(const struct test_opts *opts)
754{
755	int fd;
756	unsigned char *buf1;
757	unsigned char *buf2;
758	int buf_size = getpagesize() * 3;
759
760	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
761	if (fd < 0) {
762		perror("connect");
763		exit(EXIT_FAILURE);
764	}
765
766	buf1 = malloc(buf_size);
767	if (!buf1) {
768		perror("'malloc()' for 'buf1'");
769		exit(EXIT_FAILURE);
770	}
771
772	buf2 = malloc(buf_size);
773	if (!buf2) {
774		perror("'malloc()' for 'buf2'");
775		exit(EXIT_FAILURE);
776	}
777
778	memset(buf1, BUF_PATTERN_1, buf_size);
779	memset(buf2, BUF_PATTERN_2, buf_size);
780
781	if (send(fd, buf1, buf_size, 0) != buf_size) {
782		perror("send failed");
783		exit(EXIT_FAILURE);
784	}
785
786	if (send(fd, buf2, buf_size, 0) != buf_size) {
787		perror("send failed");
788		exit(EXIT_FAILURE);
789	}
790
791	close(fd);
792}
793
794static void test_seqpacket_invalid_rec_buffer_server(const struct test_opts *opts)
795{
796	int fd;
797	unsigned char *broken_buf;
798	unsigned char *valid_buf;
799	int page_size = getpagesize();
800	int buf_size = page_size * 3;
801	ssize_t res;
802	int prot = PROT_READ | PROT_WRITE;
803	int flags = MAP_PRIVATE | MAP_ANONYMOUS;
804	int i;
805
806	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
807	if (fd < 0) {
808		perror("accept");
809		exit(EXIT_FAILURE);
810	}
811
812	/* Setup first buffer. */
813	broken_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
814	if (broken_buf == MAP_FAILED) {
815		perror("mmap for 'broken_buf'");
816		exit(EXIT_FAILURE);
817	}
818
819	/* Unmap "hole" in buffer. */
820	if (munmap(broken_buf + page_size, page_size)) {
821		perror("'broken_buf' setup");
822		exit(EXIT_FAILURE);
823	}
824
825	valid_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
826	if (valid_buf == MAP_FAILED) {
827		perror("mmap for 'valid_buf'");
828		exit(EXIT_FAILURE);
829	}
830
831	/* Try to fill buffer with unmapped middle. */
832	res = read(fd, broken_buf, buf_size);
833	if (res != -1) {
834		fprintf(stderr,
835			"expected 'broken_buf' read(2) failure, got %zi\n",
836			res);
837		exit(EXIT_FAILURE);
838	}
839
840	if (errno != EFAULT) {
841		perror("unexpected errno of 'broken_buf'");
842		exit(EXIT_FAILURE);
843	}
844
845	/* Try to fill valid buffer. */
846	res = read(fd, valid_buf, buf_size);
847	if (res < 0) {
848		perror("unexpected 'valid_buf' read(2) failure");
849		exit(EXIT_FAILURE);
850	}
851
852	if (res != buf_size) {
853		fprintf(stderr,
854			"invalid 'valid_buf' read(2), expected %i, got %zi\n",
855			buf_size, res);
856		exit(EXIT_FAILURE);
857	}
858
859	for (i = 0; i < buf_size; i++) {
860		if (valid_buf[i] != BUF_PATTERN_2) {
861			fprintf(stderr,
862				"invalid pattern for 'valid_buf' at %i, expected %hhX, got %hhX\n",
863				i, BUF_PATTERN_2, valid_buf[i]);
864			exit(EXIT_FAILURE);
865		}
866	}
867
868	/* Unmap buffers. */
869	munmap(broken_buf, page_size);
870	munmap(broken_buf + page_size * 2, page_size);
871	munmap(valid_buf, buf_size);
872	close(fd);
873}
874
875#define RCVLOWAT_BUF_SIZE 128
876
877static void test_stream_poll_rcvlowat_server(const struct test_opts *opts)
878{
879	int fd;
880	int i;
881
882	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
883	if (fd < 0) {
884		perror("accept");
885		exit(EXIT_FAILURE);
886	}
887
888	/* Send 1 byte. */
889	send_byte(fd, 1, 0);
890
891	control_writeln("SRVSENT");
892
893	/* Wait until client is ready to receive rest of data. */
894	control_expectln("CLNSENT");
895
896	for (i = 0; i < RCVLOWAT_BUF_SIZE - 1; i++)
897		send_byte(fd, 1, 0);
898
899	/* Keep socket in active state. */
900	control_expectln("POLLDONE");
901
902	close(fd);
903}
904
905static void test_stream_poll_rcvlowat_client(const struct test_opts *opts)
906{
907	unsigned long lowat_val = RCVLOWAT_BUF_SIZE;
908	char buf[RCVLOWAT_BUF_SIZE];
909	struct pollfd fds;
910	ssize_t read_res;
911	short poll_flags;
912	int fd;
913
914	fd = vsock_stream_connect(opts->peer_cid, 1234);
915	if (fd < 0) {
916		perror("connect");
917		exit(EXIT_FAILURE);
918	}
919
920	if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
921		       &lowat_val, sizeof(lowat_val))) {
922		perror("setsockopt(SO_RCVLOWAT)");
923		exit(EXIT_FAILURE);
924	}
925
926	control_expectln("SRVSENT");
927
928	/* At this point, server sent 1 byte. */
929	fds.fd = fd;
930	poll_flags = POLLIN | POLLRDNORM;
931	fds.events = poll_flags;
932
933	/* Try to wait for 1 sec. */
934	if (poll(&fds, 1, 1000) < 0) {
935		perror("poll");
936		exit(EXIT_FAILURE);
937	}
938
939	/* poll() must return nothing. */
940	if (fds.revents) {
941		fprintf(stderr, "Unexpected poll result %hx\n",
942			fds.revents);
943		exit(EXIT_FAILURE);
944	}
945
946	/* Tell server to send rest of data. */
947	control_writeln("CLNSENT");
948
949	/* Poll for data. */
950	if (poll(&fds, 1, 10000) < 0) {
951		perror("poll");
952		exit(EXIT_FAILURE);
953	}
954
955	/* Only these two bits are expected. */
956	if (fds.revents != poll_flags) {
957		fprintf(stderr, "Unexpected poll result %hx\n",
958			fds.revents);
959		exit(EXIT_FAILURE);
960	}
961
962	/* Use MSG_DONTWAIT, if call is going to wait, EAGAIN
963	 * will be returned.
964	 */
965	read_res = recv(fd, buf, sizeof(buf), MSG_DONTWAIT);
966	if (read_res != RCVLOWAT_BUF_SIZE) {
967		fprintf(stderr, "Unexpected recv result %zi\n",
968			read_res);
969		exit(EXIT_FAILURE);
970	}
971
972	control_writeln("POLLDONE");
973
974	close(fd);
975}
976
977#define INV_BUF_TEST_DATA_LEN 512
978
979static void test_inv_buf_client(const struct test_opts *opts, bool stream)
980{
981	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
982	ssize_t ret;
983	int fd;
984
985	if (stream)
986		fd = vsock_stream_connect(opts->peer_cid, 1234);
987	else
988		fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
989
990	if (fd < 0) {
991		perror("connect");
992		exit(EXIT_FAILURE);
993	}
994
995	control_expectln("SENDDONE");
996
997	/* Use invalid buffer here. */
998	ret = recv(fd, NULL, sizeof(data), 0);
999	if (ret != -1) {
1000		fprintf(stderr, "expected recv(2) failure, got %zi\n", ret);
1001		exit(EXIT_FAILURE);
1002	}
1003
1004	if (errno != EFAULT) {
1005		fprintf(stderr, "unexpected recv(2) errno %d\n", errno);
1006		exit(EXIT_FAILURE);
1007	}
1008
1009	ret = recv(fd, data, sizeof(data), MSG_DONTWAIT);
1010
1011	if (stream) {
1012		/* For SOCK_STREAM we must continue reading. */
1013		if (ret != sizeof(data)) {
1014			fprintf(stderr, "expected recv(2) success, got %zi\n", ret);
1015			exit(EXIT_FAILURE);
1016		}
1017		/* Don't check errno in case of success. */
1018	} else {
1019		/* For SOCK_SEQPACKET socket's queue must be empty. */
1020		if (ret != -1) {
1021			fprintf(stderr, "expected recv(2) failure, got %zi\n", ret);
1022			exit(EXIT_FAILURE);
1023		}
1024
1025		if (errno != EAGAIN) {
1026			fprintf(stderr, "unexpected recv(2) errno %d\n", errno);
1027			exit(EXIT_FAILURE);
1028		}
1029	}
1030
1031	control_writeln("DONE");
1032
1033	close(fd);
1034}
1035
1036static void test_inv_buf_server(const struct test_opts *opts, bool stream)
1037{
1038	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
1039	ssize_t res;
1040	int fd;
1041
1042	if (stream)
1043		fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1044	else
1045		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
1046
1047	if (fd < 0) {
1048		perror("accept");
1049		exit(EXIT_FAILURE);
1050	}
1051
1052	res = send(fd, data, sizeof(data), 0);
1053	if (res != sizeof(data)) {
1054		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1055		exit(EXIT_FAILURE);
1056	}
1057
1058	control_writeln("SENDDONE");
1059
1060	control_expectln("DONE");
1061
1062	close(fd);
1063}
1064
1065static void test_stream_inv_buf_client(const struct test_opts *opts)
1066{
1067	test_inv_buf_client(opts, true);
1068}
1069
1070static void test_stream_inv_buf_server(const struct test_opts *opts)
1071{
1072	test_inv_buf_server(opts, true);
1073}
1074
1075static void test_seqpacket_inv_buf_client(const struct test_opts *opts)
1076{
1077	test_inv_buf_client(opts, false);
1078}
1079
1080static void test_seqpacket_inv_buf_server(const struct test_opts *opts)
1081{
1082	test_inv_buf_server(opts, false);
1083}
1084
1085#define HELLO_STR "HELLO"
1086#define WORLD_STR "WORLD"
1087
1088static void test_stream_virtio_skb_merge_client(const struct test_opts *opts)
1089{
1090	ssize_t res;
1091	int fd;
1092
1093	fd = vsock_stream_connect(opts->peer_cid, 1234);
1094	if (fd < 0) {
1095		perror("connect");
1096		exit(EXIT_FAILURE);
1097	}
1098
1099	/* Send first skbuff. */
1100	res = send(fd, HELLO_STR, strlen(HELLO_STR), 0);
1101	if (res != strlen(HELLO_STR)) {
1102		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1103		exit(EXIT_FAILURE);
1104	}
1105
1106	control_writeln("SEND0");
1107	/* Peer reads part of first skbuff. */
1108	control_expectln("REPLY0");
1109
1110	/* Send second skbuff, it will be appended to the first. */
1111	res = send(fd, WORLD_STR, strlen(WORLD_STR), 0);
1112	if (res != strlen(WORLD_STR)) {
1113		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1114		exit(EXIT_FAILURE);
1115	}
1116
1117	control_writeln("SEND1");
1118	/* Peer reads merged skbuff packet. */
1119	control_expectln("REPLY1");
1120
1121	close(fd);
1122}
1123
1124static void test_stream_virtio_skb_merge_server(const struct test_opts *opts)
1125{
1126	unsigned char buf[64];
1127	ssize_t res;
1128	int fd;
1129
1130	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1131	if (fd < 0) {
1132		perror("accept");
1133		exit(EXIT_FAILURE);
1134	}
1135
1136	control_expectln("SEND0");
1137
1138	/* Read skbuff partially. */
1139	res = recv(fd, buf, 2, 0);
1140	if (res != 2) {
1141		fprintf(stderr, "expected recv(2) returns 2 bytes, got %zi\n", res);
1142		exit(EXIT_FAILURE);
1143	}
1144
1145	control_writeln("REPLY0");
1146	control_expectln("SEND1");
1147
1148	res = recv(fd, buf + 2, sizeof(buf) - 2, 0);
1149	if (res != 8) {
1150		fprintf(stderr, "expected recv(2) returns 8 bytes, got %zi\n", res);
1151		exit(EXIT_FAILURE);
1152	}
1153
1154	res = recv(fd, buf, sizeof(buf) - 8 - 2, MSG_DONTWAIT);
1155	if (res != -1) {
1156		fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
1157		exit(EXIT_FAILURE);
1158	}
1159
1160	if (memcmp(buf, HELLO_STR WORLD_STR, strlen(HELLO_STR WORLD_STR))) {
1161		fprintf(stderr, "pattern mismatch\n");
1162		exit(EXIT_FAILURE);
1163	}
1164
1165	control_writeln("REPLY1");
1166
1167	close(fd);
1168}
1169
1170static void test_seqpacket_msg_peek_client(const struct test_opts *opts)
1171{
1172	return test_msg_peek_client(opts, true);
1173}
1174
1175static void test_seqpacket_msg_peek_server(const struct test_opts *opts)
1176{
1177	return test_msg_peek_server(opts, true);
1178}
1179
1180static struct test_case test_cases[] = {
1181	{
1182		.name = "SOCK_STREAM connection reset",
1183		.run_client = test_stream_connection_reset,
1184	},
1185	{
1186		.name = "SOCK_STREAM bind only",
1187		.run_client = test_stream_bind_only_client,
1188		.run_server = test_stream_bind_only_server,
1189	},
1190	{
1191		.name = "SOCK_STREAM client close",
1192		.run_client = test_stream_client_close_client,
1193		.run_server = test_stream_client_close_server,
1194	},
1195	{
1196		.name = "SOCK_STREAM server close",
1197		.run_client = test_stream_server_close_client,
1198		.run_server = test_stream_server_close_server,
1199	},
1200	{
1201		.name = "SOCK_STREAM multiple connections",
1202		.run_client = test_stream_multiconn_client,
1203		.run_server = test_stream_multiconn_server,
1204	},
1205	{
1206		.name = "SOCK_STREAM MSG_PEEK",
1207		.run_client = test_stream_msg_peek_client,
1208		.run_server = test_stream_msg_peek_server,
1209	},
1210	{
1211		.name = "SOCK_SEQPACKET msg bounds",
1212		.run_client = test_seqpacket_msg_bounds_client,
1213		.run_server = test_seqpacket_msg_bounds_server,
1214	},
1215	{
1216		.name = "SOCK_SEQPACKET MSG_TRUNC flag",
1217		.run_client = test_seqpacket_msg_trunc_client,
1218		.run_server = test_seqpacket_msg_trunc_server,
1219	},
1220	{
1221		.name = "SOCK_SEQPACKET timeout",
1222		.run_client = test_seqpacket_timeout_client,
1223		.run_server = test_seqpacket_timeout_server,
1224	},
1225	{
1226		.name = "SOCK_SEQPACKET invalid receive buffer",
1227		.run_client = test_seqpacket_invalid_rec_buffer_client,
1228		.run_server = test_seqpacket_invalid_rec_buffer_server,
1229	},
1230	{
1231		.name = "SOCK_STREAM poll() + SO_RCVLOWAT",
1232		.run_client = test_stream_poll_rcvlowat_client,
1233		.run_server = test_stream_poll_rcvlowat_server,
1234	},
1235	{
1236		.name = "SOCK_SEQPACKET big message",
1237		.run_client = test_seqpacket_bigmsg_client,
1238		.run_server = test_seqpacket_bigmsg_server,
1239	},
1240	{
1241		.name = "SOCK_STREAM test invalid buffer",
1242		.run_client = test_stream_inv_buf_client,
1243		.run_server = test_stream_inv_buf_server,
1244	},
1245	{
1246		.name = "SOCK_SEQPACKET test invalid buffer",
1247		.run_client = test_seqpacket_inv_buf_client,
1248		.run_server = test_seqpacket_inv_buf_server,
1249	},
1250	{
1251		.name = "SOCK_STREAM virtio skb merge",
1252		.run_client = test_stream_virtio_skb_merge_client,
1253		.run_server = test_stream_virtio_skb_merge_server,
1254	},
1255	{
1256		.name = "SOCK_SEQPACKET MSG_PEEK",
1257		.run_client = test_seqpacket_msg_peek_client,
1258		.run_server = test_seqpacket_msg_peek_server,
1259	},
1260	{},
1261};
1262
1263static const char optstring[] = "";
1264static const struct option longopts[] = {
1265	{
1266		.name = "control-host",
1267		.has_arg = required_argument,
1268		.val = 'H',
1269	},
1270	{
1271		.name = "control-port",
1272		.has_arg = required_argument,
1273		.val = 'P',
1274	},
1275	{
1276		.name = "mode",
1277		.has_arg = required_argument,
1278		.val = 'm',
1279	},
1280	{
1281		.name = "peer-cid",
1282		.has_arg = required_argument,
1283		.val = 'p',
1284	},
1285	{
1286		.name = "list",
1287		.has_arg = no_argument,
1288		.val = 'l',
1289	},
1290	{
1291		.name = "skip",
1292		.has_arg = required_argument,
1293		.val = 's',
1294	},
1295	{
1296		.name = "help",
1297		.has_arg = no_argument,
1298		.val = '?',
1299	},
1300	{},
1301};
1302
1303static void usage(void)
1304{
1305	fprintf(stderr, "Usage: vsock_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
1306		"\n"
1307		"  Server: vsock_test --control-port=1234 --mode=server --peer-cid=3\n"
1308		"  Client: vsock_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
1309		"\n"
1310		"Run vsock.ko tests.  Must be launched in both guest\n"
1311		"and host.  One side must use --mode=client and\n"
1312		"the other side must use --mode=server.\n"
1313		"\n"
1314		"A TCP control socket connection is used to coordinate tests\n"
1315		"between the client and the server.  The server requires a\n"
1316		"listen address and the client requires an address to\n"
1317		"connect to.\n"
1318		"\n"
1319		"The CID of the other side must be given with --peer-cid=<cid>.\n"
1320		"\n"
1321		"Options:\n"
1322		"  --help                 This help message\n"
1323		"  --control-host <host>  Server IP address to connect to\n"
1324		"  --control-port <port>  Server port to listen on/connect to\n"
1325		"  --mode client|server   Server or client mode\n"
1326		"  --peer-cid <cid>       CID of the other side\n"
1327		"  --list                 List of tests that will be executed\n"
1328		"  --skip <test_id>       Test ID to skip;\n"
1329		"                         use multiple --skip options to skip more tests\n"
1330		);
1331	exit(EXIT_FAILURE);
1332}
1333
1334int main(int argc, char **argv)
1335{
1336	const char *control_host = NULL;
1337	const char *control_port = NULL;
1338	struct test_opts opts = {
1339		.mode = TEST_MODE_UNSET,
1340		.peer_cid = VMADDR_CID_ANY,
1341	};
1342
1343	srand(time(NULL));
1344	init_signals();
1345
1346	for (;;) {
1347		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
1348
1349		if (opt == -1)
1350			break;
1351
1352		switch (opt) {
1353		case 'H':
1354			control_host = optarg;
1355			break;
1356		case 'm':
1357			if (strcmp(optarg, "client") == 0)
1358				opts.mode = TEST_MODE_CLIENT;
1359			else if (strcmp(optarg, "server") == 0)
1360				opts.mode = TEST_MODE_SERVER;
1361			else {
1362				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
1363				return EXIT_FAILURE;
1364			}
1365			break;
1366		case 'p':
1367			opts.peer_cid = parse_cid(optarg);
1368			break;
1369		case 'P':
1370			control_port = optarg;
1371			break;
1372		case 'l':
1373			list_tests(test_cases);
1374			break;
1375		case 's':
1376			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
1377				  optarg);
1378			break;
1379		case '?':
1380		default:
1381			usage();
1382		}
1383	}
1384
1385	if (!control_port)
1386		usage();
1387	if (opts.mode == TEST_MODE_UNSET)
1388		usage();
1389	if (opts.peer_cid == VMADDR_CID_ANY)
1390		usage();
1391
1392	if (!control_host) {
1393		if (opts.mode != TEST_MODE_SERVER)
1394			usage();
1395		control_host = "0.0.0.0";
1396	}
1397
1398	control_init(control_host, control_port,
1399		     opts.mode == TEST_MODE_SERVER);
1400
1401	run_tests(test_cases, &opts);
1402
1403	control_cleanup();
1404	return EXIT_SUCCESS;
1405}
1406