1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2015-2019 ARM Limited.
4 * Original author: Dave Martin <Dave.Martin@arm.com>
5 */
6#define _GNU_SOURCE
7#include <assert.h>
8#include <errno.h>
9#include <limits.h>
10#include <stddef.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <getopt.h>
15#include <unistd.h>
16#include <sys/auxv.h>
17#include <sys/prctl.h>
18#include <asm/hwcap.h>
19#include <asm/sigcontext.h>
20
21static int inherit = 0;
22static int no_inherit = 0;
23static int force = 0;
24static unsigned long vl;
25
26static const struct option options[] = {
27	{ "force",	no_argument, NULL, 'f' },
28	{ "inherit",	no_argument, NULL, 'i' },
29	{ "max",	no_argument, NULL, 'M' },
30	{ "no-inherit",	no_argument, &no_inherit, 1 },
31	{ "help",	no_argument, NULL, '?' },
32	{}
33};
34
35static char const *program_name;
36
37static int parse_options(int argc, char **argv)
38{
39	int c;
40	char *rest;
41
42	program_name = strrchr(argv[0], '/');
43	if (program_name)
44		++program_name;
45	else
46		program_name = argv[0];
47
48	while ((c = getopt_long(argc, argv, "Mfhi", options, NULL)) != -1)
49		switch (c) {
50		case 'M':	vl = SVE_VL_MAX; break;
51		case 'f':	force = 1; break;
52		case 'i':	inherit = 1; break;
53		case 0:		break;
54		default:	goto error;
55		}
56
57	if (inherit && no_inherit)
58		goto error;
59
60	if (!vl) {
61		/* vector length */
62		if (optind >= argc)
63			goto error;
64
65		errno = 0;
66		vl = strtoul(argv[optind], &rest, 0);
67		if (*rest) {
68			vl = ULONG_MAX;
69			errno = EINVAL;
70		}
71		if (vl == ULONG_MAX && errno) {
72			fprintf(stderr, "%s: %s: %s\n",
73				program_name, argv[optind], strerror(errno));
74			goto error;
75		}
76
77		++optind;
78	}
79
80	/* command */
81	if (optind >= argc)
82		goto error;
83
84	return 0;
85
86error:
87	fprintf(stderr,
88		"Usage: %s [-f | --force] "
89		"[-i | --inherit | --no-inherit] "
90		"{-M | --max | <vector length>} "
91		"<command> [<arguments> ...]\n",
92		program_name);
93	return -1;
94}
95
96int main(int argc, char **argv)
97{
98	int ret = 126;	/* same as sh(1) command-not-executable error */
99	long flags;
100	char *path;
101	int t, e;
102
103	if (parse_options(argc, argv))
104		return 2;	/* same as sh(1) builtin incorrect-usage */
105
106	if (vl & ~(vl & PR_SVE_VL_LEN_MASK)) {
107		fprintf(stderr, "%s: Invalid vector length %lu\n",
108			program_name, vl);
109		return 2;	/* same as sh(1) builtin incorrect-usage */
110	}
111
112	if (!(getauxval(AT_HWCAP) & HWCAP_SVE)) {
113		fprintf(stderr, "%s: Scalable Vector Extension not present\n",
114			program_name);
115
116		if (!force)
117			goto error;
118
119		fputs("Going ahead anyway (--force):  "
120		      "This is a debug option.  Don't rely on it.\n",
121		      stderr);
122	}
123
124	flags = PR_SVE_SET_VL_ONEXEC;
125	if (inherit)
126		flags |= PR_SVE_VL_INHERIT;
127
128	t = prctl(PR_SVE_SET_VL, vl | flags);
129	if (t < 0) {
130		fprintf(stderr, "%s: PR_SVE_SET_VL: %s\n",
131			program_name, strerror(errno));
132		goto error;
133	}
134
135	t = prctl(PR_SVE_GET_VL);
136	if (t == -1) {
137		fprintf(stderr, "%s: PR_SVE_GET_VL: %s\n",
138			program_name, strerror(errno));
139		goto error;
140	}
141	flags = PR_SVE_VL_LEN_MASK;
142	flags = t & ~flags;
143
144	assert(optind < argc);
145	path = argv[optind];
146
147	execvp(path, &argv[optind]);
148	e = errno;
149	if (errno == ENOENT)
150		ret = 127;	/* same as sh(1) not-found error */
151	fprintf(stderr, "%s: %s: %s\n", program_name, path, strerror(e));
152
153error:
154	return ret;		/* same as sh(1) not-executable error */
155}
156