1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2021 ARM Limited.
4 */
5#include <errno.h>
6#include <stdbool.h>
7#include <stddef.h>
8#include <stdio.h>
9#include <stdlib.h>
10#include <string.h>
11#include <unistd.h>
12#include <sys/auxv.h>
13#include <sys/prctl.h>
14#include <sys/ptrace.h>
15#include <sys/types.h>
16#include <sys/uio.h>
17#include <sys/wait.h>
18#include <asm/sigcontext.h>
19#include <asm/ptrace.h>
20
21#include "../../kselftest.h"
22
23/* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
24#ifndef NT_ARM_ZA
25#define NT_ARM_ZA 0x40c
26#endif
27#ifndef NT_ARM_ZT
28#define NT_ARM_ZT 0x40d
29#endif
30
31#define EXPECTED_TESTS 3
32
33static int sme_vl;
34
35static void fill_buf(char *buf, size_t size)
36{
37	int i;
38
39	for (i = 0; i < size; i++)
40		buf[i] = random();
41}
42
43static int do_child(void)
44{
45	if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
46		ksft_exit_fail_msg("PTRACE_TRACEME", strerror(errno));
47
48	if (raise(SIGSTOP))
49		ksft_exit_fail_msg("raise(SIGSTOP)", strerror(errno));
50
51	return EXIT_SUCCESS;
52}
53
54static struct user_za_header *get_za(pid_t pid, void **buf, size_t *size)
55{
56	struct user_za_header *za;
57	void *p;
58	size_t sz = sizeof(*za);
59	struct iovec iov;
60
61	while (1) {
62		if (*size < sz) {
63			p = realloc(*buf, sz);
64			if (!p) {
65				errno = ENOMEM;
66				goto error;
67			}
68
69			*buf = p;
70			*size = sz;
71		}
72
73		iov.iov_base = *buf;
74		iov.iov_len = sz;
75		if (ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZA, &iov))
76			goto error;
77
78		za = *buf;
79		if (za->size <= sz)
80			break;
81
82		sz = za->size;
83	}
84
85	return za;
86
87error:
88	return NULL;
89}
90
91static int set_za(pid_t pid, const struct user_za_header *za)
92{
93	struct iovec iov;
94
95	iov.iov_base = (void *)za;
96	iov.iov_len = za->size;
97	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZA, &iov);
98}
99
100static int get_zt(pid_t pid, char zt[ZT_SIG_REG_BYTES])
101{
102	struct iovec iov;
103
104	iov.iov_base = zt;
105	iov.iov_len = ZT_SIG_REG_BYTES;
106	return ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZT, &iov);
107}
108
109
110static int set_zt(pid_t pid, const char zt[ZT_SIG_REG_BYTES])
111{
112	struct iovec iov;
113
114	iov.iov_base = (void *)zt;
115	iov.iov_len = ZT_SIG_REG_BYTES;
116	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZT, &iov);
117}
118
119/* Reading with ZA disabled returns all zeros */
120static void ptrace_za_disabled_read_zt(pid_t child)
121{
122	struct user_za_header za;
123	char zt[ZT_SIG_REG_BYTES];
124	int ret, i;
125	bool fail = false;
126
127	/* Disable PSTATE.ZA using the ZA interface */
128	memset(&za, 0, sizeof(za));
129	za.vl = sme_vl;
130	za.size = sizeof(za);
131
132	ret = set_za(child, &za);
133	if (ret != 0) {
134		ksft_print_msg("Failed to disable ZA\n");
135		fail = true;
136	}
137
138	/* Read back ZT */
139	ret = get_zt(child, zt);
140	if (ret != 0) {
141		ksft_print_msg("Failed to read ZT\n");
142		fail = true;
143	}
144
145	for (i = 0; i < ARRAY_SIZE(zt); i++) {
146		if (zt[i]) {
147			ksft_print_msg("zt[%d]: 0x%x != 0\n", i, zt[i]);
148			fail = true;
149		}
150	}
151
152	ksft_test_result(!fail, "ptrace_za_disabled_read_zt\n");
153}
154
155/* Writing then reading ZT should return the data written */
156static void ptrace_set_get_zt(pid_t child)
157{
158	char zt_in[ZT_SIG_REG_BYTES];
159	char zt_out[ZT_SIG_REG_BYTES];
160	int ret, i;
161	bool fail = false;
162
163	fill_buf(zt_in, sizeof(zt_in));
164
165	ret = set_zt(child, zt_in);
166	if (ret != 0) {
167		ksft_print_msg("Failed to set ZT\n");
168		fail = true;
169	}
170
171	ret = get_zt(child, zt_out);
172	if (ret != 0) {
173		ksft_print_msg("Failed to read ZT\n");
174		fail = true;
175	}
176
177	for (i = 0; i < ARRAY_SIZE(zt_in); i++) {
178		if (zt_in[i] != zt_out[i]) {
179			ksft_print_msg("zt[%d]: 0x%x != 0x%x\n", i,
180				       zt_in[i], zt_out[i]);
181			fail = true;
182		}
183	}
184
185	ksft_test_result(!fail, "ptrace_set_get_zt\n");
186}
187
188/* Writing ZT should set PSTATE.ZA */
189static void ptrace_enable_za_via_zt(pid_t child)
190{
191	struct user_za_header za_in;
192	struct user_za_header *za_out;
193	char zt[ZT_SIG_REG_BYTES];
194	char *za_data;
195	size_t za_out_size;
196	int ret, i, vq;
197	bool fail = false;
198
199	/* Disable PSTATE.ZA using the ZA interface */
200	memset(&za_in, 0, sizeof(za_in));
201	za_in.vl = sme_vl;
202	za_in.size = sizeof(za_in);
203
204	ret = set_za(child, &za_in);
205	if (ret != 0) {
206		ksft_print_msg("Failed to disable ZA\n");
207		fail = true;
208	}
209
210	/* Write ZT */
211	fill_buf(zt, sizeof(zt));
212	ret = set_zt(child, zt);
213	if (ret != 0) {
214		ksft_print_msg("Failed to set ZT\n");
215		fail = true;
216	}
217
218	/* Read back ZA and check for register data */
219	za_out = NULL;
220	za_out_size = 0;
221	if (get_za(child, (void **)&za_out, &za_out_size)) {
222		/* Should have an unchanged VL */
223		if (za_out->vl != sme_vl) {
224			ksft_print_msg("VL changed from %d to %d\n",
225				       sme_vl, za_out->vl);
226			fail = true;
227		}
228		vq = __sve_vq_from_vl(za_out->vl);
229		za_data = (char *)za_out + ZA_PT_ZA_OFFSET;
230
231		/* Should have register data */
232		if (za_out->size < ZA_PT_SIZE(vq)) {
233			ksft_print_msg("ZA data less than expected: %u < %u\n",
234				       za_out->size, ZA_PT_SIZE(vq));
235			fail = true;
236			vq = 0;
237		}
238
239		/* That register data should be non-zero */
240		for (i = 0; i < ZA_PT_ZA_SIZE(vq); i++) {
241			if (za_data[i]) {
242				ksft_print_msg("ZA byte %d is %x\n",
243					       i, za_data[i]);
244				fail = true;
245			}
246		}
247	} else {
248		ksft_print_msg("Failed to read ZA\n");
249		fail = true;
250	}
251
252	ksft_test_result(!fail, "ptrace_enable_za_via_zt\n");
253}
254
255static int do_parent(pid_t child)
256{
257	int ret = EXIT_FAILURE;
258	pid_t pid;
259	int status;
260	siginfo_t si;
261
262	/* Attach to the child */
263	while (1) {
264		int sig;
265
266		pid = wait(&status);
267		if (pid == -1) {
268			perror("wait");
269			goto error;
270		}
271
272		/*
273		 * This should never happen but it's hard to flag in
274		 * the framework.
275		 */
276		if (pid != child)
277			continue;
278
279		if (WIFEXITED(status) || WIFSIGNALED(status))
280			ksft_exit_fail_msg("Child died unexpectedly\n");
281
282		if (!WIFSTOPPED(status))
283			goto error;
284
285		sig = WSTOPSIG(status);
286
287		if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
288			if (errno == ESRCH)
289				goto disappeared;
290
291			if (errno == EINVAL) {
292				sig = 0; /* bust group-stop */
293				goto cont;
294			}
295
296			ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
297					      strerror(errno));
298			goto error;
299		}
300
301		if (sig == SIGSTOP && si.si_code == SI_TKILL &&
302		    si.si_pid == pid)
303			break;
304
305	cont:
306		if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
307			if (errno == ESRCH)
308				goto disappeared;
309
310			ksft_test_result_fail("PTRACE_CONT: %s\n",
311					      strerror(errno));
312			goto error;
313		}
314	}
315
316	ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
317
318	ptrace_za_disabled_read_zt(child);
319	ptrace_set_get_zt(child);
320	ptrace_enable_za_via_zt(child);
321
322	ret = EXIT_SUCCESS;
323
324error:
325	kill(child, SIGKILL);
326
327disappeared:
328	return ret;
329}
330
331int main(void)
332{
333	int ret = EXIT_SUCCESS;
334	pid_t child;
335
336	srandom(getpid());
337
338	ksft_print_header();
339
340	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2)) {
341		ksft_set_plan(1);
342		ksft_exit_skip("SME2 not available\n");
343	}
344
345	/* We need a valid SME VL to enable/disable ZA */
346	sme_vl = prctl(PR_SME_GET_VL);
347	if (sme_vl == -1) {
348		ksft_set_plan(1);
349		ksft_exit_skip("Failed to read SME VL: %d (%s)\n",
350			       errno, strerror(errno));
351	}
352
353	ksft_set_plan(EXPECTED_TESTS);
354
355	child = fork();
356	if (!child)
357		return do_child();
358
359	if (do_parent(child))
360		ret = EXIT_FAILURE;
361
362	ksft_print_cnts();
363
364	return ret;
365}
366