1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * BPF JIT compiler for ARM64
4 *
5 * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
6 */
7
8#define pr_fmt(fmt) "bpf_jit: " fmt
9
10#include <linux/bitfield.h>
11#include <linux/bpf.h>
12#include <linux/filter.h>
13#include <linux/printk.h>
14#include <linux/slab.h>
15
16#include <asm/byteorder.h>
17#include <asm/cacheflush.h>
18#include <asm/debug-monitors.h>
19#include <asm/set_memory.h>
20
21#include "bpf_jit.h"
22
23#define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
24#define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
25#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
26#define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
27
28/* Map BPF registers to A64 registers */
29static const int bpf2a64[] = {
30	/* return value from in-kernel function, and exit value from eBPF */
31	[BPF_REG_0] = A64_R(7),
32	/* arguments from eBPF program to in-kernel function */
33	[BPF_REG_1] = A64_R(0),
34	[BPF_REG_2] = A64_R(1),
35	[BPF_REG_3] = A64_R(2),
36	[BPF_REG_4] = A64_R(3),
37	[BPF_REG_5] = A64_R(4),
38	/* callee saved registers that in-kernel function will preserve */
39	[BPF_REG_6] = A64_R(19),
40	[BPF_REG_7] = A64_R(20),
41	[BPF_REG_8] = A64_R(21),
42	[BPF_REG_9] = A64_R(22),
43	/* read-only frame pointer to access stack */
44	[BPF_REG_FP] = A64_R(25),
45	/* temporary registers for internal BPF JIT */
46	[TMP_REG_1] = A64_R(10),
47	[TMP_REG_2] = A64_R(11),
48	[TMP_REG_3] = A64_R(12),
49	/* tail_call_cnt */
50	[TCALL_CNT] = A64_R(26),
51	/* temporary register for blinding constants */
52	[BPF_REG_AX] = A64_R(9),
53};
54
55struct jit_ctx {
56	const struct bpf_prog *prog;
57	int idx;
58	int epilogue_offset;
59	int *offset;
60	int exentry_idx;
61	__le32 *image;
62	u32 stack_size;
63};
64
65static inline void emit(const u32 insn, struct jit_ctx *ctx)
66{
67	if (ctx->image != NULL)
68		ctx->image[ctx->idx] = cpu_to_le32(insn);
69
70	ctx->idx++;
71}
72
73static inline void emit_a64_mov_i(const int is64, const int reg,
74				  const s32 val, struct jit_ctx *ctx)
75{
76	u16 hi = val >> 16;
77	u16 lo = val & 0xffff;
78
79	if (hi & 0x8000) {
80		if (hi == 0xffff) {
81			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
82		} else {
83			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
84			if (lo != 0xffff)
85				emit(A64_MOVK(is64, reg, lo, 0), ctx);
86		}
87	} else {
88		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
89		if (hi)
90			emit(A64_MOVK(is64, reg, hi, 16), ctx);
91	}
92}
93
94static int i64_i16_blocks(const u64 val, bool inverse)
95{
96	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
97	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
98	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
99	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
100}
101
102static inline void emit_a64_mov_i64(const int reg, const u64 val,
103				    struct jit_ctx *ctx)
104{
105	u64 nrm_tmp = val, rev_tmp = ~val;
106	bool inverse;
107	int shift;
108
109	if (!(nrm_tmp >> 32))
110		return emit_a64_mov_i(0, reg, (u32)val, ctx);
111
112	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
113	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
114					  (fls64(nrm_tmp) - 1)), 16), 0);
115	if (inverse)
116		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
117	else
118		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
119	shift -= 16;
120	while (shift >= 0) {
121		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
122			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
123		shift -= 16;
124	}
125}
126
127/*
128 * Kernel addresses in the vmalloc space use at most 48 bits, and the
129 * remaining bits are guaranteed to be 0x1. So we can compose the address
130 * with a fixed length movn/movk/movk sequence.
131 */
132static inline void emit_addr_mov_i64(const int reg, const u64 val,
133				     struct jit_ctx *ctx)
134{
135	u64 tmp = val;
136	int shift = 0;
137
138	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
139	while (shift < 32) {
140		tmp >>= 16;
141		shift += 16;
142		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
143	}
144}
145
146static inline int bpf2a64_offset(int bpf_insn, int off,
147				 const struct jit_ctx *ctx)
148{
149	/* BPF JMP offset is relative to the next instruction */
150	bpf_insn++;
151	/*
152	 * Whereas arm64 branch instructions encode the offset
153	 * from the branch itself, so we must subtract 1 from the
154	 * instruction offset.
155	 */
156	return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
157}
158
159static void jit_fill_hole(void *area, unsigned int size)
160{
161	__le32 *ptr;
162	/* We are guaranteed to have aligned memory. */
163	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
164		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
165}
166
167static inline int epilogue_offset(const struct jit_ctx *ctx)
168{
169	int to = ctx->epilogue_offset;
170	int from = ctx->idx;
171
172	return to - from;
173}
174
175static bool is_addsub_imm(u32 imm)
176{
177	/* Either imm12 or shifted imm12. */
178	return !(imm & ~0xfff) || !(imm & ~0xfff000);
179}
180
181/* Stack must be multiples of 16B */
182#define STACK_ALIGN(sz) (((sz) + 15) & ~15)
183
184/* Tail call offset to jump into */
185#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
186#define PROLOGUE_OFFSET 8
187#else
188#define PROLOGUE_OFFSET 7
189#endif
190
191static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
192{
193	const struct bpf_prog *prog = ctx->prog;
194	const u8 r6 = bpf2a64[BPF_REG_6];
195	const u8 r7 = bpf2a64[BPF_REG_7];
196	const u8 r8 = bpf2a64[BPF_REG_8];
197	const u8 r9 = bpf2a64[BPF_REG_9];
198	const u8 fp = bpf2a64[BPF_REG_FP];
199	const u8 tcc = bpf2a64[TCALL_CNT];
200	const int idx0 = ctx->idx;
201	int cur_offset;
202
203	/*
204	 * BPF prog stack layout
205	 *
206	 *                         high
207	 * original A64_SP =>   0:+-----+ BPF prologue
208	 *                        |FP/LR|
209	 * current A64_FP =>  -16:+-----+
210	 *                        | ... | callee saved registers
211	 * BPF fp register => -64:+-----+ <= (BPF_FP)
212	 *                        |     |
213	 *                        | ... | BPF prog stack
214	 *                        |     |
215	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
216	 *                        |RSVD | padding
217	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
218	 *                        |     |
219	 *                        | ... | Function call stack
220	 *                        |     |
221	 *                        +-----+
222	 *                          low
223	 *
224	 */
225
226	/* BTI landing pad */
227	if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
228		emit(A64_BTI_C, ctx);
229
230	/* Save FP and LR registers to stay align with ARM64 AAPCS */
231	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
232	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
233
234	/* Save callee-saved registers */
235	emit(A64_PUSH(r6, r7, A64_SP), ctx);
236	emit(A64_PUSH(r8, r9, A64_SP), ctx);
237	emit(A64_PUSH(fp, tcc, A64_SP), ctx);
238
239	/* Set up BPF prog stack base register */
240	emit(A64_MOV(1, fp, A64_SP), ctx);
241
242	if (!ebpf_from_cbpf) {
243		/* Initialize tail_call_cnt */
244		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
245
246		cur_offset = ctx->idx - idx0;
247		if (cur_offset != PROLOGUE_OFFSET) {
248			pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
249				    cur_offset, PROLOGUE_OFFSET);
250			return -1;
251		}
252
253		/* BTI landing pad for the tail call, done with a BR */
254		if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
255			emit(A64_BTI_J, ctx);
256	}
257
258	ctx->stack_size = STACK_ALIGN(prog->aux->stack_depth);
259
260	/* Set up function call stack */
261	emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
262	return 0;
263}
264
265static int out_offset = -1; /* initialized on the first pass of build_body() */
266static int emit_bpf_tail_call(struct jit_ctx *ctx)
267{
268	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
269	const u8 r2 = bpf2a64[BPF_REG_2];
270	const u8 r3 = bpf2a64[BPF_REG_3];
271
272	const u8 tmp = bpf2a64[TMP_REG_1];
273	const u8 prg = bpf2a64[TMP_REG_2];
274	const u8 tcc = bpf2a64[TCALL_CNT];
275	const int idx0 = ctx->idx;
276#define cur_offset (ctx->idx - idx0)
277#define jmp_offset (out_offset - (cur_offset))
278	size_t off;
279
280	/* if (index >= array->map.max_entries)
281	 *     goto out;
282	 */
283	off = offsetof(struct bpf_array, map.max_entries);
284	emit_a64_mov_i64(tmp, off, ctx);
285	emit(A64_LDR32(tmp, r2, tmp), ctx);
286	emit(A64_MOV(0, r3, r3), ctx);
287	emit(A64_CMP(0, r3, tmp), ctx);
288	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
289
290	/* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
291	 *     goto out;
292	 * tail_call_cnt++;
293	 */
294	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
295	emit(A64_CMP(1, tcc, tmp), ctx);
296	emit(A64_B_(A64_COND_HI, jmp_offset), ctx);
297	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
298
299	/* prog = array->ptrs[index];
300	 * if (prog == NULL)
301	 *     goto out;
302	 */
303	off = offsetof(struct bpf_array, ptrs);
304	emit_a64_mov_i64(tmp, off, ctx);
305	emit(A64_ADD(1, tmp, r2, tmp), ctx);
306	emit(A64_LSL(1, prg, r3, 3), ctx);
307	emit(A64_LDR64(prg, tmp, prg), ctx);
308	emit(A64_CBZ(1, prg, jmp_offset), ctx);
309
310	/* goto *(prog->bpf_func + prologue_offset); */
311	off = offsetof(struct bpf_prog, bpf_func);
312	emit_a64_mov_i64(tmp, off, ctx);
313	emit(A64_LDR64(tmp, prg, tmp), ctx);
314	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
315	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
316	emit(A64_BR(tmp), ctx);
317
318	/* out: */
319	if (out_offset == -1)
320		out_offset = cur_offset;
321	if (cur_offset != out_offset) {
322		pr_err_once("tail_call out_offset = %d, expected %d!\n",
323			    cur_offset, out_offset);
324		return -1;
325	}
326	return 0;
327#undef cur_offset
328#undef jmp_offset
329}
330
331static void build_epilogue(struct jit_ctx *ctx)
332{
333	const u8 r0 = bpf2a64[BPF_REG_0];
334	const u8 r6 = bpf2a64[BPF_REG_6];
335	const u8 r7 = bpf2a64[BPF_REG_7];
336	const u8 r8 = bpf2a64[BPF_REG_8];
337	const u8 r9 = bpf2a64[BPF_REG_9];
338	const u8 fp = bpf2a64[BPF_REG_FP];
339
340	/* We're done with BPF stack */
341	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
342
343	/* Restore fs (x25) and x26 */
344	emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
345
346	/* Restore callee-saved register */
347	emit(A64_POP(r8, r9, A64_SP), ctx);
348	emit(A64_POP(r6, r7, A64_SP), ctx);
349
350	/* Restore FP/LR registers */
351	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
352
353	/* Set return value */
354	emit(A64_MOV(1, A64_R(0), r0), ctx);
355
356	emit(A64_RET(A64_LR), ctx);
357}
358
359#define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
360#define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
361
362int arm64_bpf_fixup_exception(const struct exception_table_entry *ex,
363			      struct pt_regs *regs)
364{
365	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
366	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
367
368	regs->regs[dst_reg] = 0;
369	regs->pc = (unsigned long)&ex->fixup - offset;
370	return 1;
371}
372
373/* For accesses to BTF pointers, add an entry to the exception table */
374static int add_exception_handler(const struct bpf_insn *insn,
375				 struct jit_ctx *ctx,
376				 int dst_reg)
377{
378	off_t offset;
379	unsigned long pc;
380	struct exception_table_entry *ex;
381
382	if (!ctx->image)
383		/* First pass */
384		return 0;
385
386	if (BPF_MODE(insn->code) != BPF_PROBE_MEM)
387		return 0;
388
389	if (!ctx->prog->aux->extable ||
390	    WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
391		return -EINVAL;
392
393	ex = &ctx->prog->aux->extable[ctx->exentry_idx];
394	pc = (unsigned long)&ctx->image[ctx->idx - 1];
395
396	offset = pc - (long)&ex->insn;
397	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
398		return -ERANGE;
399	ex->insn = offset;
400
401	/*
402	 * Since the extable follows the program, the fixup offset is always
403	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
404	 * to keep things simple, and put the destination register in the upper
405	 * bits. We don't need to worry about buildtime or runtime sort
406	 * modifying the upper bits because the table is already sorted, and
407	 * isn't part of the main exception table.
408	 */
409	offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
410	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
411		return -ERANGE;
412
413	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
414		    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
415
416	ctx->exentry_idx++;
417	return 0;
418}
419
420/* JITs an eBPF instruction.
421 * Returns:
422 * 0  - successfully JITed an 8-byte eBPF instruction.
423 * >0 - successfully JITed a 16-byte eBPF instruction.
424 * <0 - failed to JIT.
425 */
426static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
427		      bool extra_pass)
428{
429	const u8 code = insn->code;
430	const u8 dst = bpf2a64[insn->dst_reg];
431	const u8 src = bpf2a64[insn->src_reg];
432	const u8 tmp = bpf2a64[TMP_REG_1];
433	const u8 tmp2 = bpf2a64[TMP_REG_2];
434	const u8 tmp3 = bpf2a64[TMP_REG_3];
435	const s16 off = insn->off;
436	const s32 imm = insn->imm;
437	const int i = insn - ctx->prog->insnsi;
438	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
439			  BPF_CLASS(code) == BPF_JMP;
440	const bool isdw = BPF_SIZE(code) == BPF_DW;
441	u8 jmp_cond, reg;
442	s32 jmp_offset;
443	u32 a64_insn;
444	int ret;
445
446#define check_imm(bits, imm) do {				\
447	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
448	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
449		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
450			i, imm, imm);				\
451		return -EINVAL;					\
452	}							\
453} while (0)
454#define check_imm19(imm) check_imm(19, imm)
455#define check_imm26(imm) check_imm(26, imm)
456
457	switch (code) {
458	/* dst = src */
459	case BPF_ALU | BPF_MOV | BPF_X:
460	case BPF_ALU64 | BPF_MOV | BPF_X:
461		emit(A64_MOV(is64, dst, src), ctx);
462		break;
463	/* dst = dst OP src */
464	case BPF_ALU | BPF_ADD | BPF_X:
465	case BPF_ALU64 | BPF_ADD | BPF_X:
466		emit(A64_ADD(is64, dst, dst, src), ctx);
467		break;
468	case BPF_ALU | BPF_SUB | BPF_X:
469	case BPF_ALU64 | BPF_SUB | BPF_X:
470		emit(A64_SUB(is64, dst, dst, src), ctx);
471		break;
472	case BPF_ALU | BPF_AND | BPF_X:
473	case BPF_ALU64 | BPF_AND | BPF_X:
474		emit(A64_AND(is64, dst, dst, src), ctx);
475		break;
476	case BPF_ALU | BPF_OR | BPF_X:
477	case BPF_ALU64 | BPF_OR | BPF_X:
478		emit(A64_ORR(is64, dst, dst, src), ctx);
479		break;
480	case BPF_ALU | BPF_XOR | BPF_X:
481	case BPF_ALU64 | BPF_XOR | BPF_X:
482		emit(A64_EOR(is64, dst, dst, src), ctx);
483		break;
484	case BPF_ALU | BPF_MUL | BPF_X:
485	case BPF_ALU64 | BPF_MUL | BPF_X:
486		emit(A64_MUL(is64, dst, dst, src), ctx);
487		break;
488	case BPF_ALU | BPF_DIV | BPF_X:
489	case BPF_ALU64 | BPF_DIV | BPF_X:
490	case BPF_ALU | BPF_MOD | BPF_X:
491	case BPF_ALU64 | BPF_MOD | BPF_X:
492		switch (BPF_OP(code)) {
493		case BPF_DIV:
494			emit(A64_UDIV(is64, dst, dst, src), ctx);
495			break;
496		case BPF_MOD:
497			emit(A64_UDIV(is64, tmp, dst, src), ctx);
498			emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
499			break;
500		}
501		break;
502	case BPF_ALU | BPF_LSH | BPF_X:
503	case BPF_ALU64 | BPF_LSH | BPF_X:
504		emit(A64_LSLV(is64, dst, dst, src), ctx);
505		break;
506	case BPF_ALU | BPF_RSH | BPF_X:
507	case BPF_ALU64 | BPF_RSH | BPF_X:
508		emit(A64_LSRV(is64, dst, dst, src), ctx);
509		break;
510	case BPF_ALU | BPF_ARSH | BPF_X:
511	case BPF_ALU64 | BPF_ARSH | BPF_X:
512		emit(A64_ASRV(is64, dst, dst, src), ctx);
513		break;
514	/* dst = -dst */
515	case BPF_ALU | BPF_NEG:
516	case BPF_ALU64 | BPF_NEG:
517		emit(A64_NEG(is64, dst, dst), ctx);
518		break;
519	/* dst = BSWAP##imm(dst) */
520	case BPF_ALU | BPF_END | BPF_FROM_LE:
521	case BPF_ALU | BPF_END | BPF_FROM_BE:
522#ifdef CONFIG_CPU_BIG_ENDIAN
523		if (BPF_SRC(code) == BPF_FROM_BE)
524			goto emit_bswap_uxt;
525#else /* !CONFIG_CPU_BIG_ENDIAN */
526		if (BPF_SRC(code) == BPF_FROM_LE)
527			goto emit_bswap_uxt;
528#endif
529		switch (imm) {
530		case 16:
531			emit(A64_REV16(is64, dst, dst), ctx);
532			/* zero-extend 16 bits into 64 bits */
533			emit(A64_UXTH(is64, dst, dst), ctx);
534			break;
535		case 32:
536			emit(A64_REV32(is64, dst, dst), ctx);
537			/* upper 32 bits already cleared */
538			break;
539		case 64:
540			emit(A64_REV64(dst, dst), ctx);
541			break;
542		}
543		break;
544emit_bswap_uxt:
545		switch (imm) {
546		case 16:
547			/* zero-extend 16 bits into 64 bits */
548			emit(A64_UXTH(is64, dst, dst), ctx);
549			break;
550		case 32:
551			/* zero-extend 32 bits into 64 bits */
552			emit(A64_UXTW(is64, dst, dst), ctx);
553			break;
554		case 64:
555			/* nop */
556			break;
557		}
558		break;
559	/* dst = imm */
560	case BPF_ALU | BPF_MOV | BPF_K:
561	case BPF_ALU64 | BPF_MOV | BPF_K:
562		emit_a64_mov_i(is64, dst, imm, ctx);
563		break;
564	/* dst = dst OP imm */
565	case BPF_ALU | BPF_ADD | BPF_K:
566	case BPF_ALU64 | BPF_ADD | BPF_K:
567		if (is_addsub_imm(imm)) {
568			emit(A64_ADD_I(is64, dst, dst, imm), ctx);
569		} else if (is_addsub_imm(-imm)) {
570			emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
571		} else {
572			emit_a64_mov_i(is64, tmp, imm, ctx);
573			emit(A64_ADD(is64, dst, dst, tmp), ctx);
574		}
575		break;
576	case BPF_ALU | BPF_SUB | BPF_K:
577	case BPF_ALU64 | BPF_SUB | BPF_K:
578		if (is_addsub_imm(imm)) {
579			emit(A64_SUB_I(is64, dst, dst, imm), ctx);
580		} else if (is_addsub_imm(-imm)) {
581			emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
582		} else {
583			emit_a64_mov_i(is64, tmp, imm, ctx);
584			emit(A64_SUB(is64, dst, dst, tmp), ctx);
585		}
586		break;
587	case BPF_ALU | BPF_AND | BPF_K:
588	case BPF_ALU64 | BPF_AND | BPF_K:
589		a64_insn = A64_AND_I(is64, dst, dst, imm);
590		if (a64_insn != AARCH64_BREAK_FAULT) {
591			emit(a64_insn, ctx);
592		} else {
593			emit_a64_mov_i(is64, tmp, imm, ctx);
594			emit(A64_AND(is64, dst, dst, tmp), ctx);
595		}
596		break;
597	case BPF_ALU | BPF_OR | BPF_K:
598	case BPF_ALU64 | BPF_OR | BPF_K:
599		a64_insn = A64_ORR_I(is64, dst, dst, imm);
600		if (a64_insn != AARCH64_BREAK_FAULT) {
601			emit(a64_insn, ctx);
602		} else {
603			emit_a64_mov_i(is64, tmp, imm, ctx);
604			emit(A64_ORR(is64, dst, dst, tmp), ctx);
605		}
606		break;
607	case BPF_ALU | BPF_XOR | BPF_K:
608	case BPF_ALU64 | BPF_XOR | BPF_K:
609		a64_insn = A64_EOR_I(is64, dst, dst, imm);
610		if (a64_insn != AARCH64_BREAK_FAULT) {
611			emit(a64_insn, ctx);
612		} else {
613			emit_a64_mov_i(is64, tmp, imm, ctx);
614			emit(A64_EOR(is64, dst, dst, tmp), ctx);
615		}
616		break;
617	case BPF_ALU | BPF_MUL | BPF_K:
618	case BPF_ALU64 | BPF_MUL | BPF_K:
619		emit_a64_mov_i(is64, tmp, imm, ctx);
620		emit(A64_MUL(is64, dst, dst, tmp), ctx);
621		break;
622	case BPF_ALU | BPF_DIV | BPF_K:
623	case BPF_ALU64 | BPF_DIV | BPF_K:
624		emit_a64_mov_i(is64, tmp, imm, ctx);
625		emit(A64_UDIV(is64, dst, dst, tmp), ctx);
626		break;
627	case BPF_ALU | BPF_MOD | BPF_K:
628	case BPF_ALU64 | BPF_MOD | BPF_K:
629		emit_a64_mov_i(is64, tmp2, imm, ctx);
630		emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
631		emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
632		break;
633	case BPF_ALU | BPF_LSH | BPF_K:
634	case BPF_ALU64 | BPF_LSH | BPF_K:
635		emit(A64_LSL(is64, dst, dst, imm), ctx);
636		break;
637	case BPF_ALU | BPF_RSH | BPF_K:
638	case BPF_ALU64 | BPF_RSH | BPF_K:
639		emit(A64_LSR(is64, dst, dst, imm), ctx);
640		break;
641	case BPF_ALU | BPF_ARSH | BPF_K:
642	case BPF_ALU64 | BPF_ARSH | BPF_K:
643		emit(A64_ASR(is64, dst, dst, imm), ctx);
644		break;
645
646	/* JUMP off */
647	case BPF_JMP | BPF_JA:
648		jmp_offset = bpf2a64_offset(i, off, ctx);
649		check_imm26(jmp_offset);
650		emit(A64_B(jmp_offset), ctx);
651		break;
652	/* IF (dst COND src) JUMP off */
653	case BPF_JMP | BPF_JEQ | BPF_X:
654	case BPF_JMP | BPF_JGT | BPF_X:
655	case BPF_JMP | BPF_JLT | BPF_X:
656	case BPF_JMP | BPF_JGE | BPF_X:
657	case BPF_JMP | BPF_JLE | BPF_X:
658	case BPF_JMP | BPF_JNE | BPF_X:
659	case BPF_JMP | BPF_JSGT | BPF_X:
660	case BPF_JMP | BPF_JSLT | BPF_X:
661	case BPF_JMP | BPF_JSGE | BPF_X:
662	case BPF_JMP | BPF_JSLE | BPF_X:
663	case BPF_JMP32 | BPF_JEQ | BPF_X:
664	case BPF_JMP32 | BPF_JGT | BPF_X:
665	case BPF_JMP32 | BPF_JLT | BPF_X:
666	case BPF_JMP32 | BPF_JGE | BPF_X:
667	case BPF_JMP32 | BPF_JLE | BPF_X:
668	case BPF_JMP32 | BPF_JNE | BPF_X:
669	case BPF_JMP32 | BPF_JSGT | BPF_X:
670	case BPF_JMP32 | BPF_JSLT | BPF_X:
671	case BPF_JMP32 | BPF_JSGE | BPF_X:
672	case BPF_JMP32 | BPF_JSLE | BPF_X:
673		emit(A64_CMP(is64, dst, src), ctx);
674emit_cond_jmp:
675		jmp_offset = bpf2a64_offset(i, off, ctx);
676		check_imm19(jmp_offset);
677		switch (BPF_OP(code)) {
678		case BPF_JEQ:
679			jmp_cond = A64_COND_EQ;
680			break;
681		case BPF_JGT:
682			jmp_cond = A64_COND_HI;
683			break;
684		case BPF_JLT:
685			jmp_cond = A64_COND_CC;
686			break;
687		case BPF_JGE:
688			jmp_cond = A64_COND_CS;
689			break;
690		case BPF_JLE:
691			jmp_cond = A64_COND_LS;
692			break;
693		case BPF_JSET:
694		case BPF_JNE:
695			jmp_cond = A64_COND_NE;
696			break;
697		case BPF_JSGT:
698			jmp_cond = A64_COND_GT;
699			break;
700		case BPF_JSLT:
701			jmp_cond = A64_COND_LT;
702			break;
703		case BPF_JSGE:
704			jmp_cond = A64_COND_GE;
705			break;
706		case BPF_JSLE:
707			jmp_cond = A64_COND_LE;
708			break;
709		default:
710			return -EFAULT;
711		}
712		emit(A64_B_(jmp_cond, jmp_offset), ctx);
713		break;
714	case BPF_JMP | BPF_JSET | BPF_X:
715	case BPF_JMP32 | BPF_JSET | BPF_X:
716		emit(A64_TST(is64, dst, src), ctx);
717		goto emit_cond_jmp;
718	/* IF (dst COND imm) JUMP off */
719	case BPF_JMP | BPF_JEQ | BPF_K:
720	case BPF_JMP | BPF_JGT | BPF_K:
721	case BPF_JMP | BPF_JLT | BPF_K:
722	case BPF_JMP | BPF_JGE | BPF_K:
723	case BPF_JMP | BPF_JLE | BPF_K:
724	case BPF_JMP | BPF_JNE | BPF_K:
725	case BPF_JMP | BPF_JSGT | BPF_K:
726	case BPF_JMP | BPF_JSLT | BPF_K:
727	case BPF_JMP | BPF_JSGE | BPF_K:
728	case BPF_JMP | BPF_JSLE | BPF_K:
729	case BPF_JMP32 | BPF_JEQ | BPF_K:
730	case BPF_JMP32 | BPF_JGT | BPF_K:
731	case BPF_JMP32 | BPF_JLT | BPF_K:
732	case BPF_JMP32 | BPF_JGE | BPF_K:
733	case BPF_JMP32 | BPF_JLE | BPF_K:
734	case BPF_JMP32 | BPF_JNE | BPF_K:
735	case BPF_JMP32 | BPF_JSGT | BPF_K:
736	case BPF_JMP32 | BPF_JSLT | BPF_K:
737	case BPF_JMP32 | BPF_JSGE | BPF_K:
738	case BPF_JMP32 | BPF_JSLE | BPF_K:
739		if (is_addsub_imm(imm)) {
740			emit(A64_CMP_I(is64, dst, imm), ctx);
741		} else if (is_addsub_imm(-imm)) {
742			emit(A64_CMN_I(is64, dst, -imm), ctx);
743		} else {
744			emit_a64_mov_i(is64, tmp, imm, ctx);
745			emit(A64_CMP(is64, dst, tmp), ctx);
746		}
747		goto emit_cond_jmp;
748	case BPF_JMP | BPF_JSET | BPF_K:
749	case BPF_JMP32 | BPF_JSET | BPF_K:
750		a64_insn = A64_TST_I(is64, dst, imm);
751		if (a64_insn != AARCH64_BREAK_FAULT) {
752			emit(a64_insn, ctx);
753		} else {
754			emit_a64_mov_i(is64, tmp, imm, ctx);
755			emit(A64_TST(is64, dst, tmp), ctx);
756		}
757		goto emit_cond_jmp;
758	/* function call */
759	case BPF_JMP | BPF_CALL:
760	{
761		const u8 r0 = bpf2a64[BPF_REG_0];
762		bool func_addr_fixed;
763		u64 func_addr;
764
765		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
766					    &func_addr, &func_addr_fixed);
767		if (ret < 0)
768			return ret;
769		emit_addr_mov_i64(tmp, func_addr, ctx);
770		emit(A64_BLR(tmp), ctx);
771		emit(A64_MOV(1, r0, A64_R(0)), ctx);
772		break;
773	}
774	/* tail call */
775	case BPF_JMP | BPF_TAIL_CALL:
776		if (emit_bpf_tail_call(ctx))
777			return -EFAULT;
778		break;
779	/* function return */
780	case BPF_JMP | BPF_EXIT:
781		/* Optimization: when last instruction is EXIT,
782		   simply fallthrough to epilogue. */
783		if (i == ctx->prog->len - 1)
784			break;
785		jmp_offset = epilogue_offset(ctx);
786		check_imm26(jmp_offset);
787		emit(A64_B(jmp_offset), ctx);
788		break;
789
790	/* dst = imm64 */
791	case BPF_LD | BPF_IMM | BPF_DW:
792	{
793		const struct bpf_insn insn1 = insn[1];
794		u64 imm64;
795
796		imm64 = (u64)insn1.imm << 32 | (u32)imm;
797		emit_a64_mov_i64(dst, imm64, ctx);
798
799		return 1;
800	}
801
802	/* LDX: dst = *(size *)(src + off) */
803	case BPF_LDX | BPF_MEM | BPF_W:
804	case BPF_LDX | BPF_MEM | BPF_H:
805	case BPF_LDX | BPF_MEM | BPF_B:
806	case BPF_LDX | BPF_MEM | BPF_DW:
807	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
808	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
809	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
810	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
811		emit_a64_mov_i(1, tmp, off, ctx);
812		switch (BPF_SIZE(code)) {
813		case BPF_W:
814			emit(A64_LDR32(dst, src, tmp), ctx);
815			break;
816		case BPF_H:
817			emit(A64_LDRH(dst, src, tmp), ctx);
818			break;
819		case BPF_B:
820			emit(A64_LDRB(dst, src, tmp), ctx);
821			break;
822		case BPF_DW:
823			emit(A64_LDR64(dst, src, tmp), ctx);
824			break;
825		}
826
827		ret = add_exception_handler(insn, ctx, dst);
828		if (ret)
829			return ret;
830		break;
831
832	/* speculation barrier */
833	case BPF_ST | BPF_NOSPEC:
834		/*
835		 * Nothing required here.
836		 *
837		 * In case of arm64, we rely on the firmware mitigation of
838		 * Speculative Store Bypass as controlled via the ssbd kernel
839		 * parameter. Whenever the mitigation is enabled, it works
840		 * for all of the kernel code with no need to provide any
841		 * additional instructions.
842		 */
843		break;
844
845	/* ST: *(size *)(dst + off) = imm */
846	case BPF_ST | BPF_MEM | BPF_W:
847	case BPF_ST | BPF_MEM | BPF_H:
848	case BPF_ST | BPF_MEM | BPF_B:
849	case BPF_ST | BPF_MEM | BPF_DW:
850		/* Load imm to a register then store it */
851		emit_a64_mov_i(1, tmp2, off, ctx);
852		emit_a64_mov_i(1, tmp, imm, ctx);
853		switch (BPF_SIZE(code)) {
854		case BPF_W:
855			emit(A64_STR32(tmp, dst, tmp2), ctx);
856			break;
857		case BPF_H:
858			emit(A64_STRH(tmp, dst, tmp2), ctx);
859			break;
860		case BPF_B:
861			emit(A64_STRB(tmp, dst, tmp2), ctx);
862			break;
863		case BPF_DW:
864			emit(A64_STR64(tmp, dst, tmp2), ctx);
865			break;
866		}
867		break;
868
869	/* STX: *(size *)(dst + off) = src */
870	case BPF_STX | BPF_MEM | BPF_W:
871	case BPF_STX | BPF_MEM | BPF_H:
872	case BPF_STX | BPF_MEM | BPF_B:
873	case BPF_STX | BPF_MEM | BPF_DW:
874		emit_a64_mov_i(1, tmp, off, ctx);
875		switch (BPF_SIZE(code)) {
876		case BPF_W:
877			emit(A64_STR32(src, dst, tmp), ctx);
878			break;
879		case BPF_H:
880			emit(A64_STRH(src, dst, tmp), ctx);
881			break;
882		case BPF_B:
883			emit(A64_STRB(src, dst, tmp), ctx);
884			break;
885		case BPF_DW:
886			emit(A64_STR64(src, dst, tmp), ctx);
887			break;
888		}
889		break;
890
891	/* STX XADD: lock *(u32 *)(dst + off) += src */
892	case BPF_STX | BPF_XADD | BPF_W:
893	/* STX XADD: lock *(u64 *)(dst + off) += src */
894	case BPF_STX | BPF_XADD | BPF_DW:
895		if (!off) {
896			reg = dst;
897		} else {
898			emit_a64_mov_i(1, tmp, off, ctx);
899			emit(A64_ADD(1, tmp, tmp, dst), ctx);
900			reg = tmp;
901		}
902		if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS)) {
903			emit(A64_STADD(isdw, reg, src), ctx);
904		} else {
905			emit(A64_LDXR(isdw, tmp2, reg), ctx);
906			emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
907			emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
908			jmp_offset = -3;
909			check_imm19(jmp_offset);
910			emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
911		}
912		break;
913
914	default:
915		pr_err_once("unknown opcode %02x\n", code);
916		return -EINVAL;
917	}
918
919	return 0;
920}
921
922static int build_body(struct jit_ctx *ctx, bool extra_pass)
923{
924	const struct bpf_prog *prog = ctx->prog;
925	int i;
926
927	/*
928	 * - offset[0] offset of the end of prologue,
929	 *   start of the 1st instruction.
930	 * - offset[1] - offset of the end of 1st instruction,
931	 *   start of the 2nd instruction
932	 * [....]
933	 * - offset[3] - offset of the end of 3rd instruction,
934	 *   start of 4th instruction
935	 */
936	for (i = 0; i < prog->len; i++) {
937		const struct bpf_insn *insn = &prog->insnsi[i];
938		int ret;
939
940		if (ctx->image == NULL)
941			ctx->offset[i] = ctx->idx;
942		ret = build_insn(insn, ctx, extra_pass);
943		if (ret > 0) {
944			i++;
945			if (ctx->image == NULL)
946				ctx->offset[i] = ctx->idx;
947			continue;
948		}
949		if (ret)
950			return ret;
951	}
952	/*
953	 * offset is allocated with prog->len + 1 so fill in
954	 * the last element with the offset after the last
955	 * instruction (end of program)
956	 */
957	if (ctx->image == NULL)
958		ctx->offset[i] = ctx->idx;
959
960	return 0;
961}
962
963static int validate_code(struct jit_ctx *ctx)
964{
965	int i;
966
967	for (i = 0; i < ctx->idx; i++) {
968		u32 a64_insn = le32_to_cpu(ctx->image[i]);
969
970		if (a64_insn == AARCH64_BREAK_FAULT)
971			return -1;
972	}
973
974	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
975		return -1;
976
977	return 0;
978}
979
980static inline void bpf_flush_icache(void *start, void *end)
981{
982	flush_icache_range((unsigned long)start, (unsigned long)end);
983}
984
985struct arm64_jit_data {
986	struct bpf_binary_header *header;
987	u8 *image;
988	struct jit_ctx ctx;
989};
990
991struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
992{
993	int image_size, prog_size, extable_size;
994	struct bpf_prog *tmp, *orig_prog = prog;
995	struct bpf_binary_header *header;
996	struct arm64_jit_data *jit_data;
997	bool was_classic = bpf_prog_was_classic(prog);
998	bool tmp_blinded = false;
999	bool extra_pass = false;
1000	struct jit_ctx ctx;
1001	u8 *image_ptr;
1002
1003	if (!prog->jit_requested)
1004		return orig_prog;
1005
1006	tmp = bpf_jit_blind_constants(prog);
1007	/* If blinding was requested and we failed during blinding,
1008	 * we must fall back to the interpreter.
1009	 */
1010	if (IS_ERR(tmp))
1011		return orig_prog;
1012	if (tmp != prog) {
1013		tmp_blinded = true;
1014		prog = tmp;
1015	}
1016
1017	jit_data = prog->aux->jit_data;
1018	if (!jit_data) {
1019		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1020		if (!jit_data) {
1021			prog = orig_prog;
1022			goto out;
1023		}
1024		prog->aux->jit_data = jit_data;
1025	}
1026	if (jit_data->ctx.offset) {
1027		ctx = jit_data->ctx;
1028		image_ptr = jit_data->image;
1029		header = jit_data->header;
1030		extra_pass = true;
1031		prog_size = sizeof(u32) * ctx.idx;
1032		goto skip_init_ctx;
1033	}
1034	memset(&ctx, 0, sizeof(ctx));
1035	ctx.prog = prog;
1036
1037	ctx.offset = kcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
1038	if (ctx.offset == NULL) {
1039		prog = orig_prog;
1040		goto out_off;
1041	}
1042
1043	/*
1044	 * 1. Initial fake pass to compute ctx->idx and ctx->offset.
1045	 *
1046	 * BPF line info needs ctx->offset[i] to be the offset of
1047	 * instruction[i] in jited image, so build prologue first.
1048	 */
1049	if (build_prologue(&ctx, was_classic)) {
1050		prog = orig_prog;
1051		goto out_off;
1052	}
1053
1054	if (build_body(&ctx, extra_pass)) {
1055		prog = orig_prog;
1056		goto out_off;
1057	}
1058
1059	ctx.epilogue_offset = ctx.idx;
1060	build_epilogue(&ctx);
1061
1062	extable_size = prog->aux->num_exentries *
1063		sizeof(struct exception_table_entry);
1064
1065	/* Now we know the actual image size. */
1066	prog_size = sizeof(u32) * ctx.idx;
1067	image_size = prog_size + extable_size;
1068	header = bpf_jit_binary_alloc(image_size, &image_ptr,
1069				      sizeof(u32), jit_fill_hole);
1070	if (header == NULL) {
1071		prog = orig_prog;
1072		goto out_off;
1073	}
1074
1075	/* 2. Now, the actual pass. */
1076
1077	ctx.image = (__le32 *)image_ptr;
1078	if (extable_size)
1079		prog->aux->extable = (void *)image_ptr + prog_size;
1080skip_init_ctx:
1081	ctx.idx = 0;
1082	ctx.exentry_idx = 0;
1083
1084	build_prologue(&ctx, was_classic);
1085
1086	if (build_body(&ctx, extra_pass)) {
1087		bpf_jit_binary_free(header);
1088		prog = orig_prog;
1089		goto out_off;
1090	}
1091
1092	build_epilogue(&ctx);
1093
1094	/* 3. Extra pass to validate JITed code. */
1095	if (validate_code(&ctx)) {
1096		bpf_jit_binary_free(header);
1097		prog = orig_prog;
1098		goto out_off;
1099	}
1100
1101	/* And we're done. */
1102	if (bpf_jit_enable > 1)
1103		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1104
1105	bpf_flush_icache(header, ctx.image + ctx.idx);
1106
1107	if (!prog->is_func || extra_pass) {
1108		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1109			pr_err_once("multi-func JIT bug %d != %d\n",
1110				    ctx.idx, jit_data->ctx.idx);
1111			bpf_jit_binary_free(header);
1112			prog->bpf_func = NULL;
1113			prog->jited = 0;
1114			prog->jited_len = 0;
1115			goto out_off;
1116		}
1117		bpf_jit_binary_lock_ro(header);
1118	} else {
1119		jit_data->ctx = ctx;
1120		jit_data->image = image_ptr;
1121		jit_data->header = header;
1122	}
1123	prog->bpf_func = (void *)ctx.image;
1124	prog->jited = 1;
1125	prog->jited_len = prog_size;
1126
1127	if (!prog->is_func || extra_pass) {
1128		int i;
1129
1130		/* offset[prog->len] is the size of program */
1131		for (i = 0; i <= prog->len; i++)
1132			ctx.offset[i] *= AARCH64_INSN_SIZE;
1133		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1134out_off:
1135		kfree(ctx.offset);
1136		kfree(jit_data);
1137		prog->aux->jit_data = NULL;
1138	}
1139out:
1140	if (tmp_blinded)
1141		bpf_jit_prog_release_other(prog, prog == orig_prog ?
1142					   tmp : orig_prog);
1143	return prog;
1144}
1145
1146u64 bpf_jit_alloc_exec_limit(void)
1147{
1148	return BPF_JIT_REGION_SIZE;
1149}
1150
1151void *bpf_jit_alloc_exec(unsigned long size)
1152{
1153	return __vmalloc_node_range(size, PAGE_SIZE, BPF_JIT_REGION_START,
1154				    BPF_JIT_REGION_END, GFP_KERNEL,
1155				    PAGE_KERNEL, 0, NUMA_NO_NODE,
1156				    __builtin_return_address(0));
1157}
1158
1159void bpf_jit_free_exec(void *addr)
1160{
1161	return vfree(addr);
1162}
1163