1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * BPF JIT compiler for ARM64
4 *
5 * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
6 */
7
8#define pr_fmt(fmt) "bpf_jit: " fmt
9
10#include <linux/bitfield.h>
11#include <linux/bpf.h>
12#include <linux/filter.h>
13#include <linux/memory.h>
14#include <linux/printk.h>
15#include <linux/slab.h>
16
17#include <asm/asm-extable.h>
18#include <asm/byteorder.h>
19#include <asm/cacheflush.h>
20#include <asm/debug-monitors.h>
21#include <asm/insn.h>
22#include <asm/patching.h>
23#include <asm/set_memory.h>
24
25#include "bpf_jit.h"
26
27#define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
28#define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
29#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
30#define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
31#define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
32
33#define check_imm(bits, imm) do {				\
34	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
35	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
36		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
37			i, imm, imm);				\
38		return -EINVAL;					\
39	}							\
40} while (0)
41#define check_imm19(imm) check_imm(19, imm)
42#define check_imm26(imm) check_imm(26, imm)
43
44/* Map BPF registers to A64 registers */
45static const int bpf2a64[] = {
46	/* return value from in-kernel function, and exit value from eBPF */
47	[BPF_REG_0] = A64_R(7),
48	/* arguments from eBPF program to in-kernel function */
49	[BPF_REG_1] = A64_R(0),
50	[BPF_REG_2] = A64_R(1),
51	[BPF_REG_3] = A64_R(2),
52	[BPF_REG_4] = A64_R(3),
53	[BPF_REG_5] = A64_R(4),
54	/* callee saved registers that in-kernel function will preserve */
55	[BPF_REG_6] = A64_R(19),
56	[BPF_REG_7] = A64_R(20),
57	[BPF_REG_8] = A64_R(21),
58	[BPF_REG_9] = A64_R(22),
59	/* read-only frame pointer to access stack */
60	[BPF_REG_FP] = A64_R(25),
61	/* temporary registers for BPF JIT */
62	[TMP_REG_1] = A64_R(10),
63	[TMP_REG_2] = A64_R(11),
64	[TMP_REG_3] = A64_R(12),
65	/* tail_call_cnt */
66	[TCALL_CNT] = A64_R(26),
67	/* temporary register for blinding constants */
68	[BPF_REG_AX] = A64_R(9),
69	[FP_BOTTOM] = A64_R(27),
70};
71
72struct jit_ctx {
73	const struct bpf_prog *prog;
74	int idx;
75	int epilogue_offset;
76	int *offset;
77	int exentry_idx;
78	__le32 *image;
79	u32 stack_size;
80	int fpb_offset;
81};
82
83struct bpf_plt {
84	u32 insn_ldr; /* load target */
85	u32 insn_br;  /* branch to target */
86	u64 target;   /* target value */
87};
88
89#define PLT_TARGET_SIZE   sizeof_field(struct bpf_plt, target)
90#define PLT_TARGET_OFFSET offsetof(struct bpf_plt, target)
91
92static inline void emit(const u32 insn, struct jit_ctx *ctx)
93{
94	if (ctx->image != NULL)
95		ctx->image[ctx->idx] = cpu_to_le32(insn);
96
97	ctx->idx++;
98}
99
100static inline void emit_a64_mov_i(const int is64, const int reg,
101				  const s32 val, struct jit_ctx *ctx)
102{
103	u16 hi = val >> 16;
104	u16 lo = val & 0xffff;
105
106	if (hi & 0x8000) {
107		if (hi == 0xffff) {
108			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
109		} else {
110			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
111			if (lo != 0xffff)
112				emit(A64_MOVK(is64, reg, lo, 0), ctx);
113		}
114	} else {
115		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
116		if (hi)
117			emit(A64_MOVK(is64, reg, hi, 16), ctx);
118	}
119}
120
121static int i64_i16_blocks(const u64 val, bool inverse)
122{
123	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
124	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
125	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
126	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
127}
128
129static inline void emit_a64_mov_i64(const int reg, const u64 val,
130				    struct jit_ctx *ctx)
131{
132	u64 nrm_tmp = val, rev_tmp = ~val;
133	bool inverse;
134	int shift;
135
136	if (!(nrm_tmp >> 32))
137		return emit_a64_mov_i(0, reg, (u32)val, ctx);
138
139	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
140	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
141					  (fls64(nrm_tmp) - 1)), 16), 0);
142	if (inverse)
143		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
144	else
145		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
146	shift -= 16;
147	while (shift >= 0) {
148		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
149			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
150		shift -= 16;
151	}
152}
153
154static inline void emit_bti(u32 insn, struct jit_ctx *ctx)
155{
156	if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
157		emit(insn, ctx);
158}
159
160/*
161 * Kernel addresses in the vmalloc space use at most 48 bits, and the
162 * remaining bits are guaranteed to be 0x1. So we can compose the address
163 * with a fixed length movn/movk/movk sequence.
164 */
165static inline void emit_addr_mov_i64(const int reg, const u64 val,
166				     struct jit_ctx *ctx)
167{
168	u64 tmp = val;
169	int shift = 0;
170
171	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
172	while (shift < 32) {
173		tmp >>= 16;
174		shift += 16;
175		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
176	}
177}
178
179static inline void emit_call(u64 target, struct jit_ctx *ctx)
180{
181	u8 tmp = bpf2a64[TMP_REG_1];
182
183	emit_addr_mov_i64(tmp, target, ctx);
184	emit(A64_BLR(tmp), ctx);
185}
186
187static inline int bpf2a64_offset(int bpf_insn, int off,
188				 const struct jit_ctx *ctx)
189{
190	/* BPF JMP offset is relative to the next instruction */
191	bpf_insn++;
192	/*
193	 * Whereas arm64 branch instructions encode the offset
194	 * from the branch itself, so we must subtract 1 from the
195	 * instruction offset.
196	 */
197	return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
198}
199
200static void jit_fill_hole(void *area, unsigned int size)
201{
202	__le32 *ptr;
203	/* We are guaranteed to have aligned memory. */
204	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
205		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
206}
207
208static inline int epilogue_offset(const struct jit_ctx *ctx)
209{
210	int to = ctx->epilogue_offset;
211	int from = ctx->idx;
212
213	return to - from;
214}
215
216static bool is_addsub_imm(u32 imm)
217{
218	/* Either imm12 or shifted imm12. */
219	return !(imm & ~0xfff) || !(imm & ~0xfff000);
220}
221
222/*
223 * There are 3 types of AArch64 LDR/STR (immediate) instruction:
224 * Post-index, Pre-index, Unsigned offset.
225 *
226 * For BPF ldr/str, the "unsigned offset" type is sufficient.
227 *
228 * "Unsigned offset" type LDR(immediate) format:
229 *
230 *    3                   2                   1                   0
231 *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
232 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
233 * |x x|1 1 1 0 0 1 0 1|         imm12         |    Rn   |    Rt   |
234 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
235 * scale
236 *
237 * "Unsigned offset" type STR(immediate) format:
238 *    3                   2                   1                   0
239 *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
240 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
241 * |x x|1 1 1 0 0 1 0 0|         imm12         |    Rn   |    Rt   |
242 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
243 * scale
244 *
245 * The offset is calculated from imm12 and scale in the following way:
246 *
247 * offset = (u64)imm12 << scale
248 */
249static bool is_lsi_offset(int offset, int scale)
250{
251	if (offset < 0)
252		return false;
253
254	if (offset > (0xFFF << scale))
255		return false;
256
257	if (offset & ((1 << scale) - 1))
258		return false;
259
260	return true;
261}
262
263/* generated prologue:
264 *      bti c // if CONFIG_ARM64_BTI_KERNEL
265 *      mov x9, lr
266 *      nop  // POKE_OFFSET
267 *      paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL
268 *      stp x29, lr, [sp, #-16]!
269 *      mov x29, sp
270 *      stp x19, x20, [sp, #-16]!
271 *      stp x21, x22, [sp, #-16]!
272 *      stp x25, x26, [sp, #-16]!
273 *      stp x27, x28, [sp, #-16]!
274 *      mov x25, sp
275 *      mov tcc, #0
276 *      // PROLOGUE_OFFSET
277 */
278
279#define BTI_INSNS (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) ? 1 : 0)
280#define PAC_INSNS (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL) ? 1 : 0)
281
282/* Offset of nop instruction in bpf prog entry to be poked */
283#define POKE_OFFSET (BTI_INSNS + 1)
284
285/* Tail call offset to jump into */
286#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 8)
287
288static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
289{
290	const struct bpf_prog *prog = ctx->prog;
291	const bool is_main_prog = prog->aux->func_idx == 0;
292	const u8 r6 = bpf2a64[BPF_REG_6];
293	const u8 r7 = bpf2a64[BPF_REG_7];
294	const u8 r8 = bpf2a64[BPF_REG_8];
295	const u8 r9 = bpf2a64[BPF_REG_9];
296	const u8 fp = bpf2a64[BPF_REG_FP];
297	const u8 tcc = bpf2a64[TCALL_CNT];
298	const u8 fpb = bpf2a64[FP_BOTTOM];
299	const int idx0 = ctx->idx;
300	int cur_offset;
301
302	/*
303	 * BPF prog stack layout
304	 *
305	 *                         high
306	 * original A64_SP =>   0:+-----+ BPF prologue
307	 *                        |FP/LR|
308	 * current A64_FP =>  -16:+-----+
309	 *                        | ... | callee saved registers
310	 * BPF fp register => -64:+-----+ <= (BPF_FP)
311	 *                        |     |
312	 *                        | ... | BPF prog stack
313	 *                        |     |
314	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
315	 *                        |RSVD | padding
316	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
317	 *                        |     |
318	 *                        | ... | Function call stack
319	 *                        |     |
320	 *                        +-----+
321	 *                          low
322	 *
323	 */
324
325	/* bpf function may be invoked by 3 instruction types:
326	 * 1. bl, attached via freplace to bpf prog via short jump
327	 * 2. br, attached via freplace to bpf prog via long jump
328	 * 3. blr, working as a function pointer, used by emit_call.
329	 * So BTI_JC should used here to support both br and blr.
330	 */
331	emit_bti(A64_BTI_JC, ctx);
332
333	emit(A64_MOV(1, A64_R(9), A64_LR), ctx);
334	emit(A64_NOP, ctx);
335
336	/* Sign lr */
337	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
338		emit(A64_PACIASP, ctx);
339
340	/* Save FP and LR registers to stay align with ARM64 AAPCS */
341	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
342	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
343
344	/* Save callee-saved registers */
345	emit(A64_PUSH(r6, r7, A64_SP), ctx);
346	emit(A64_PUSH(r8, r9, A64_SP), ctx);
347	emit(A64_PUSH(fp, tcc, A64_SP), ctx);
348	emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);
349
350	/* Set up BPF prog stack base register */
351	emit(A64_MOV(1, fp, A64_SP), ctx);
352
353	if (!ebpf_from_cbpf && is_main_prog) {
354		/* Initialize tail_call_cnt */
355		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
356
357		cur_offset = ctx->idx - idx0;
358		if (cur_offset != PROLOGUE_OFFSET) {
359			pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
360				    cur_offset, PROLOGUE_OFFSET);
361			return -1;
362		}
363
364		/* BTI landing pad for the tail call, done with a BR */
365		emit_bti(A64_BTI_J, ctx);
366	}
367
368	emit(A64_SUB_I(1, fpb, fp, ctx->fpb_offset), ctx);
369
370	/* Stack must be multiples of 16B */
371	ctx->stack_size = round_up(prog->aux->stack_depth, 16);
372
373	/* Set up function call stack */
374	emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
375	return 0;
376}
377
378static int out_offset = -1; /* initialized on the first pass of build_body() */
379static int emit_bpf_tail_call(struct jit_ctx *ctx)
380{
381	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
382	const u8 r2 = bpf2a64[BPF_REG_2];
383	const u8 r3 = bpf2a64[BPF_REG_3];
384
385	const u8 tmp = bpf2a64[TMP_REG_1];
386	const u8 prg = bpf2a64[TMP_REG_2];
387	const u8 tcc = bpf2a64[TCALL_CNT];
388	const int idx0 = ctx->idx;
389#define cur_offset (ctx->idx - idx0)
390#define jmp_offset (out_offset - (cur_offset))
391	size_t off;
392
393	/* if (index >= array->map.max_entries)
394	 *     goto out;
395	 */
396	off = offsetof(struct bpf_array, map.max_entries);
397	emit_a64_mov_i64(tmp, off, ctx);
398	emit(A64_LDR32(tmp, r2, tmp), ctx);
399	emit(A64_MOV(0, r3, r3), ctx);
400	emit(A64_CMP(0, r3, tmp), ctx);
401	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
402
403	/*
404	 * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
405	 *     goto out;
406	 * tail_call_cnt++;
407	 */
408	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
409	emit(A64_CMP(1, tcc, tmp), ctx);
410	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
411	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
412
413	/* prog = array->ptrs[index];
414	 * if (prog == NULL)
415	 *     goto out;
416	 */
417	off = offsetof(struct bpf_array, ptrs);
418	emit_a64_mov_i64(tmp, off, ctx);
419	emit(A64_ADD(1, tmp, r2, tmp), ctx);
420	emit(A64_LSL(1, prg, r3, 3), ctx);
421	emit(A64_LDR64(prg, tmp, prg), ctx);
422	emit(A64_CBZ(1, prg, jmp_offset), ctx);
423
424	/* goto *(prog->bpf_func + prologue_offset); */
425	off = offsetof(struct bpf_prog, bpf_func);
426	emit_a64_mov_i64(tmp, off, ctx);
427	emit(A64_LDR64(tmp, prg, tmp), ctx);
428	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
429	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
430	emit(A64_BR(tmp), ctx);
431
432	/* out: */
433	if (out_offset == -1)
434		out_offset = cur_offset;
435	if (cur_offset != out_offset) {
436		pr_err_once("tail_call out_offset = %d, expected %d!\n",
437			    cur_offset, out_offset);
438		return -1;
439	}
440	return 0;
441#undef cur_offset
442#undef jmp_offset
443}
444
445#ifdef CONFIG_ARM64_LSE_ATOMICS
446static int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
447{
448	const u8 code = insn->code;
449	const u8 dst = bpf2a64[insn->dst_reg];
450	const u8 src = bpf2a64[insn->src_reg];
451	const u8 tmp = bpf2a64[TMP_REG_1];
452	const u8 tmp2 = bpf2a64[TMP_REG_2];
453	const bool isdw = BPF_SIZE(code) == BPF_DW;
454	const s16 off = insn->off;
455	u8 reg;
456
457	if (!off) {
458		reg = dst;
459	} else {
460		emit_a64_mov_i(1, tmp, off, ctx);
461		emit(A64_ADD(1, tmp, tmp, dst), ctx);
462		reg = tmp;
463	}
464
465	switch (insn->imm) {
466	/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
467	case BPF_ADD:
468		emit(A64_STADD(isdw, reg, src), ctx);
469		break;
470	case BPF_AND:
471		emit(A64_MVN(isdw, tmp2, src), ctx);
472		emit(A64_STCLR(isdw, reg, tmp2), ctx);
473		break;
474	case BPF_OR:
475		emit(A64_STSET(isdw, reg, src), ctx);
476		break;
477	case BPF_XOR:
478		emit(A64_STEOR(isdw, reg, src), ctx);
479		break;
480	/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
481	case BPF_ADD | BPF_FETCH:
482		emit(A64_LDADDAL(isdw, src, reg, src), ctx);
483		break;
484	case BPF_AND | BPF_FETCH:
485		emit(A64_MVN(isdw, tmp2, src), ctx);
486		emit(A64_LDCLRAL(isdw, src, reg, tmp2), ctx);
487		break;
488	case BPF_OR | BPF_FETCH:
489		emit(A64_LDSETAL(isdw, src, reg, src), ctx);
490		break;
491	case BPF_XOR | BPF_FETCH:
492		emit(A64_LDEORAL(isdw, src, reg, src), ctx);
493		break;
494	/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
495	case BPF_XCHG:
496		emit(A64_SWPAL(isdw, src, reg, src), ctx);
497		break;
498	/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
499	case BPF_CMPXCHG:
500		emit(A64_CASAL(isdw, src, reg, bpf2a64[BPF_REG_0]), ctx);
501		break;
502	default:
503		pr_err_once("unknown atomic op code %02x\n", insn->imm);
504		return -EINVAL;
505	}
506
507	return 0;
508}
509#else
510static inline int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
511{
512	return -EINVAL;
513}
514#endif
515
516static int emit_ll_sc_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
517{
518	const u8 code = insn->code;
519	const u8 dst = bpf2a64[insn->dst_reg];
520	const u8 src = bpf2a64[insn->src_reg];
521	const u8 tmp = bpf2a64[TMP_REG_1];
522	const u8 tmp2 = bpf2a64[TMP_REG_2];
523	const u8 tmp3 = bpf2a64[TMP_REG_3];
524	const int i = insn - ctx->prog->insnsi;
525	const s32 imm = insn->imm;
526	const s16 off = insn->off;
527	const bool isdw = BPF_SIZE(code) == BPF_DW;
528	u8 reg;
529	s32 jmp_offset;
530
531	if (!off) {
532		reg = dst;
533	} else {
534		emit_a64_mov_i(1, tmp, off, ctx);
535		emit(A64_ADD(1, tmp, tmp, dst), ctx);
536		reg = tmp;
537	}
538
539	if (imm == BPF_ADD || imm == BPF_AND ||
540	    imm == BPF_OR || imm == BPF_XOR) {
541		/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
542		emit(A64_LDXR(isdw, tmp2, reg), ctx);
543		if (imm == BPF_ADD)
544			emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
545		else if (imm == BPF_AND)
546			emit(A64_AND(isdw, tmp2, tmp2, src), ctx);
547		else if (imm == BPF_OR)
548			emit(A64_ORR(isdw, tmp2, tmp2, src), ctx);
549		else
550			emit(A64_EOR(isdw, tmp2, tmp2, src), ctx);
551		emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
552		jmp_offset = -3;
553		check_imm19(jmp_offset);
554		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
555	} else if (imm == (BPF_ADD | BPF_FETCH) ||
556		   imm == (BPF_AND | BPF_FETCH) ||
557		   imm == (BPF_OR | BPF_FETCH) ||
558		   imm == (BPF_XOR | BPF_FETCH)) {
559		/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
560		const u8 ax = bpf2a64[BPF_REG_AX];
561
562		emit(A64_MOV(isdw, ax, src), ctx);
563		emit(A64_LDXR(isdw, src, reg), ctx);
564		if (imm == (BPF_ADD | BPF_FETCH))
565			emit(A64_ADD(isdw, tmp2, src, ax), ctx);
566		else if (imm == (BPF_AND | BPF_FETCH))
567			emit(A64_AND(isdw, tmp2, src, ax), ctx);
568		else if (imm == (BPF_OR | BPF_FETCH))
569			emit(A64_ORR(isdw, tmp2, src, ax), ctx);
570		else
571			emit(A64_EOR(isdw, tmp2, src, ax), ctx);
572		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
573		jmp_offset = -3;
574		check_imm19(jmp_offset);
575		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
576		emit(A64_DMB_ISH, ctx);
577	} else if (imm == BPF_XCHG) {
578		/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
579		emit(A64_MOV(isdw, tmp2, src), ctx);
580		emit(A64_LDXR(isdw, src, reg), ctx);
581		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
582		jmp_offset = -2;
583		check_imm19(jmp_offset);
584		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
585		emit(A64_DMB_ISH, ctx);
586	} else if (imm == BPF_CMPXCHG) {
587		/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
588		const u8 r0 = bpf2a64[BPF_REG_0];
589
590		emit(A64_MOV(isdw, tmp2, r0), ctx);
591		emit(A64_LDXR(isdw, r0, reg), ctx);
592		emit(A64_EOR(isdw, tmp3, r0, tmp2), ctx);
593		jmp_offset = 4;
594		check_imm19(jmp_offset);
595		emit(A64_CBNZ(isdw, tmp3, jmp_offset), ctx);
596		emit(A64_STLXR(isdw, src, reg, tmp3), ctx);
597		jmp_offset = -4;
598		check_imm19(jmp_offset);
599		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
600		emit(A64_DMB_ISH, ctx);
601	} else {
602		pr_err_once("unknown atomic op code %02x\n", imm);
603		return -EINVAL;
604	}
605
606	return 0;
607}
608
609void dummy_tramp(void);
610
611asm (
612"	.pushsection .text, \"ax\", @progbits\n"
613"	.global dummy_tramp\n"
614"	.type dummy_tramp, %function\n"
615"dummy_tramp:"
616#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
617"	bti j\n" /* dummy_tramp is called via "br x10" */
618#endif
619"	mov x10, x30\n"
620"	mov x30, x9\n"
621"	ret x10\n"
622"	.size dummy_tramp, .-dummy_tramp\n"
623"	.popsection\n"
624);
625
626/* build a plt initialized like this:
627 *
628 * plt:
629 *      ldr tmp, target
630 *      br tmp
631 * target:
632 *      .quad dummy_tramp
633 *
634 * when a long jump trampoline is attached, target is filled with the
635 * trampoline address, and when the trampoline is removed, target is
636 * restored to dummy_tramp address.
637 */
638static void build_plt(struct jit_ctx *ctx)
639{
640	const u8 tmp = bpf2a64[TMP_REG_1];
641	struct bpf_plt *plt = NULL;
642
643	/* make sure target is 64-bit aligned */
644	if ((ctx->idx + PLT_TARGET_OFFSET / AARCH64_INSN_SIZE) % 2)
645		emit(A64_NOP, ctx);
646
647	plt = (struct bpf_plt *)(ctx->image + ctx->idx);
648	/* plt is called via bl, no BTI needed here */
649	emit(A64_LDR64LIT(tmp, 2 * AARCH64_INSN_SIZE), ctx);
650	emit(A64_BR(tmp), ctx);
651
652	if (ctx->image)
653		plt->target = (u64)&dummy_tramp;
654}
655
656static void build_epilogue(struct jit_ctx *ctx)
657{
658	const u8 r0 = bpf2a64[BPF_REG_0];
659	const u8 r6 = bpf2a64[BPF_REG_6];
660	const u8 r7 = bpf2a64[BPF_REG_7];
661	const u8 r8 = bpf2a64[BPF_REG_8];
662	const u8 r9 = bpf2a64[BPF_REG_9];
663	const u8 fp = bpf2a64[BPF_REG_FP];
664	const u8 fpb = bpf2a64[FP_BOTTOM];
665
666	/* We're done with BPF stack */
667	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
668
669	/* Restore x27 and x28 */
670	emit(A64_POP(fpb, A64_R(28), A64_SP), ctx);
671	/* Restore fs (x25) and x26 */
672	emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
673
674	/* Restore callee-saved register */
675	emit(A64_POP(r8, r9, A64_SP), ctx);
676	emit(A64_POP(r6, r7, A64_SP), ctx);
677
678	/* Restore FP/LR registers */
679	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
680
681	/* Set return value */
682	emit(A64_MOV(1, A64_R(0), r0), ctx);
683
684	/* Authenticate lr */
685	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
686		emit(A64_AUTIASP, ctx);
687
688	emit(A64_RET(A64_LR), ctx);
689}
690
691#define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
692#define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
693
694bool ex_handler_bpf(const struct exception_table_entry *ex,
695		    struct pt_regs *regs)
696{
697	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
698	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
699
700	regs->regs[dst_reg] = 0;
701	regs->pc = (unsigned long)&ex->fixup - offset;
702	return true;
703}
704
705/* For accesses to BTF pointers, add an entry to the exception table */
706static int add_exception_handler(const struct bpf_insn *insn,
707				 struct jit_ctx *ctx,
708				 int dst_reg)
709{
710	off_t offset;
711	unsigned long pc;
712	struct exception_table_entry *ex;
713
714	if (!ctx->image)
715		/* First pass */
716		return 0;
717
718	if (BPF_MODE(insn->code) != BPF_PROBE_MEM &&
719		BPF_MODE(insn->code) != BPF_PROBE_MEMSX)
720		return 0;
721
722	if (!ctx->prog->aux->extable ||
723	    WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
724		return -EINVAL;
725
726	ex = &ctx->prog->aux->extable[ctx->exentry_idx];
727	pc = (unsigned long)&ctx->image[ctx->idx - 1];
728
729	offset = pc - (long)&ex->insn;
730	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
731		return -ERANGE;
732	ex->insn = offset;
733
734	/*
735	 * Since the extable follows the program, the fixup offset is always
736	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
737	 * to keep things simple, and put the destination register in the upper
738	 * bits. We don't need to worry about buildtime or runtime sort
739	 * modifying the upper bits because the table is already sorted, and
740	 * isn't part of the main exception table.
741	 */
742	offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
743	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
744		return -ERANGE;
745
746	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
747		    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
748
749	ex->type = EX_TYPE_BPF;
750
751	ctx->exentry_idx++;
752	return 0;
753}
754
755/* JITs an eBPF instruction.
756 * Returns:
757 * 0  - successfully JITed an 8-byte eBPF instruction.
758 * >0 - successfully JITed a 16-byte eBPF instruction.
759 * <0 - failed to JIT.
760 */
761static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
762		      bool extra_pass)
763{
764	const u8 code = insn->code;
765	const u8 dst = bpf2a64[insn->dst_reg];
766	const u8 src = bpf2a64[insn->src_reg];
767	const u8 tmp = bpf2a64[TMP_REG_1];
768	const u8 tmp2 = bpf2a64[TMP_REG_2];
769	const u8 fp = bpf2a64[BPF_REG_FP];
770	const u8 fpb = bpf2a64[FP_BOTTOM];
771	const s16 off = insn->off;
772	const s32 imm = insn->imm;
773	const int i = insn - ctx->prog->insnsi;
774	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
775			  BPF_CLASS(code) == BPF_JMP;
776	u8 jmp_cond;
777	s32 jmp_offset;
778	u32 a64_insn;
779	u8 src_adj;
780	u8 dst_adj;
781	int off_adj;
782	int ret;
783	bool sign_extend;
784
785	switch (code) {
786	/* dst = src */
787	case BPF_ALU | BPF_MOV | BPF_X:
788	case BPF_ALU64 | BPF_MOV | BPF_X:
789		switch (insn->off) {
790		case 0:
791			emit(A64_MOV(is64, dst, src), ctx);
792			break;
793		case 8:
794			emit(A64_SXTB(is64, dst, src), ctx);
795			break;
796		case 16:
797			emit(A64_SXTH(is64, dst, src), ctx);
798			break;
799		case 32:
800			emit(A64_SXTW(is64, dst, src), ctx);
801			break;
802		}
803		break;
804	/* dst = dst OP src */
805	case BPF_ALU | BPF_ADD | BPF_X:
806	case BPF_ALU64 | BPF_ADD | BPF_X:
807		emit(A64_ADD(is64, dst, dst, src), ctx);
808		break;
809	case BPF_ALU | BPF_SUB | BPF_X:
810	case BPF_ALU64 | BPF_SUB | BPF_X:
811		emit(A64_SUB(is64, dst, dst, src), ctx);
812		break;
813	case BPF_ALU | BPF_AND | BPF_X:
814	case BPF_ALU64 | BPF_AND | BPF_X:
815		emit(A64_AND(is64, dst, dst, src), ctx);
816		break;
817	case BPF_ALU | BPF_OR | BPF_X:
818	case BPF_ALU64 | BPF_OR | BPF_X:
819		emit(A64_ORR(is64, dst, dst, src), ctx);
820		break;
821	case BPF_ALU | BPF_XOR | BPF_X:
822	case BPF_ALU64 | BPF_XOR | BPF_X:
823		emit(A64_EOR(is64, dst, dst, src), ctx);
824		break;
825	case BPF_ALU | BPF_MUL | BPF_X:
826	case BPF_ALU64 | BPF_MUL | BPF_X:
827		emit(A64_MUL(is64, dst, dst, src), ctx);
828		break;
829	case BPF_ALU | BPF_DIV | BPF_X:
830	case BPF_ALU64 | BPF_DIV | BPF_X:
831		if (!off)
832			emit(A64_UDIV(is64, dst, dst, src), ctx);
833		else
834			emit(A64_SDIV(is64, dst, dst, src), ctx);
835		break;
836	case BPF_ALU | BPF_MOD | BPF_X:
837	case BPF_ALU64 | BPF_MOD | BPF_X:
838		if (!off)
839			emit(A64_UDIV(is64, tmp, dst, src), ctx);
840		else
841			emit(A64_SDIV(is64, tmp, dst, src), ctx);
842		emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
843		break;
844	case BPF_ALU | BPF_LSH | BPF_X:
845	case BPF_ALU64 | BPF_LSH | BPF_X:
846		emit(A64_LSLV(is64, dst, dst, src), ctx);
847		break;
848	case BPF_ALU | BPF_RSH | BPF_X:
849	case BPF_ALU64 | BPF_RSH | BPF_X:
850		emit(A64_LSRV(is64, dst, dst, src), ctx);
851		break;
852	case BPF_ALU | BPF_ARSH | BPF_X:
853	case BPF_ALU64 | BPF_ARSH | BPF_X:
854		emit(A64_ASRV(is64, dst, dst, src), ctx);
855		break;
856	/* dst = -dst */
857	case BPF_ALU | BPF_NEG:
858	case BPF_ALU64 | BPF_NEG:
859		emit(A64_NEG(is64, dst, dst), ctx);
860		break;
861	/* dst = BSWAP##imm(dst) */
862	case BPF_ALU | BPF_END | BPF_FROM_LE:
863	case BPF_ALU | BPF_END | BPF_FROM_BE:
864	case BPF_ALU64 | BPF_END | BPF_FROM_LE:
865#ifdef CONFIG_CPU_BIG_ENDIAN
866		if (BPF_CLASS(code) == BPF_ALU && BPF_SRC(code) == BPF_FROM_BE)
867			goto emit_bswap_uxt;
868#else /* !CONFIG_CPU_BIG_ENDIAN */
869		if (BPF_CLASS(code) == BPF_ALU && BPF_SRC(code) == BPF_FROM_LE)
870			goto emit_bswap_uxt;
871#endif
872		switch (imm) {
873		case 16:
874			emit(A64_REV16(is64, dst, dst), ctx);
875			/* zero-extend 16 bits into 64 bits */
876			emit(A64_UXTH(is64, dst, dst), ctx);
877			break;
878		case 32:
879			emit(A64_REV32(is64, dst, dst), ctx);
880			/* upper 32 bits already cleared */
881			break;
882		case 64:
883			emit(A64_REV64(dst, dst), ctx);
884			break;
885		}
886		break;
887emit_bswap_uxt:
888		switch (imm) {
889		case 16:
890			/* zero-extend 16 bits into 64 bits */
891			emit(A64_UXTH(is64, dst, dst), ctx);
892			break;
893		case 32:
894			/* zero-extend 32 bits into 64 bits */
895			emit(A64_UXTW(is64, dst, dst), ctx);
896			break;
897		case 64:
898			/* nop */
899			break;
900		}
901		break;
902	/* dst = imm */
903	case BPF_ALU | BPF_MOV | BPF_K:
904	case BPF_ALU64 | BPF_MOV | BPF_K:
905		emit_a64_mov_i(is64, dst, imm, ctx);
906		break;
907	/* dst = dst OP imm */
908	case BPF_ALU | BPF_ADD | BPF_K:
909	case BPF_ALU64 | BPF_ADD | BPF_K:
910		if (is_addsub_imm(imm)) {
911			emit(A64_ADD_I(is64, dst, dst, imm), ctx);
912		} else if (is_addsub_imm(-imm)) {
913			emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
914		} else {
915			emit_a64_mov_i(is64, tmp, imm, ctx);
916			emit(A64_ADD(is64, dst, dst, tmp), ctx);
917		}
918		break;
919	case BPF_ALU | BPF_SUB | BPF_K:
920	case BPF_ALU64 | BPF_SUB | BPF_K:
921		if (is_addsub_imm(imm)) {
922			emit(A64_SUB_I(is64, dst, dst, imm), ctx);
923		} else if (is_addsub_imm(-imm)) {
924			emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
925		} else {
926			emit_a64_mov_i(is64, tmp, imm, ctx);
927			emit(A64_SUB(is64, dst, dst, tmp), ctx);
928		}
929		break;
930	case BPF_ALU | BPF_AND | BPF_K:
931	case BPF_ALU64 | BPF_AND | BPF_K:
932		a64_insn = A64_AND_I(is64, dst, dst, imm);
933		if (a64_insn != AARCH64_BREAK_FAULT) {
934			emit(a64_insn, ctx);
935		} else {
936			emit_a64_mov_i(is64, tmp, imm, ctx);
937			emit(A64_AND(is64, dst, dst, tmp), ctx);
938		}
939		break;
940	case BPF_ALU | BPF_OR | BPF_K:
941	case BPF_ALU64 | BPF_OR | BPF_K:
942		a64_insn = A64_ORR_I(is64, dst, dst, imm);
943		if (a64_insn != AARCH64_BREAK_FAULT) {
944			emit(a64_insn, ctx);
945		} else {
946			emit_a64_mov_i(is64, tmp, imm, ctx);
947			emit(A64_ORR(is64, dst, dst, tmp), ctx);
948		}
949		break;
950	case BPF_ALU | BPF_XOR | BPF_K:
951	case BPF_ALU64 | BPF_XOR | BPF_K:
952		a64_insn = A64_EOR_I(is64, dst, dst, imm);
953		if (a64_insn != AARCH64_BREAK_FAULT) {
954			emit(a64_insn, ctx);
955		} else {
956			emit_a64_mov_i(is64, tmp, imm, ctx);
957			emit(A64_EOR(is64, dst, dst, tmp), ctx);
958		}
959		break;
960	case BPF_ALU | BPF_MUL | BPF_K:
961	case BPF_ALU64 | BPF_MUL | BPF_K:
962		emit_a64_mov_i(is64, tmp, imm, ctx);
963		emit(A64_MUL(is64, dst, dst, tmp), ctx);
964		break;
965	case BPF_ALU | BPF_DIV | BPF_K:
966	case BPF_ALU64 | BPF_DIV | BPF_K:
967		emit_a64_mov_i(is64, tmp, imm, ctx);
968		if (!off)
969			emit(A64_UDIV(is64, dst, dst, tmp), ctx);
970		else
971			emit(A64_SDIV(is64, dst, dst, tmp), ctx);
972		break;
973	case BPF_ALU | BPF_MOD | BPF_K:
974	case BPF_ALU64 | BPF_MOD | BPF_K:
975		emit_a64_mov_i(is64, tmp2, imm, ctx);
976		if (!off)
977			emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
978		else
979			emit(A64_SDIV(is64, tmp, dst, tmp2), ctx);
980		emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
981		break;
982	case BPF_ALU | BPF_LSH | BPF_K:
983	case BPF_ALU64 | BPF_LSH | BPF_K:
984		emit(A64_LSL(is64, dst, dst, imm), ctx);
985		break;
986	case BPF_ALU | BPF_RSH | BPF_K:
987	case BPF_ALU64 | BPF_RSH | BPF_K:
988		emit(A64_LSR(is64, dst, dst, imm), ctx);
989		break;
990	case BPF_ALU | BPF_ARSH | BPF_K:
991	case BPF_ALU64 | BPF_ARSH | BPF_K:
992		emit(A64_ASR(is64, dst, dst, imm), ctx);
993		break;
994
995	/* JUMP off */
996	case BPF_JMP | BPF_JA:
997	case BPF_JMP32 | BPF_JA:
998		if (BPF_CLASS(code) == BPF_JMP)
999			jmp_offset = bpf2a64_offset(i, off, ctx);
1000		else
1001			jmp_offset = bpf2a64_offset(i, imm, ctx);
1002		check_imm26(jmp_offset);
1003		emit(A64_B(jmp_offset), ctx);
1004		break;
1005	/* IF (dst COND src) JUMP off */
1006	case BPF_JMP | BPF_JEQ | BPF_X:
1007	case BPF_JMP | BPF_JGT | BPF_X:
1008	case BPF_JMP | BPF_JLT | BPF_X:
1009	case BPF_JMP | BPF_JGE | BPF_X:
1010	case BPF_JMP | BPF_JLE | BPF_X:
1011	case BPF_JMP | BPF_JNE | BPF_X:
1012	case BPF_JMP | BPF_JSGT | BPF_X:
1013	case BPF_JMP | BPF_JSLT | BPF_X:
1014	case BPF_JMP | BPF_JSGE | BPF_X:
1015	case BPF_JMP | BPF_JSLE | BPF_X:
1016	case BPF_JMP32 | BPF_JEQ | BPF_X:
1017	case BPF_JMP32 | BPF_JGT | BPF_X:
1018	case BPF_JMP32 | BPF_JLT | BPF_X:
1019	case BPF_JMP32 | BPF_JGE | BPF_X:
1020	case BPF_JMP32 | BPF_JLE | BPF_X:
1021	case BPF_JMP32 | BPF_JNE | BPF_X:
1022	case BPF_JMP32 | BPF_JSGT | BPF_X:
1023	case BPF_JMP32 | BPF_JSLT | BPF_X:
1024	case BPF_JMP32 | BPF_JSGE | BPF_X:
1025	case BPF_JMP32 | BPF_JSLE | BPF_X:
1026		emit(A64_CMP(is64, dst, src), ctx);
1027emit_cond_jmp:
1028		jmp_offset = bpf2a64_offset(i, off, ctx);
1029		check_imm19(jmp_offset);
1030		switch (BPF_OP(code)) {
1031		case BPF_JEQ:
1032			jmp_cond = A64_COND_EQ;
1033			break;
1034		case BPF_JGT:
1035			jmp_cond = A64_COND_HI;
1036			break;
1037		case BPF_JLT:
1038			jmp_cond = A64_COND_CC;
1039			break;
1040		case BPF_JGE:
1041			jmp_cond = A64_COND_CS;
1042			break;
1043		case BPF_JLE:
1044			jmp_cond = A64_COND_LS;
1045			break;
1046		case BPF_JSET:
1047		case BPF_JNE:
1048			jmp_cond = A64_COND_NE;
1049			break;
1050		case BPF_JSGT:
1051			jmp_cond = A64_COND_GT;
1052			break;
1053		case BPF_JSLT:
1054			jmp_cond = A64_COND_LT;
1055			break;
1056		case BPF_JSGE:
1057			jmp_cond = A64_COND_GE;
1058			break;
1059		case BPF_JSLE:
1060			jmp_cond = A64_COND_LE;
1061			break;
1062		default:
1063			return -EFAULT;
1064		}
1065		emit(A64_B_(jmp_cond, jmp_offset), ctx);
1066		break;
1067	case BPF_JMP | BPF_JSET | BPF_X:
1068	case BPF_JMP32 | BPF_JSET | BPF_X:
1069		emit(A64_TST(is64, dst, src), ctx);
1070		goto emit_cond_jmp;
1071	/* IF (dst COND imm) JUMP off */
1072	case BPF_JMP | BPF_JEQ | BPF_K:
1073	case BPF_JMP | BPF_JGT | BPF_K:
1074	case BPF_JMP | BPF_JLT | BPF_K:
1075	case BPF_JMP | BPF_JGE | BPF_K:
1076	case BPF_JMP | BPF_JLE | BPF_K:
1077	case BPF_JMP | BPF_JNE | BPF_K:
1078	case BPF_JMP | BPF_JSGT | BPF_K:
1079	case BPF_JMP | BPF_JSLT | BPF_K:
1080	case BPF_JMP | BPF_JSGE | BPF_K:
1081	case BPF_JMP | BPF_JSLE | BPF_K:
1082	case BPF_JMP32 | BPF_JEQ | BPF_K:
1083	case BPF_JMP32 | BPF_JGT | BPF_K:
1084	case BPF_JMP32 | BPF_JLT | BPF_K:
1085	case BPF_JMP32 | BPF_JGE | BPF_K:
1086	case BPF_JMP32 | BPF_JLE | BPF_K:
1087	case BPF_JMP32 | BPF_JNE | BPF_K:
1088	case BPF_JMP32 | BPF_JSGT | BPF_K:
1089	case BPF_JMP32 | BPF_JSLT | BPF_K:
1090	case BPF_JMP32 | BPF_JSGE | BPF_K:
1091	case BPF_JMP32 | BPF_JSLE | BPF_K:
1092		if (is_addsub_imm(imm)) {
1093			emit(A64_CMP_I(is64, dst, imm), ctx);
1094		} else if (is_addsub_imm(-imm)) {
1095			emit(A64_CMN_I(is64, dst, -imm), ctx);
1096		} else {
1097			emit_a64_mov_i(is64, tmp, imm, ctx);
1098			emit(A64_CMP(is64, dst, tmp), ctx);
1099		}
1100		goto emit_cond_jmp;
1101	case BPF_JMP | BPF_JSET | BPF_K:
1102	case BPF_JMP32 | BPF_JSET | BPF_K:
1103		a64_insn = A64_TST_I(is64, dst, imm);
1104		if (a64_insn != AARCH64_BREAK_FAULT) {
1105			emit(a64_insn, ctx);
1106		} else {
1107			emit_a64_mov_i(is64, tmp, imm, ctx);
1108			emit(A64_TST(is64, dst, tmp), ctx);
1109		}
1110		goto emit_cond_jmp;
1111	/* function call */
1112	case BPF_JMP | BPF_CALL:
1113	{
1114		const u8 r0 = bpf2a64[BPF_REG_0];
1115		bool func_addr_fixed;
1116		u64 func_addr;
1117
1118		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1119					    &func_addr, &func_addr_fixed);
1120		if (ret < 0)
1121			return ret;
1122		emit_call(func_addr, ctx);
1123		emit(A64_MOV(1, r0, A64_R(0)), ctx);
1124		break;
1125	}
1126	/* tail call */
1127	case BPF_JMP | BPF_TAIL_CALL:
1128		if (emit_bpf_tail_call(ctx))
1129			return -EFAULT;
1130		break;
1131	/* function return */
1132	case BPF_JMP | BPF_EXIT:
1133		/* Optimization: when last instruction is EXIT,
1134		   simply fallthrough to epilogue. */
1135		if (i == ctx->prog->len - 1)
1136			break;
1137		jmp_offset = epilogue_offset(ctx);
1138		check_imm26(jmp_offset);
1139		emit(A64_B(jmp_offset), ctx);
1140		break;
1141
1142	/* dst = imm64 */
1143	case BPF_LD | BPF_IMM | BPF_DW:
1144	{
1145		const struct bpf_insn insn1 = insn[1];
1146		u64 imm64;
1147
1148		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1149		if (bpf_pseudo_func(insn))
1150			emit_addr_mov_i64(dst, imm64, ctx);
1151		else
1152			emit_a64_mov_i64(dst, imm64, ctx);
1153
1154		return 1;
1155	}
1156
1157	/* LDX: dst = (u64)*(unsigned size *)(src + off) */
1158	case BPF_LDX | BPF_MEM | BPF_W:
1159	case BPF_LDX | BPF_MEM | BPF_H:
1160	case BPF_LDX | BPF_MEM | BPF_B:
1161	case BPF_LDX | BPF_MEM | BPF_DW:
1162	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1163	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1164	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1165	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1166	/* LDXS: dst_reg = (s64)*(signed size *)(src_reg + off) */
1167	case BPF_LDX | BPF_MEMSX | BPF_B:
1168	case BPF_LDX | BPF_MEMSX | BPF_H:
1169	case BPF_LDX | BPF_MEMSX | BPF_W:
1170	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1171	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1172	case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1173		if (ctx->fpb_offset > 0 && src == fp) {
1174			src_adj = fpb;
1175			off_adj = off + ctx->fpb_offset;
1176		} else {
1177			src_adj = src;
1178			off_adj = off;
1179		}
1180		sign_extend = (BPF_MODE(insn->code) == BPF_MEMSX ||
1181				BPF_MODE(insn->code) == BPF_PROBE_MEMSX);
1182		switch (BPF_SIZE(code)) {
1183		case BPF_W:
1184			if (is_lsi_offset(off_adj, 2)) {
1185				if (sign_extend)
1186					emit(A64_LDRSWI(dst, src_adj, off_adj), ctx);
1187				else
1188					emit(A64_LDR32I(dst, src_adj, off_adj), ctx);
1189			} else {
1190				emit_a64_mov_i(1, tmp, off, ctx);
1191				if (sign_extend)
1192					emit(A64_LDRSW(dst, src_adj, off_adj), ctx);
1193				else
1194					emit(A64_LDR32(dst, src, tmp), ctx);
1195			}
1196			break;
1197		case BPF_H:
1198			if (is_lsi_offset(off_adj, 1)) {
1199				if (sign_extend)
1200					emit(A64_LDRSHI(dst, src_adj, off_adj), ctx);
1201				else
1202					emit(A64_LDRHI(dst, src_adj, off_adj), ctx);
1203			} else {
1204				emit_a64_mov_i(1, tmp, off, ctx);
1205				if (sign_extend)
1206					emit(A64_LDRSH(dst, src, tmp), ctx);
1207				else
1208					emit(A64_LDRH(dst, src, tmp), ctx);
1209			}
1210			break;
1211		case BPF_B:
1212			if (is_lsi_offset(off_adj, 0)) {
1213				if (sign_extend)
1214					emit(A64_LDRSBI(dst, src_adj, off_adj), ctx);
1215				else
1216					emit(A64_LDRBI(dst, src_adj, off_adj), ctx);
1217			} else {
1218				emit_a64_mov_i(1, tmp, off, ctx);
1219				if (sign_extend)
1220					emit(A64_LDRSB(dst, src, tmp), ctx);
1221				else
1222					emit(A64_LDRB(dst, src, tmp), ctx);
1223			}
1224			break;
1225		case BPF_DW:
1226			if (is_lsi_offset(off_adj, 3)) {
1227				emit(A64_LDR64I(dst, src_adj, off_adj), ctx);
1228			} else {
1229				emit_a64_mov_i(1, tmp, off, ctx);
1230				emit(A64_LDR64(dst, src, tmp), ctx);
1231			}
1232			break;
1233		}
1234
1235		ret = add_exception_handler(insn, ctx, dst);
1236		if (ret)
1237			return ret;
1238		break;
1239
1240	/* speculation barrier */
1241	case BPF_ST | BPF_NOSPEC:
1242		/*
1243		 * Nothing required here.
1244		 *
1245		 * In case of arm64, we rely on the firmware mitigation of
1246		 * Speculative Store Bypass as controlled via the ssbd kernel
1247		 * parameter. Whenever the mitigation is enabled, it works
1248		 * for all of the kernel code with no need to provide any
1249		 * additional instructions.
1250		 */
1251		break;
1252
1253	/* ST: *(size *)(dst + off) = imm */
1254	case BPF_ST | BPF_MEM | BPF_W:
1255	case BPF_ST | BPF_MEM | BPF_H:
1256	case BPF_ST | BPF_MEM | BPF_B:
1257	case BPF_ST | BPF_MEM | BPF_DW:
1258		if (ctx->fpb_offset > 0 && dst == fp) {
1259			dst_adj = fpb;
1260			off_adj = off + ctx->fpb_offset;
1261		} else {
1262			dst_adj = dst;
1263			off_adj = off;
1264		}
1265		/* Load imm to a register then store it */
1266		emit_a64_mov_i(1, tmp, imm, ctx);
1267		switch (BPF_SIZE(code)) {
1268		case BPF_W:
1269			if (is_lsi_offset(off_adj, 2)) {
1270				emit(A64_STR32I(tmp, dst_adj, off_adj), ctx);
1271			} else {
1272				emit_a64_mov_i(1, tmp2, off, ctx);
1273				emit(A64_STR32(tmp, dst, tmp2), ctx);
1274			}
1275			break;
1276		case BPF_H:
1277			if (is_lsi_offset(off_adj, 1)) {
1278				emit(A64_STRHI(tmp, dst_adj, off_adj), ctx);
1279			} else {
1280				emit_a64_mov_i(1, tmp2, off, ctx);
1281				emit(A64_STRH(tmp, dst, tmp2), ctx);
1282			}
1283			break;
1284		case BPF_B:
1285			if (is_lsi_offset(off_adj, 0)) {
1286				emit(A64_STRBI(tmp, dst_adj, off_adj), ctx);
1287			} else {
1288				emit_a64_mov_i(1, tmp2, off, ctx);
1289				emit(A64_STRB(tmp, dst, tmp2), ctx);
1290			}
1291			break;
1292		case BPF_DW:
1293			if (is_lsi_offset(off_adj, 3)) {
1294				emit(A64_STR64I(tmp, dst_adj, off_adj), ctx);
1295			} else {
1296				emit_a64_mov_i(1, tmp2, off, ctx);
1297				emit(A64_STR64(tmp, dst, tmp2), ctx);
1298			}
1299			break;
1300		}
1301		break;
1302
1303	/* STX: *(size *)(dst + off) = src */
1304	case BPF_STX | BPF_MEM | BPF_W:
1305	case BPF_STX | BPF_MEM | BPF_H:
1306	case BPF_STX | BPF_MEM | BPF_B:
1307	case BPF_STX | BPF_MEM | BPF_DW:
1308		if (ctx->fpb_offset > 0 && dst == fp) {
1309			dst_adj = fpb;
1310			off_adj = off + ctx->fpb_offset;
1311		} else {
1312			dst_adj = dst;
1313			off_adj = off;
1314		}
1315		switch (BPF_SIZE(code)) {
1316		case BPF_W:
1317			if (is_lsi_offset(off_adj, 2)) {
1318				emit(A64_STR32I(src, dst_adj, off_adj), ctx);
1319			} else {
1320				emit_a64_mov_i(1, tmp, off, ctx);
1321				emit(A64_STR32(src, dst, tmp), ctx);
1322			}
1323			break;
1324		case BPF_H:
1325			if (is_lsi_offset(off_adj, 1)) {
1326				emit(A64_STRHI(src, dst_adj, off_adj), ctx);
1327			} else {
1328				emit_a64_mov_i(1, tmp, off, ctx);
1329				emit(A64_STRH(src, dst, tmp), ctx);
1330			}
1331			break;
1332		case BPF_B:
1333			if (is_lsi_offset(off_adj, 0)) {
1334				emit(A64_STRBI(src, dst_adj, off_adj), ctx);
1335			} else {
1336				emit_a64_mov_i(1, tmp, off, ctx);
1337				emit(A64_STRB(src, dst, tmp), ctx);
1338			}
1339			break;
1340		case BPF_DW:
1341			if (is_lsi_offset(off_adj, 3)) {
1342				emit(A64_STR64I(src, dst_adj, off_adj), ctx);
1343			} else {
1344				emit_a64_mov_i(1, tmp, off, ctx);
1345				emit(A64_STR64(src, dst, tmp), ctx);
1346			}
1347			break;
1348		}
1349		break;
1350
1351	case BPF_STX | BPF_ATOMIC | BPF_W:
1352	case BPF_STX | BPF_ATOMIC | BPF_DW:
1353		if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
1354			ret = emit_lse_atomic(insn, ctx);
1355		else
1356			ret = emit_ll_sc_atomic(insn, ctx);
1357		if (ret)
1358			return ret;
1359		break;
1360
1361	default:
1362		pr_err_once("unknown opcode %02x\n", code);
1363		return -EINVAL;
1364	}
1365
1366	return 0;
1367}
1368
1369/*
1370 * Return 0 if FP may change at runtime, otherwise find the minimum negative
1371 * offset to FP, converts it to positive number, and align down to 8 bytes.
1372 */
1373static int find_fpb_offset(struct bpf_prog *prog)
1374{
1375	int i;
1376	int offset = 0;
1377
1378	for (i = 0; i < prog->len; i++) {
1379		const struct bpf_insn *insn = &prog->insnsi[i];
1380		const u8 class = BPF_CLASS(insn->code);
1381		const u8 mode = BPF_MODE(insn->code);
1382		const u8 src = insn->src_reg;
1383		const u8 dst = insn->dst_reg;
1384		const s32 imm = insn->imm;
1385		const s16 off = insn->off;
1386
1387		switch (class) {
1388		case BPF_STX:
1389		case BPF_ST:
1390			/* fp holds atomic operation result */
1391			if (class == BPF_STX && mode == BPF_ATOMIC &&
1392			    ((imm == BPF_XCHG ||
1393			      imm == (BPF_FETCH | BPF_ADD) ||
1394			      imm == (BPF_FETCH | BPF_AND) ||
1395			      imm == (BPF_FETCH | BPF_XOR) ||
1396			      imm == (BPF_FETCH | BPF_OR)) &&
1397			     src == BPF_REG_FP))
1398				return 0;
1399
1400			if (mode == BPF_MEM && dst == BPF_REG_FP &&
1401			    off < offset)
1402				offset = insn->off;
1403			break;
1404
1405		case BPF_JMP32:
1406		case BPF_JMP:
1407			break;
1408
1409		case BPF_LDX:
1410		case BPF_LD:
1411			/* fp holds load result */
1412			if (dst == BPF_REG_FP)
1413				return 0;
1414
1415			if (class == BPF_LDX && mode == BPF_MEM &&
1416			    src == BPF_REG_FP && off < offset)
1417				offset = off;
1418			break;
1419
1420		case BPF_ALU:
1421		case BPF_ALU64:
1422		default:
1423			/* fp holds ALU result */
1424			if (dst == BPF_REG_FP)
1425				return 0;
1426		}
1427	}
1428
1429	if (offset < 0) {
1430		/*
1431		 * safely be converted to a positive 'int', since insn->off
1432		 * is 's16'
1433		 */
1434		offset = -offset;
1435		/* align down to 8 bytes */
1436		offset = ALIGN_DOWN(offset, 8);
1437	}
1438
1439	return offset;
1440}
1441
1442static int build_body(struct jit_ctx *ctx, bool extra_pass)
1443{
1444	const struct bpf_prog *prog = ctx->prog;
1445	int i;
1446
1447	/*
1448	 * - offset[0] offset of the end of prologue,
1449	 *   start of the 1st instruction.
1450	 * - offset[1] - offset of the end of 1st instruction,
1451	 *   start of the 2nd instruction
1452	 * [....]
1453	 * - offset[3] - offset of the end of 3rd instruction,
1454	 *   start of 4th instruction
1455	 */
1456	for (i = 0; i < prog->len; i++) {
1457		const struct bpf_insn *insn = &prog->insnsi[i];
1458		int ret;
1459
1460		if (ctx->image == NULL)
1461			ctx->offset[i] = ctx->idx;
1462		ret = build_insn(insn, ctx, extra_pass);
1463		if (ret > 0) {
1464			i++;
1465			if (ctx->image == NULL)
1466				ctx->offset[i] = ctx->idx;
1467			continue;
1468		}
1469		if (ret)
1470			return ret;
1471	}
1472	/*
1473	 * offset is allocated with prog->len + 1 so fill in
1474	 * the last element with the offset after the last
1475	 * instruction (end of program)
1476	 */
1477	if (ctx->image == NULL)
1478		ctx->offset[i] = ctx->idx;
1479
1480	return 0;
1481}
1482
1483static int validate_code(struct jit_ctx *ctx)
1484{
1485	int i;
1486
1487	for (i = 0; i < ctx->idx; i++) {
1488		u32 a64_insn = le32_to_cpu(ctx->image[i]);
1489
1490		if (a64_insn == AARCH64_BREAK_FAULT)
1491			return -1;
1492	}
1493	return 0;
1494}
1495
1496static int validate_ctx(struct jit_ctx *ctx)
1497{
1498	if (validate_code(ctx))
1499		return -1;
1500
1501	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
1502		return -1;
1503
1504	return 0;
1505}
1506
1507static inline void bpf_flush_icache(void *start, void *end)
1508{
1509	flush_icache_range((unsigned long)start, (unsigned long)end);
1510}
1511
1512struct arm64_jit_data {
1513	struct bpf_binary_header *header;
1514	u8 *image;
1515	struct jit_ctx ctx;
1516};
1517
1518struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1519{
1520	int image_size, prog_size, extable_size, extable_align, extable_offset;
1521	struct bpf_prog *tmp, *orig_prog = prog;
1522	struct bpf_binary_header *header;
1523	struct arm64_jit_data *jit_data;
1524	bool was_classic = bpf_prog_was_classic(prog);
1525	bool tmp_blinded = false;
1526	bool extra_pass = false;
1527	struct jit_ctx ctx;
1528	u8 *image_ptr;
1529
1530	if (!prog->jit_requested)
1531		return orig_prog;
1532
1533	tmp = bpf_jit_blind_constants(prog);
1534	/* If blinding was requested and we failed during blinding,
1535	 * we must fall back to the interpreter.
1536	 */
1537	if (IS_ERR(tmp))
1538		return orig_prog;
1539	if (tmp != prog) {
1540		tmp_blinded = true;
1541		prog = tmp;
1542	}
1543
1544	jit_data = prog->aux->jit_data;
1545	if (!jit_data) {
1546		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1547		if (!jit_data) {
1548			prog = orig_prog;
1549			goto out;
1550		}
1551		prog->aux->jit_data = jit_data;
1552	}
1553	if (jit_data->ctx.offset) {
1554		ctx = jit_data->ctx;
1555		image_ptr = jit_data->image;
1556		header = jit_data->header;
1557		extra_pass = true;
1558		prog_size = sizeof(u32) * ctx.idx;
1559		goto skip_init_ctx;
1560	}
1561	memset(&ctx, 0, sizeof(ctx));
1562	ctx.prog = prog;
1563
1564	ctx.offset = kvcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
1565	if (ctx.offset == NULL) {
1566		prog = orig_prog;
1567		goto out_off;
1568	}
1569
1570	ctx.fpb_offset = find_fpb_offset(prog);
1571
1572	/*
1573	 * 1. Initial fake pass to compute ctx->idx and ctx->offset.
1574	 *
1575	 * BPF line info needs ctx->offset[i] to be the offset of
1576	 * instruction[i] in jited image, so build prologue first.
1577	 */
1578	if (build_prologue(&ctx, was_classic)) {
1579		prog = orig_prog;
1580		goto out_off;
1581	}
1582
1583	if (build_body(&ctx, extra_pass)) {
1584		prog = orig_prog;
1585		goto out_off;
1586	}
1587
1588	ctx.epilogue_offset = ctx.idx;
1589	build_epilogue(&ctx);
1590	build_plt(&ctx);
1591
1592	extable_align = __alignof__(struct exception_table_entry);
1593	extable_size = prog->aux->num_exentries *
1594		sizeof(struct exception_table_entry);
1595
1596	/* Now we know the actual image size. */
1597	prog_size = sizeof(u32) * ctx.idx;
1598	/* also allocate space for plt target */
1599	extable_offset = round_up(prog_size + PLT_TARGET_SIZE, extable_align);
1600	image_size = extable_offset + extable_size;
1601	header = bpf_jit_binary_alloc(image_size, &image_ptr,
1602				      sizeof(u32), jit_fill_hole);
1603	if (header == NULL) {
1604		prog = orig_prog;
1605		goto out_off;
1606	}
1607
1608	/* 2. Now, the actual pass. */
1609
1610	ctx.image = (__le32 *)image_ptr;
1611	if (extable_size)
1612		prog->aux->extable = (void *)image_ptr + extable_offset;
1613skip_init_ctx:
1614	ctx.idx = 0;
1615	ctx.exentry_idx = 0;
1616
1617	build_prologue(&ctx, was_classic);
1618
1619	if (build_body(&ctx, extra_pass)) {
1620		bpf_jit_binary_free(header);
1621		prog = orig_prog;
1622		goto out_off;
1623	}
1624
1625	build_epilogue(&ctx);
1626	build_plt(&ctx);
1627
1628	/* 3. Extra pass to validate JITed code. */
1629	if (validate_ctx(&ctx)) {
1630		bpf_jit_binary_free(header);
1631		prog = orig_prog;
1632		goto out_off;
1633	}
1634
1635	/* And we're done. */
1636	if (bpf_jit_enable > 1)
1637		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1638
1639	bpf_flush_icache(header, ctx.image + ctx.idx);
1640
1641	if (!prog->is_func || extra_pass) {
1642		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1643			pr_err_once("multi-func JIT bug %d != %d\n",
1644				    ctx.idx, jit_data->ctx.idx);
1645			bpf_jit_binary_free(header);
1646			prog->bpf_func = NULL;
1647			prog->jited = 0;
1648			prog->jited_len = 0;
1649			goto out_off;
1650		}
1651		bpf_jit_binary_lock_ro(header);
1652	} else {
1653		jit_data->ctx = ctx;
1654		jit_data->image = image_ptr;
1655		jit_data->header = header;
1656	}
1657	prog->bpf_func = (void *)ctx.image;
1658	prog->jited = 1;
1659	prog->jited_len = prog_size;
1660
1661	if (!prog->is_func || extra_pass) {
1662		int i;
1663
1664		/* offset[prog->len] is the size of program */
1665		for (i = 0; i <= prog->len; i++)
1666			ctx.offset[i] *= AARCH64_INSN_SIZE;
1667		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1668out_off:
1669		kvfree(ctx.offset);
1670		kfree(jit_data);
1671		prog->aux->jit_data = NULL;
1672	}
1673out:
1674	if (tmp_blinded)
1675		bpf_jit_prog_release_other(prog, prog == orig_prog ?
1676					   tmp : orig_prog);
1677	return prog;
1678}
1679
1680bool bpf_jit_supports_kfunc_call(void)
1681{
1682	return true;
1683}
1684
1685u64 bpf_jit_alloc_exec_limit(void)
1686{
1687	return VMALLOC_END - VMALLOC_START;
1688}
1689
1690void *bpf_jit_alloc_exec(unsigned long size)
1691{
1692	/* Memory is intended to be executable, reset the pointer tag. */
1693	return kasan_reset_tag(vmalloc(size));
1694}
1695
1696void bpf_jit_free_exec(void *addr)
1697{
1698	return vfree(addr);
1699}
1700
1701/* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */
1702bool bpf_jit_supports_subprog_tailcalls(void)
1703{
1704	return true;
1705}
1706
1707static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
1708			    int args_off, int retval_off, int run_ctx_off,
1709			    bool save_ret)
1710{
1711	__le32 *branch;
1712	u64 enter_prog;
1713	u64 exit_prog;
1714	struct bpf_prog *p = l->link.prog;
1715	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
1716
1717	enter_prog = (u64)bpf_trampoline_enter(p);
1718	exit_prog = (u64)bpf_trampoline_exit(p);
1719
1720	if (l->cookie == 0) {
1721		/* if cookie is zero, one instruction is enough to store it */
1722		emit(A64_STR64I(A64_ZR, A64_SP, run_ctx_off + cookie_off), ctx);
1723	} else {
1724		emit_a64_mov_i64(A64_R(10), l->cookie, ctx);
1725		emit(A64_STR64I(A64_R(10), A64_SP, run_ctx_off + cookie_off),
1726		     ctx);
1727	}
1728
1729	/* save p to callee saved register x19 to avoid loading p with mov_i64
1730	 * each time.
1731	 */
1732	emit_addr_mov_i64(A64_R(19), (const u64)p, ctx);
1733
1734	/* arg1: prog */
1735	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
1736	/* arg2: &run_ctx */
1737	emit(A64_ADD_I(1, A64_R(1), A64_SP, run_ctx_off), ctx);
1738
1739	emit_call(enter_prog, ctx);
1740
1741	/* if (__bpf_prog_enter(prog) == 0)
1742	 *         goto skip_exec_of_prog;
1743	 */
1744	branch = ctx->image + ctx->idx;
1745	emit(A64_NOP, ctx);
1746
1747	/* save return value to callee saved register x20 */
1748	emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx);
1749
1750	emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
1751	if (!p->jited)
1752		emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
1753
1754	emit_call((const u64)p->bpf_func, ctx);
1755
1756	if (save_ret)
1757		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
1758
1759	if (ctx->image) {
1760		int offset = &ctx->image[ctx->idx] - branch;
1761		*branch = cpu_to_le32(A64_CBZ(1, A64_R(0), offset));
1762	}
1763
1764	/* arg1: prog */
1765	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
1766	/* arg2: start time */
1767	emit(A64_MOV(1, A64_R(1), A64_R(20)), ctx);
1768	/* arg3: &run_ctx */
1769	emit(A64_ADD_I(1, A64_R(2), A64_SP, run_ctx_off), ctx);
1770
1771	emit_call(exit_prog, ctx);
1772}
1773
1774static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
1775			       int args_off, int retval_off, int run_ctx_off,
1776			       __le32 **branches)
1777{
1778	int i;
1779
1780	/* The first fmod_ret program will receive a garbage return value.
1781	 * Set this to 0 to avoid confusing the program.
1782	 */
1783	emit(A64_STR64I(A64_ZR, A64_SP, retval_off), ctx);
1784	for (i = 0; i < tl->nr_links; i++) {
1785		invoke_bpf_prog(ctx, tl->links[i], args_off, retval_off,
1786				run_ctx_off, true);
1787		/* if (*(u64 *)(sp + retval_off) !=  0)
1788		 *	goto do_fexit;
1789		 */
1790		emit(A64_LDR64I(A64_R(10), A64_SP, retval_off), ctx);
1791		/* Save the location of branch, and generate a nop.
1792		 * This nop will be replaced with a cbnz later.
1793		 */
1794		branches[i] = ctx->image + ctx->idx;
1795		emit(A64_NOP, ctx);
1796	}
1797}
1798
1799static void save_args(struct jit_ctx *ctx, int args_off, int nregs)
1800{
1801	int i;
1802
1803	for (i = 0; i < nregs; i++) {
1804		emit(A64_STR64I(i, A64_SP, args_off), ctx);
1805		args_off += 8;
1806	}
1807}
1808
1809static void restore_args(struct jit_ctx *ctx, int args_off, int nregs)
1810{
1811	int i;
1812
1813	for (i = 0; i < nregs; i++) {
1814		emit(A64_LDR64I(i, A64_SP, args_off), ctx);
1815		args_off += 8;
1816	}
1817}
1818
1819/* Based on the x86's implementation of arch_prepare_bpf_trampoline().
1820 *
1821 * bpf prog and function entry before bpf trampoline hooked:
1822 *   mov x9, lr
1823 *   nop
1824 *
1825 * bpf prog and function entry after bpf trampoline hooked:
1826 *   mov x9, lr
1827 *   bl  <bpf_trampoline or plt>
1828 *
1829 */
1830static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
1831			      struct bpf_tramp_links *tlinks, void *orig_call,
1832			      int nregs, u32 flags)
1833{
1834	int i;
1835	int stack_size;
1836	int retaddr_off;
1837	int regs_off;
1838	int retval_off;
1839	int args_off;
1840	int nregs_off;
1841	int ip_off;
1842	int run_ctx_off;
1843	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
1844	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
1845	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
1846	bool save_ret;
1847	__le32 **branches = NULL;
1848
1849	/* trampoline stack layout:
1850	 *                  [ parent ip         ]
1851	 *                  [ FP                ]
1852	 * SP + retaddr_off [ self ip           ]
1853	 *                  [ FP                ]
1854	 *
1855	 *                  [ padding           ] align SP to multiples of 16
1856	 *
1857	 *                  [ x20               ] callee saved reg x20
1858	 * SP + regs_off    [ x19               ] callee saved reg x19
1859	 *
1860	 * SP + retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
1861	 *                                        BPF_TRAMP_F_RET_FENTRY_RET
1862	 *
1863	 *                  [ arg reg N         ]
1864	 *                  [ ...               ]
1865	 * SP + args_off    [ arg reg 1         ]
1866	 *
1867	 * SP + nregs_off   [ arg regs count    ]
1868	 *
1869	 * SP + ip_off      [ traced function   ] BPF_TRAMP_F_IP_ARG flag
1870	 *
1871	 * SP + run_ctx_off [ bpf_tramp_run_ctx ]
1872	 */
1873
1874	stack_size = 0;
1875	run_ctx_off = stack_size;
1876	/* room for bpf_tramp_run_ctx */
1877	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
1878
1879	ip_off = stack_size;
1880	/* room for IP address argument */
1881	if (flags & BPF_TRAMP_F_IP_ARG)
1882		stack_size += 8;
1883
1884	nregs_off = stack_size;
1885	/* room for args count */
1886	stack_size += 8;
1887
1888	args_off = stack_size;
1889	/* room for args */
1890	stack_size += nregs * 8;
1891
1892	/* room for return value */
1893	retval_off = stack_size;
1894	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
1895	if (save_ret)
1896		stack_size += 8;
1897
1898	/* room for callee saved registers, currently x19 and x20 are used */
1899	regs_off = stack_size;
1900	stack_size += 16;
1901
1902	/* round up to multiples of 16 to avoid SPAlignmentFault */
1903	stack_size = round_up(stack_size, 16);
1904
1905	/* return address locates above FP */
1906	retaddr_off = stack_size + 8;
1907
1908	/* bpf trampoline may be invoked by 3 instruction types:
1909	 * 1. bl, attached to bpf prog or kernel function via short jump
1910	 * 2. br, attached to bpf prog or kernel function via long jump
1911	 * 3. blr, working as a function pointer, used by struct_ops.
1912	 * So BTI_JC should used here to support both br and blr.
1913	 */
1914	emit_bti(A64_BTI_JC, ctx);
1915
1916	/* frame for parent function */
1917	emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
1918	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
1919
1920	/* frame for patched function */
1921	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
1922	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
1923
1924	/* allocate stack space */
1925	emit(A64_SUB_I(1, A64_SP, A64_SP, stack_size), ctx);
1926
1927	if (flags & BPF_TRAMP_F_IP_ARG) {
1928		/* save ip address of the traced function */
1929		emit_addr_mov_i64(A64_R(10), (const u64)orig_call, ctx);
1930		emit(A64_STR64I(A64_R(10), A64_SP, ip_off), ctx);
1931	}
1932
1933	/* save arg regs count*/
1934	emit(A64_MOVZ(1, A64_R(10), nregs, 0), ctx);
1935	emit(A64_STR64I(A64_R(10), A64_SP, nregs_off), ctx);
1936
1937	/* save arg regs */
1938	save_args(ctx, args_off, nregs);
1939
1940	/* save callee saved registers */
1941	emit(A64_STR64I(A64_R(19), A64_SP, regs_off), ctx);
1942	emit(A64_STR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
1943
1944	if (flags & BPF_TRAMP_F_CALL_ORIG) {
1945		emit_addr_mov_i64(A64_R(0), (const u64)im, ctx);
1946		emit_call((const u64)__bpf_tramp_enter, ctx);
1947	}
1948
1949	for (i = 0; i < fentry->nr_links; i++)
1950		invoke_bpf_prog(ctx, fentry->links[i], args_off,
1951				retval_off, run_ctx_off,
1952				flags & BPF_TRAMP_F_RET_FENTRY_RET);
1953
1954	if (fmod_ret->nr_links) {
1955		branches = kcalloc(fmod_ret->nr_links, sizeof(__le32 *),
1956				   GFP_KERNEL);
1957		if (!branches)
1958			return -ENOMEM;
1959
1960		invoke_bpf_mod_ret(ctx, fmod_ret, args_off, retval_off,
1961				   run_ctx_off, branches);
1962	}
1963
1964	if (flags & BPF_TRAMP_F_CALL_ORIG) {
1965		restore_args(ctx, args_off, nregs);
1966		/* call original func */
1967		emit(A64_LDR64I(A64_R(10), A64_SP, retaddr_off), ctx);
1968		emit(A64_ADR(A64_LR, AARCH64_INSN_SIZE * 2), ctx);
1969		emit(A64_RET(A64_R(10)), ctx);
1970		/* store return value */
1971		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
1972		/* reserve a nop for bpf_tramp_image_put */
1973		im->ip_after_call = ctx->image + ctx->idx;
1974		emit(A64_NOP, ctx);
1975	}
1976
1977	/* update the branches saved in invoke_bpf_mod_ret with cbnz */
1978	for (i = 0; i < fmod_ret->nr_links && ctx->image != NULL; i++) {
1979		int offset = &ctx->image[ctx->idx] - branches[i];
1980		*branches[i] = cpu_to_le32(A64_CBNZ(1, A64_R(10), offset));
1981	}
1982
1983	for (i = 0; i < fexit->nr_links; i++)
1984		invoke_bpf_prog(ctx, fexit->links[i], args_off, retval_off,
1985				run_ctx_off, false);
1986
1987	if (flags & BPF_TRAMP_F_CALL_ORIG) {
1988		im->ip_epilogue = ctx->image + ctx->idx;
1989		emit_addr_mov_i64(A64_R(0), (const u64)im, ctx);
1990		emit_call((const u64)__bpf_tramp_exit, ctx);
1991	}
1992
1993	if (flags & BPF_TRAMP_F_RESTORE_REGS)
1994		restore_args(ctx, args_off, nregs);
1995
1996	/* restore callee saved register x19 and x20 */
1997	emit(A64_LDR64I(A64_R(19), A64_SP, regs_off), ctx);
1998	emit(A64_LDR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
1999
2000	if (save_ret)
2001		emit(A64_LDR64I(A64_R(0), A64_SP, retval_off), ctx);
2002
2003	/* reset SP  */
2004	emit(A64_MOV(1, A64_SP, A64_FP), ctx);
2005
2006	/* pop frames  */
2007	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
2008	emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
2009
2010	if (flags & BPF_TRAMP_F_SKIP_FRAME) {
2011		/* skip patched function, return to parent */
2012		emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
2013		emit(A64_RET(A64_R(9)), ctx);
2014	} else {
2015		/* return to patched function */
2016		emit(A64_MOV(1, A64_R(10), A64_LR), ctx);
2017		emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
2018		emit(A64_RET(A64_R(10)), ctx);
2019	}
2020
2021	if (ctx->image)
2022		bpf_flush_icache(ctx->image, ctx->image + ctx->idx);
2023
2024	kfree(branches);
2025
2026	return ctx->idx;
2027}
2028
2029int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
2030				void *image_end, const struct btf_func_model *m,
2031				u32 flags, struct bpf_tramp_links *tlinks,
2032				void *orig_call)
2033{
2034	int i, ret;
2035	int nregs = m->nr_args;
2036	int max_insns = ((long)image_end - (long)image) / AARCH64_INSN_SIZE;
2037	struct jit_ctx ctx = {
2038		.image = NULL,
2039		.idx = 0,
2040	};
2041
2042	/* extra registers needed for struct argument */
2043	for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) {
2044		/* The arg_size is at most 16 bytes, enforced by the verifier. */
2045		if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
2046			nregs += (m->arg_size[i] + 7) / 8 - 1;
2047	}
2048
2049	/* the first 8 registers are used for arguments */
2050	if (nregs > 8)
2051		return -ENOTSUPP;
2052
2053	ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nregs, flags);
2054	if (ret < 0)
2055		return ret;
2056
2057	if (ret > max_insns)
2058		return -EFBIG;
2059
2060	ctx.image = image;
2061	ctx.idx = 0;
2062
2063	jit_fill_hole(image, (unsigned int)(image_end - image));
2064	ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nregs, flags);
2065
2066	if (ret > 0 && validate_code(&ctx) < 0)
2067		ret = -EINVAL;
2068
2069	if (ret > 0)
2070		ret *= AARCH64_INSN_SIZE;
2071
2072	return ret;
2073}
2074
2075static bool is_long_jump(void *ip, void *target)
2076{
2077	long offset;
2078
2079	/* NULL target means this is a NOP */
2080	if (!target)
2081		return false;
2082
2083	offset = (long)target - (long)ip;
2084	return offset < -SZ_128M || offset >= SZ_128M;
2085}
2086
2087static int gen_branch_or_nop(enum aarch64_insn_branch_type type, void *ip,
2088			     void *addr, void *plt, u32 *insn)
2089{
2090	void *target;
2091
2092	if (!addr) {
2093		*insn = aarch64_insn_gen_nop();
2094		return 0;
2095	}
2096
2097	if (is_long_jump(ip, addr))
2098		target = plt;
2099	else
2100		target = addr;
2101
2102	*insn = aarch64_insn_gen_branch_imm((unsigned long)ip,
2103					    (unsigned long)target,
2104					    type);
2105
2106	return *insn != AARCH64_BREAK_FAULT ? 0 : -EFAULT;
2107}
2108
2109/* Replace the branch instruction from @ip to @old_addr in a bpf prog or a bpf
2110 * trampoline with the branch instruction from @ip to @new_addr. If @old_addr
2111 * or @new_addr is NULL, the old or new instruction is NOP.
2112 *
2113 * When @ip is the bpf prog entry, a bpf trampoline is being attached or
2114 * detached. Since bpf trampoline and bpf prog are allocated separately with
2115 * vmalloc, the address distance may exceed 128MB, the maximum branch range.
2116 * So long jump should be handled.
2117 *
2118 * When a bpf prog is constructed, a plt pointing to empty trampoline
2119 * dummy_tramp is placed at the end:
2120 *
2121 *      bpf_prog:
2122 *              mov x9, lr
2123 *              nop // patchsite
2124 *              ...
2125 *              ret
2126 *
2127 *      plt:
2128 *              ldr x10, target
2129 *              br x10
2130 *      target:
2131 *              .quad dummy_tramp // plt target
2132 *
2133 * This is also the state when no trampoline is attached.
2134 *
2135 * When a short-jump bpf trampoline is attached, the patchsite is patched
2136 * to a bl instruction to the trampoline directly:
2137 *
2138 *      bpf_prog:
2139 *              mov x9, lr
2140 *              bl <short-jump bpf trampoline address> // patchsite
2141 *              ...
2142 *              ret
2143 *
2144 *      plt:
2145 *              ldr x10, target
2146 *              br x10
2147 *      target:
2148 *              .quad dummy_tramp // plt target
2149 *
2150 * When a long-jump bpf trampoline is attached, the plt target is filled with
2151 * the trampoline address and the patchsite is patched to a bl instruction to
2152 * the plt:
2153 *
2154 *      bpf_prog:
2155 *              mov x9, lr
2156 *              bl plt // patchsite
2157 *              ...
2158 *              ret
2159 *
2160 *      plt:
2161 *              ldr x10, target
2162 *              br x10
2163 *      target:
2164 *              .quad <long-jump bpf trampoline address> // plt target
2165 *
2166 * The dummy_tramp is used to prevent another CPU from jumping to unknown
2167 * locations during the patching process, making the patching process easier.
2168 */
2169int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
2170		       void *old_addr, void *new_addr)
2171{
2172	int ret;
2173	u32 old_insn;
2174	u32 new_insn;
2175	u32 replaced;
2176	struct bpf_plt *plt = NULL;
2177	unsigned long size = 0UL;
2178	unsigned long offset = ~0UL;
2179	enum aarch64_insn_branch_type branch_type;
2180	char namebuf[KSYM_NAME_LEN];
2181	void *image = NULL;
2182	u64 plt_target = 0ULL;
2183	bool poking_bpf_entry;
2184
2185	if (!__bpf_address_lookup((unsigned long)ip, &size, &offset, namebuf))
2186		/* Only poking bpf text is supported. Since kernel function
2187		 * entry is set up by ftrace, we reply on ftrace to poke kernel
2188		 * functions.
2189		 */
2190		return -ENOTSUPP;
2191
2192	image = ip - offset;
2193	/* zero offset means we're poking bpf prog entry */
2194	poking_bpf_entry = (offset == 0UL);
2195
2196	/* bpf prog entry, find plt and the real patchsite */
2197	if (poking_bpf_entry) {
2198		/* plt locates at the end of bpf prog */
2199		plt = image + size - PLT_TARGET_OFFSET;
2200
2201		/* skip to the nop instruction in bpf prog entry:
2202		 * bti c // if BTI enabled
2203		 * mov x9, x30
2204		 * nop
2205		 */
2206		ip = image + POKE_OFFSET * AARCH64_INSN_SIZE;
2207	}
2208
2209	/* long jump is only possible at bpf prog entry */
2210	if (WARN_ON((is_long_jump(ip, new_addr) || is_long_jump(ip, old_addr)) &&
2211		    !poking_bpf_entry))
2212		return -EINVAL;
2213
2214	if (poke_type == BPF_MOD_CALL)
2215		branch_type = AARCH64_INSN_BRANCH_LINK;
2216	else
2217		branch_type = AARCH64_INSN_BRANCH_NOLINK;
2218
2219	if (gen_branch_or_nop(branch_type, ip, old_addr, plt, &old_insn) < 0)
2220		return -EFAULT;
2221
2222	if (gen_branch_or_nop(branch_type, ip, new_addr, plt, &new_insn) < 0)
2223		return -EFAULT;
2224
2225	if (is_long_jump(ip, new_addr))
2226		plt_target = (u64)new_addr;
2227	else if (is_long_jump(ip, old_addr))
2228		/* if the old target is a long jump and the new target is not,
2229		 * restore the plt target to dummy_tramp, so there is always a
2230		 * legal and harmless address stored in plt target, and we'll
2231		 * never jump from plt to an unknown place.
2232		 */
2233		plt_target = (u64)&dummy_tramp;
2234
2235	if (plt_target) {
2236		/* non-zero plt_target indicates we're patching a bpf prog,
2237		 * which is read only.
2238		 */
2239		if (set_memory_rw(PAGE_MASK & ((uintptr_t)&plt->target), 1))
2240			return -EFAULT;
2241		WRITE_ONCE(plt->target, plt_target);
2242		set_memory_ro(PAGE_MASK & ((uintptr_t)&plt->target), 1);
2243		/* since plt target points to either the new trampoline
2244		 * or dummy_tramp, even if another CPU reads the old plt
2245		 * target value before fetching the bl instruction to plt,
2246		 * it will be brought back by dummy_tramp, so no barrier is
2247		 * required here.
2248		 */
2249	}
2250
2251	/* if the old target and the new target are both long jumps, no
2252	 * patching is required
2253	 */
2254	if (old_insn == new_insn)
2255		return 0;
2256
2257	mutex_lock(&text_mutex);
2258	if (aarch64_insn_read(ip, &replaced)) {
2259		ret = -EFAULT;
2260		goto out;
2261	}
2262
2263	if (replaced != old_insn) {
2264		ret = -EFAULT;
2265		goto out;
2266	}
2267
2268	/* We call aarch64_insn_patch_text_nosync() to replace instruction
2269	 * atomically, so no other CPUs will fetch a half-new and half-old
2270	 * instruction. But there is chance that another CPU executes the
2271	 * old instruction after the patching operation finishes (e.g.,
2272	 * pipeline not flushed, or icache not synchronized yet).
2273	 *
2274	 * 1. when a new trampoline is attached, it is not a problem for
2275	 *    different CPUs to jump to different trampolines temporarily.
2276	 *
2277	 * 2. when an old trampoline is freed, we should wait for all other
2278	 *    CPUs to exit the trampoline and make sure the trampoline is no
2279	 *    longer reachable, since bpf_tramp_image_put() function already
2280	 *    uses percpu_ref and task-based rcu to do the sync, no need to call
2281	 *    the sync version here, see bpf_tramp_image_put() for details.
2282	 */
2283	ret = aarch64_insn_patch_text_nosync(ip, new_insn);
2284out:
2285	mutex_unlock(&text_mutex);
2286
2287	return ret;
2288}
2289