1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2021 ARM Limited.
4 * Original author: Mark Brown <broonie@kernel.org>
5 */
6#include <assert.h>
7#include <errno.h>
8#include <fcntl.h>
9#include <stdbool.h>
10#include <stddef.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <unistd.h>
15#include <sys/auxv.h>
16#include <sys/prctl.h>
17#include <sys/types.h>
18#include <sys/wait.h>
19#include <asm/sigcontext.h>
20#include <asm/hwcap.h>
21
22#include "../../kselftest.h"
23#include "rdvl.h"
24
25#define ARCH_MIN_VL SVE_VL_MIN
26
27struct vec_data {
28	const char *name;
29	unsigned long hwcap_type;
30	unsigned long hwcap;
31	const char *rdvl_binary;
32	int (*rdvl)(void);
33
34	int prctl_get;
35	int prctl_set;
36	const char *default_vl_file;
37
38	int default_vl;
39	int min_vl;
40	int max_vl;
41};
42
43#define VEC_SVE 0
44#define VEC_SME 1
45
46static struct vec_data vec_data[] = {
47	[VEC_SVE] = {
48		.name = "SVE",
49		.hwcap_type = AT_HWCAP,
50		.hwcap = HWCAP_SVE,
51		.rdvl = rdvl_sve,
52		.rdvl_binary = "./rdvl-sve",
53		.prctl_get = PR_SVE_GET_VL,
54		.prctl_set = PR_SVE_SET_VL,
55		.default_vl_file = "/proc/sys/abi/sve_default_vector_length",
56	},
57	[VEC_SME] = {
58		.name = "SME",
59		.hwcap_type = AT_HWCAP2,
60		.hwcap = HWCAP2_SME,
61		.rdvl = rdvl_sme,
62		.rdvl_binary = "./rdvl-sme",
63		.prctl_get = PR_SME_GET_VL,
64		.prctl_set = PR_SME_SET_VL,
65		.default_vl_file = "/proc/sys/abi/sme_default_vector_length",
66	},
67};
68
69static int stdio_read_integer(FILE *f, const char *what, int *val)
70{
71	int n = 0;
72	int ret;
73
74	ret = fscanf(f, "%d%*1[\n]%n", val, &n);
75	if (ret < 1 || n < 1) {
76		ksft_print_msg("failed to parse integer from %s\n", what);
77		return -1;
78	}
79
80	return 0;
81}
82
83/* Start a new process and return the vector length it sees */
84static int get_child_rdvl(struct vec_data *data)
85{
86	FILE *out;
87	int pipefd[2];
88	pid_t pid, child;
89	int read_vl, ret;
90
91	ret = pipe(pipefd);
92	if (ret == -1) {
93		ksft_print_msg("pipe() failed: %d (%s)\n",
94			       errno, strerror(errno));
95		return -1;
96	}
97
98	fflush(stdout);
99
100	child = fork();
101	if (child == -1) {
102		ksft_print_msg("fork() failed: %d (%s)\n",
103			       errno, strerror(errno));
104		close(pipefd[0]);
105		close(pipefd[1]);
106		return -1;
107	}
108
109	/* Child: put vector length on the pipe */
110	if (child == 0) {
111		/*
112		 * Replace stdout with the pipe, errors to stderr from
113		 * here as kselftest prints to stdout.
114		 */
115		ret = dup2(pipefd[1], 1);
116		if (ret == -1) {
117			fprintf(stderr, "dup2() %d\n", errno);
118			exit(EXIT_FAILURE);
119		}
120
121		/* exec() a new binary which puts the VL on stdout */
122		ret = execl(data->rdvl_binary, data->rdvl_binary, NULL);
123		fprintf(stderr, "execl(%s) failed: %d (%s)\n",
124			data->rdvl_binary, errno, strerror(errno));
125
126		exit(EXIT_FAILURE);
127	}
128
129	close(pipefd[1]);
130
131	/* Parent; wait for the exit status from the child & verify it */
132	do {
133		pid = wait(&ret);
134		if (pid == -1) {
135			ksft_print_msg("wait() failed: %d (%s)\n",
136				       errno, strerror(errno));
137			close(pipefd[0]);
138			return -1;
139		}
140	} while (pid != child);
141
142	assert(pid == child);
143
144	if (!WIFEXITED(ret)) {
145		ksft_print_msg("child exited abnormally\n");
146		close(pipefd[0]);
147		return -1;
148	}
149
150	if (WEXITSTATUS(ret) != 0) {
151		ksft_print_msg("child returned error %d\n",
152			       WEXITSTATUS(ret));
153		close(pipefd[0]);
154		return -1;
155	}
156
157	out = fdopen(pipefd[0], "r");
158	if (!out) {
159		ksft_print_msg("failed to open child stdout\n");
160		close(pipefd[0]);
161		return -1;
162	}
163
164	ret = stdio_read_integer(out, "child", &read_vl);
165	fclose(out);
166	if (ret != 0)
167		return ret;
168
169	return read_vl;
170}
171
172static int file_read_integer(const char *name, int *val)
173{
174	FILE *f;
175	int ret;
176
177	f = fopen(name, "r");
178	if (!f) {
179		ksft_test_result_fail("Unable to open %s: %d (%s)\n",
180				      name, errno,
181				      strerror(errno));
182		return -1;
183	}
184
185	ret = stdio_read_integer(f, name, val);
186	fclose(f);
187
188	return ret;
189}
190
191static int file_write_integer(const char *name, int val)
192{
193	FILE *f;
194
195	f = fopen(name, "w");
196	if (!f) {
197		ksft_test_result_fail("Unable to open %s: %d (%s)\n",
198				      name, errno,
199				      strerror(errno));
200		return -1;
201	}
202
203	fprintf(f, "%d", val);
204	fclose(f);
205
206	return 0;
207}
208
209/*
210 * Verify that we can read the default VL via proc, checking that it
211 * is set in a freshly spawned child.
212 */
213static void proc_read_default(struct vec_data *data)
214{
215	int default_vl, child_vl, ret;
216
217	ret = file_read_integer(data->default_vl_file, &default_vl);
218	if (ret != 0)
219		return;
220
221	/* Is this the actual default seen by new processes? */
222	child_vl = get_child_rdvl(data);
223	if (child_vl != default_vl) {
224		ksft_test_result_fail("%s is %d but child VL is %d\n",
225				      data->default_vl_file,
226				      default_vl, child_vl);
227		return;
228	}
229
230	ksft_test_result_pass("%s default vector length %d\n", data->name,
231			      default_vl);
232	data->default_vl = default_vl;
233}
234
235/* Verify that we can write a minimum value and have it take effect */
236static void proc_write_min(struct vec_data *data)
237{
238	int ret, new_default, child_vl;
239
240	if (geteuid() != 0) {
241		ksft_test_result_skip("Need to be root to write to /proc\n");
242		return;
243	}
244
245	ret = file_write_integer(data->default_vl_file, ARCH_MIN_VL);
246	if (ret != 0)
247		return;
248
249	/* What was the new value? */
250	ret = file_read_integer(data->default_vl_file, &new_default);
251	if (ret != 0)
252		return;
253
254	/* Did it take effect in a new process? */
255	child_vl = get_child_rdvl(data);
256	if (child_vl != new_default) {
257		ksft_test_result_fail("%s is %d but child VL is %d\n",
258				      data->default_vl_file,
259				      new_default, child_vl);
260		return;
261	}
262
263	ksft_test_result_pass("%s minimum vector length %d\n", data->name,
264			      new_default);
265	data->min_vl = new_default;
266
267	file_write_integer(data->default_vl_file, data->default_vl);
268}
269
270/* Verify that we can write a maximum value and have it take effect */
271static void proc_write_max(struct vec_data *data)
272{
273	int ret, new_default, child_vl;
274
275	if (geteuid() != 0) {
276		ksft_test_result_skip("Need to be root to write to /proc\n");
277		return;
278	}
279
280	/* -1 is accepted by the /proc interface as the maximum VL */
281	ret = file_write_integer(data->default_vl_file, -1);
282	if (ret != 0)
283		return;
284
285	/* What was the new value? */
286	ret = file_read_integer(data->default_vl_file, &new_default);
287	if (ret != 0)
288		return;
289
290	/* Did it take effect in a new process? */
291	child_vl = get_child_rdvl(data);
292	if (child_vl != new_default) {
293		ksft_test_result_fail("%s is %d but child VL is %d\n",
294				      data->default_vl_file,
295				      new_default, child_vl);
296		return;
297	}
298
299	ksft_test_result_pass("%s maximum vector length %d\n", data->name,
300			      new_default);
301	data->max_vl = new_default;
302
303	file_write_integer(data->default_vl_file, data->default_vl);
304}
305
306/* Can we read back a VL from prctl? */
307static void prctl_get(struct vec_data *data)
308{
309	int ret;
310
311	ret = prctl(data->prctl_get);
312	if (ret == -1) {
313		ksft_test_result_fail("%s prctl() read failed: %d (%s)\n",
314				      data->name, errno, strerror(errno));
315		return;
316	}
317
318	/* Mask out any flags */
319	ret &= PR_SVE_VL_LEN_MASK;
320
321	/* Is that what we can read back directly? */
322	if (ret == data->rdvl())
323		ksft_test_result_pass("%s current VL is %d\n",
324				      data->name, ret);
325	else
326		ksft_test_result_fail("%s prctl() VL %d but RDVL is %d\n",
327				      data->name, ret, data->rdvl());
328}
329
330/* Does the prctl let us set the VL we already have? */
331static void prctl_set_same(struct vec_data *data)
332{
333	int cur_vl = data->rdvl();
334	int ret;
335
336	ret = prctl(data->prctl_set, cur_vl);
337	if (ret < 0) {
338		ksft_test_result_fail("%s prctl set failed: %d (%s)\n",
339				      data->name, errno, strerror(errno));
340		return;
341	}
342
343	ksft_test_result(cur_vl == data->rdvl(),
344			 "%s set VL %d and have VL %d\n",
345			 data->name, cur_vl, data->rdvl());
346}
347
348/* Can we set a new VL for this process? */
349static void prctl_set(struct vec_data *data)
350{
351	int ret;
352
353	if (data->min_vl == data->max_vl) {
354		ksft_test_result_skip("%s only one VL supported\n",
355				      data->name);
356		return;
357	}
358
359	/* Try to set the minimum VL */
360	ret = prctl(data->prctl_set, data->min_vl);
361	if (ret < 0) {
362		ksft_test_result_fail("%s prctl set failed for %d: %d (%s)\n",
363				      data->name, data->min_vl,
364				      errno, strerror(errno));
365		return;
366	}
367
368	if ((ret & PR_SVE_VL_LEN_MASK) != data->min_vl) {
369		ksft_test_result_fail("%s prctl set %d but return value is %d\n",
370				      data->name, data->min_vl, data->rdvl());
371		return;
372	}
373
374	if (data->rdvl() != data->min_vl) {
375		ksft_test_result_fail("%s set %d but RDVL is %d\n",
376				      data->name, data->min_vl, data->rdvl());
377		return;
378	}
379
380	/* Try to set the maximum VL */
381	ret = prctl(data->prctl_set, data->max_vl);
382	if (ret < 0) {
383		ksft_test_result_fail("%s prctl set failed for %d: %d (%s)\n",
384				      data->name, data->max_vl,
385				      errno, strerror(errno));
386		return;
387	}
388
389	if ((ret & PR_SVE_VL_LEN_MASK) != data->max_vl) {
390		ksft_test_result_fail("%s prctl() set %d but return value is %d\n",
391				      data->name, data->max_vl, data->rdvl());
392		return;
393	}
394
395	/* The _INHERIT flag should not be present when we read the VL */
396	ret = prctl(data->prctl_get);
397	if (ret == -1) {
398		ksft_test_result_fail("%s prctl() read failed: %d (%s)\n",
399				      data->name, errno, strerror(errno));
400		return;
401	}
402
403	if (ret & PR_SVE_VL_INHERIT) {
404		ksft_test_result_fail("%s prctl() reports _INHERIT\n",
405				      data->name);
406		return;
407	}
408
409	ksft_test_result_pass("%s prctl() set min/max\n", data->name);
410}
411
412/* If we didn't request it a new VL shouldn't affect the child */
413static void prctl_set_no_child(struct vec_data *data)
414{
415	int ret, child_vl;
416
417	if (data->min_vl == data->max_vl) {
418		ksft_test_result_skip("%s only one VL supported\n",
419				      data->name);
420		return;
421	}
422
423	ret = prctl(data->prctl_set, data->min_vl);
424	if (ret < 0) {
425		ksft_test_result_fail("%s prctl set failed for %d: %d (%s)\n",
426				      data->name, data->min_vl,
427				      errno, strerror(errno));
428		return;
429	}
430
431	/* Ensure the default VL is different */
432	ret = file_write_integer(data->default_vl_file, data->max_vl);
433	if (ret != 0)
434		return;
435
436	/* Check that the child has the default we just set */
437	child_vl = get_child_rdvl(data);
438	if (child_vl != data->max_vl) {
439		ksft_test_result_fail("%s is %d but child VL is %d\n",
440				      data->default_vl_file,
441				      data->max_vl, child_vl);
442		return;
443	}
444
445	ksft_test_result_pass("%s vector length used default\n", data->name);
446
447	file_write_integer(data->default_vl_file, data->default_vl);
448}
449
450/* If we didn't request it a new VL shouldn't affect the child */
451static void prctl_set_for_child(struct vec_data *data)
452{
453	int ret, child_vl;
454
455	if (data->min_vl == data->max_vl) {
456		ksft_test_result_skip("%s only one VL supported\n",
457				      data->name);
458		return;
459	}
460
461	ret = prctl(data->prctl_set, data->min_vl | PR_SVE_VL_INHERIT);
462	if (ret < 0) {
463		ksft_test_result_fail("%s prctl set failed for %d: %d (%s)\n",
464				      data->name, data->min_vl,
465				      errno, strerror(errno));
466		return;
467	}
468
469	/* The _INHERIT flag should be present when we read the VL */
470	ret = prctl(data->prctl_get);
471	if (ret == -1) {
472		ksft_test_result_fail("%s prctl() read failed: %d (%s)\n",
473				      data->name, errno, strerror(errno));
474		return;
475	}
476	if (!(ret & PR_SVE_VL_INHERIT)) {
477		ksft_test_result_fail("%s prctl() does not report _INHERIT\n",
478				      data->name);
479		return;
480	}
481
482	/* Ensure the default VL is different */
483	ret = file_write_integer(data->default_vl_file, data->max_vl);
484	if (ret != 0)
485		return;
486
487	/* Check that the child inherited our VL */
488	child_vl = get_child_rdvl(data);
489	if (child_vl != data->min_vl) {
490		ksft_test_result_fail("%s is %d but child VL is %d\n",
491				      data->default_vl_file,
492				      data->min_vl, child_vl);
493		return;
494	}
495
496	ksft_test_result_pass("%s vector length was inherited\n", data->name);
497
498	file_write_integer(data->default_vl_file, data->default_vl);
499}
500
501/* _ONEXEC takes effect only in the child process */
502static void prctl_set_onexec(struct vec_data *data)
503{
504	int ret, child_vl;
505
506	if (data->min_vl == data->max_vl) {
507		ksft_test_result_skip("%s only one VL supported\n",
508				      data->name);
509		return;
510	}
511
512	/* Set a known value for the default and our current VL */
513	ret = file_write_integer(data->default_vl_file, data->max_vl);
514	if (ret != 0)
515		return;
516
517	ret = prctl(data->prctl_set, data->max_vl);
518	if (ret < 0) {
519		ksft_test_result_fail("%s prctl set failed for %d: %d (%s)\n",
520				      data->name, data->min_vl,
521				      errno, strerror(errno));
522		return;
523	}
524
525	/* Set a different value for the child to have on exec */
526	ret = prctl(data->prctl_set, data->min_vl | PR_SVE_SET_VL_ONEXEC);
527	if (ret < 0) {
528		ksft_test_result_fail("%s prctl set failed for %d: %d (%s)\n",
529				      data->name, data->min_vl,
530				      errno, strerror(errno));
531		return;
532	}
533
534	/* Our current VL should stay the same */
535	if (data->rdvl() != data->max_vl) {
536		ksft_test_result_fail("%s VL changed by _ONEXEC prctl()\n",
537				      data->name);
538		return;
539	}
540
541	/* Check that the child inherited our VL */
542	child_vl = get_child_rdvl(data);
543	if (child_vl != data->min_vl) {
544		ksft_test_result_fail("Set %d _ONEXEC but child VL is %d\n",
545				      data->min_vl, child_vl);
546		return;
547	}
548
549	ksft_test_result_pass("%s vector length set on exec\n", data->name);
550
551	file_write_integer(data->default_vl_file, data->default_vl);
552}
553
554/* For each VQ verify that setting via prctl() does the right thing */
555static void prctl_set_all_vqs(struct vec_data *data)
556{
557	int ret, vq, vl, new_vl, i;
558	int orig_vls[ARRAY_SIZE(vec_data)];
559	int errors = 0;
560
561	if (!data->min_vl || !data->max_vl) {
562		ksft_test_result_skip("%s Failed to enumerate VLs, not testing VL setting\n",
563				      data->name);
564		return;
565	}
566
567	for (i = 0; i < ARRAY_SIZE(vec_data); i++)
568		orig_vls[i] = vec_data[i].rdvl();
569
570	for (vq = SVE_VQ_MIN; vq <= SVE_VQ_MAX; vq++) {
571		vl = sve_vl_from_vq(vq);
572
573		/* Attempt to set the VL */
574		ret = prctl(data->prctl_set, vl);
575		if (ret < 0) {
576			errors++;
577			ksft_print_msg("%s prctl set failed for %d: %d (%s)\n",
578				       data->name, vl,
579				       errno, strerror(errno));
580			continue;
581		}
582
583		new_vl = ret & PR_SVE_VL_LEN_MASK;
584
585		/* Check that we actually have the reported new VL */
586		if (data->rdvl() != new_vl) {
587			ksft_print_msg("Set %s VL %d but RDVL reports %d\n",
588				       data->name, new_vl, data->rdvl());
589			errors++;
590		}
591
592		/* Did any other VLs change? */
593		for (i = 0; i < ARRAY_SIZE(vec_data); i++) {
594			if (&vec_data[i] == data)
595				continue;
596
597			if (!(getauxval(vec_data[i].hwcap_type) & vec_data[i].hwcap))
598				continue;
599
600			if (vec_data[i].rdvl() != orig_vls[i]) {
601				ksft_print_msg("%s VL changed from %d to %d\n",
602					       vec_data[i].name, orig_vls[i],
603					       vec_data[i].rdvl());
604				errors++;
605			}
606		}
607
608		/* Was that the VL we asked for? */
609		if (new_vl == vl)
610			continue;
611
612		/* Should round up to the minimum VL if below it */
613		if (vl < data->min_vl) {
614			if (new_vl != data->min_vl) {
615				ksft_print_msg("%s VL %d returned %d not minimum %d\n",
616					       data->name, vl, new_vl,
617					       data->min_vl);
618				errors++;
619			}
620
621			continue;
622		}
623
624		/* Should round down to maximum VL if above it */
625		if (vl > data->max_vl) {
626			if (new_vl != data->max_vl) {
627				ksft_print_msg("%s VL %d returned %d not maximum %d\n",
628					       data->name, vl, new_vl,
629					       data->max_vl);
630				errors++;
631			}
632
633			continue;
634		}
635
636		/* Otherwise we should've rounded down */
637		if (!(new_vl < vl)) {
638			ksft_print_msg("%s VL %d returned %d, did not round down\n",
639				       data->name, vl, new_vl);
640			errors++;
641
642			continue;
643		}
644	}
645
646	ksft_test_result(errors == 0, "%s prctl() set all VLs, %d errors\n",
647			 data->name, errors);
648}
649
650typedef void (*test_type)(struct vec_data *);
651
652static const test_type tests[] = {
653	/*
654	 * The default/min/max tests must be first and in this order
655	 * to provide data for other tests.
656	 */
657	proc_read_default,
658	proc_write_min,
659	proc_write_max,
660
661	prctl_get,
662	prctl_set_same,
663	prctl_set,
664	prctl_set_no_child,
665	prctl_set_for_child,
666	prctl_set_onexec,
667	prctl_set_all_vqs,
668};
669
670static inline void smstart(void)
671{
672	asm volatile("msr S0_3_C4_C7_3, xzr");
673}
674
675static inline void smstart_sm(void)
676{
677	asm volatile("msr S0_3_C4_C3_3, xzr");
678}
679
680static inline void smstop(void)
681{
682	asm volatile("msr S0_3_C4_C6_3, xzr");
683}
684
685
686/*
687 * Verify we can change the SVE vector length while SME is active and
688 * continue to use SME afterwards.
689 */
690static void change_sve_with_za(void)
691{
692	struct vec_data *sve_data = &vec_data[VEC_SVE];
693	bool pass = true;
694	int ret, i;
695
696	if (sve_data->min_vl == sve_data->max_vl) {
697		ksft_print_msg("Only one SVE VL supported, can't change\n");
698		ksft_test_result_skip("change_sve_while_sme\n");
699		return;
700	}
701
702	/* Ensure we will trigger a change when we set the maximum */
703	ret = prctl(sve_data->prctl_set, sve_data->min_vl);
704	if (ret != sve_data->min_vl) {
705		ksft_print_msg("Failed to set SVE VL %d: %d\n",
706			       sve_data->min_vl, ret);
707		pass = false;
708	}
709
710	/* Enable SM and ZA */
711	smstart();
712
713	/* Trigger another VL change */
714	ret = prctl(sve_data->prctl_set, sve_data->max_vl);
715	if (ret != sve_data->max_vl) {
716		ksft_print_msg("Failed to set SVE VL %d: %d\n",
717			       sve_data->max_vl, ret);
718		pass = false;
719	}
720
721	/*
722	 * Spin for a bit with SM enabled to try to trigger another
723	 * save/restore.  We can't use syscalls without exiting
724	 * streaming mode.
725	 */
726	for (i = 0; i < 100000000; i++)
727		smstart_sm();
728
729	/*
730	 * TODO: Verify that ZA was preserved over the VL change and
731	 * spin.
732	 */
733
734	/* Clean up after ourselves */
735	smstop();
736	ret = prctl(sve_data->prctl_set, sve_data->default_vl);
737	if (ret != sve_data->default_vl) {
738	        ksft_print_msg("Failed to restore SVE VL %d: %d\n",
739			       sve_data->default_vl, ret);
740		pass = false;
741	}
742
743	ksft_test_result(pass, "change_sve_with_za\n");
744}
745
746typedef void (*test_all_type)(void);
747
748static const struct {
749	const char *name;
750	test_all_type test;
751}  all_types_tests[] = {
752	{ "change_sve_with_za", change_sve_with_za },
753};
754
755int main(void)
756{
757	bool all_supported = true;
758	int i, j;
759
760	ksft_print_header();
761	ksft_set_plan(ARRAY_SIZE(tests) * ARRAY_SIZE(vec_data) +
762		      ARRAY_SIZE(all_types_tests));
763
764	for (i = 0; i < ARRAY_SIZE(vec_data); i++) {
765		struct vec_data *data = &vec_data[i];
766		unsigned long supported;
767
768		supported = getauxval(data->hwcap_type) & data->hwcap;
769		if (!supported)
770			all_supported = false;
771
772		for (j = 0; j < ARRAY_SIZE(tests); j++) {
773			if (supported)
774				tests[j](data);
775			else
776				ksft_test_result_skip("%s not supported\n",
777						      data->name);
778		}
779	}
780
781	for (i = 0; i < ARRAY_SIZE(all_types_tests); i++) {
782		if (all_supported)
783			all_types_tests[i].test();
784		else
785			ksft_test_result_skip("%s\n", all_types_tests[i].name);
786	}
787
788	ksft_exit_pass();
789}
790