1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * BPF JIT compiler for LoongArch
4 *
5 * Copyright (C) 2021 Loongson Technology Corporation Limited
6 */
7#include "ebpf_jit.h"
8
9#define TMP_REG_1       (MAX_BPF_JIT_REG + 0)
10#define TMP_REG_2       (MAX_BPF_JIT_REG + 1)
11#define TMP_REG_3       (MAX_BPF_JIT_REG + 2)
12#define REG_TCC         (MAX_BPF_JIT_REG + 3)
13#define TCC_SAVED       (MAX_BPF_JIT_REG + 4)
14
15#define SAVE_RA         BIT(0)
16#define SAVE_TCC        BIT(1)
17
18static const int regmap[] = {
19	/* return value from in-kernel function, and exit value for eBPF program */
20	[BPF_REG_0] = LOONGARCH_GPR_A5,
21	/* arguments from eBPF program to in-kernel function */
22	[BPF_REG_1] = LOONGARCH_GPR_A0,
23	[BPF_REG_2] = LOONGARCH_GPR_A1,
24	[BPF_REG_3] = LOONGARCH_GPR_A2,
25	[BPF_REG_4] = LOONGARCH_GPR_A3,
26	[BPF_REG_5] = LOONGARCH_GPR_A4,
27	/* callee saved registers that in-kernel function will preserve */
28	[BPF_REG_6] = LOONGARCH_GPR_S0,
29	[BPF_REG_7] = LOONGARCH_GPR_S1,
30	[BPF_REG_8] = LOONGARCH_GPR_S2,
31	[BPF_REG_9] = LOONGARCH_GPR_S3,
32	/* read-only frame pointer to access stack */
33	[BPF_REG_FP] = LOONGARCH_GPR_S4,
34	/* temporary register for blinding constants */
35	[BPF_REG_AX] = LOONGARCH_GPR_T0,
36	/* temporary register for internal BPF JIT */
37	[TMP_REG_1] = LOONGARCH_GPR_T1,
38	[TMP_REG_2] = LOONGARCH_GPR_T2,
39	[TMP_REG_3] = LOONGARCH_GPR_T3,
40	/* tail call */
41	[REG_TCC] = LOONGARCH_GPR_A6,
42	/* store A6 in S5 if program do calls */
43	[TCC_SAVED] = LOONGARCH_GPR_S5,
44};
45
46static void mark_call(struct jit_ctx *ctx)
47{
48	ctx->flags |= SAVE_RA;
49}
50
51static void mark_tail_call(struct jit_ctx *ctx)
52{
53	ctx->flags |= SAVE_TCC;
54}
55
56static bool seen_call(struct jit_ctx *ctx)
57{
58	return (ctx->flags & SAVE_RA);
59}
60
61static bool seen_tail_call(struct jit_ctx *ctx)
62{
63	return (ctx->flags & SAVE_TCC);
64}
65
66static u8 tail_call_reg(struct jit_ctx *ctx)
67{
68	if (seen_call(ctx))
69		return regmap[TCC_SAVED];
70
71	return regmap[REG_TCC];
72}
73
74/*
75 * eBPF prog stack layout:
76 *
77 *                                        high
78 * original $sp ------------> +-------------------------+ <--LOONGARCH_GPR_FP
79 *                            |           $ra           |
80 *                            +-------------------------+
81 *                            |           $fp           |
82 *                            +-------------------------+
83 *                            |           $s0           |
84 *                            +-------------------------+
85 *                            |           $s1           |
86 *                            +-------------------------+
87 *                            |           $s2           |
88 *                            +-------------------------+
89 *                            |           $s3           |
90 *                            +-------------------------+
91 *                            |           $s4           |
92 *                            +-------------------------+
93 *                            |           $s5           |
94 *                            +-------------------------+ <--BPF_REG_FP
95 *                            |  prog->aux->stack_depth |
96 *                            |        (optional)       |
97 * current $sp -------------> +-------------------------+
98 *                                        low
99 */
100static void build_prologue(struct jit_ctx *ctx)
101{
102	int stack_adjust = 0, store_offset, bpf_stack_adjust;
103
104	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
105
106	stack_adjust += sizeof(long); /* LOONGARCH_GPR_RA */
107	stack_adjust += sizeof(long); /* LOONGARCH_GPR_FP */
108	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S0 */
109	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S1 */
110	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S2 */
111	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S3 */
112	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S4 */
113	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S5 */
114
115	stack_adjust = round_up(stack_adjust, 16);
116	stack_adjust += bpf_stack_adjust;
117
118	/*
119	 * First instruction initializes the tail call count (TCC).
120	 * On tail call we skip this instruction, and the TCC is
121	 * passed in REG_TCC from the caller.
122	 */
123	emit_insn(ctx, addid, regmap[REG_TCC], LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
124
125	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
126
127	store_offset = stack_adjust - sizeof(long);
128	emit_insn(ctx, std, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, store_offset);
129
130	store_offset -= sizeof(long);
131	emit_insn(ctx, std, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, store_offset);
132
133	store_offset -= sizeof(long);
134	emit_insn(ctx, std, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, store_offset);
135
136	store_offset -= sizeof(long);
137	emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, store_offset);
138
139	store_offset -= sizeof(long);
140	emit_insn(ctx, std, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, store_offset);
141
142	store_offset -= sizeof(long);
143	emit_insn(ctx, std, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, store_offset);
144
145	store_offset -= sizeof(long);
146	emit_insn(ctx, std, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, store_offset);
147
148	store_offset -= sizeof(long);
149	emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
150
151	emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
152
153	if (bpf_stack_adjust)
154		emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
155
156	/*
157	 * Program contains calls and tail calls, so REG_TCC need
158	 * to be saved across calls.
159	 */
160	if (seen_tail_call(ctx) && seen_call(ctx))
161		move_reg(ctx, regmap[TCC_SAVED], regmap[REG_TCC]);
162
163	ctx->stack_size = stack_adjust;
164}
165
166static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
167{
168	int stack_adjust = ctx->stack_size;
169	int load_offset;
170
171	load_offset = stack_adjust - sizeof(long);
172	emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, load_offset);
173
174	load_offset -= sizeof(long);
175	emit_insn(ctx, ldd, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, load_offset);
176
177	load_offset -= sizeof(long);
178	emit_insn(ctx, ldd, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, load_offset);
179
180	load_offset -= sizeof(long);
181	emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, load_offset);
182
183	load_offset -= sizeof(long);
184	emit_insn(ctx, ldd, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, load_offset);
185
186	load_offset -= sizeof(long);
187	emit_insn(ctx, ldd, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, load_offset);
188
189	load_offset -= sizeof(long);
190	emit_insn(ctx, ldd, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, load_offset);
191
192	load_offset -= sizeof(long);
193	emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
194
195	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
196
197	if (!is_tail_call) {
198		/* Set return value */
199		move_reg(ctx, LOONGARCH_GPR_A0, regmap[BPF_REG_0]);
200		/* Return to the caller */
201		emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_RA, 0);
202	} else {
203		/*
204		 * Call the next bpf prog and skip the first instruction
205		 * of TCC initialization.
206		 */
207		emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, regmap[TMP_REG_3], 1);
208	}
209}
210
211void build_epilogue(struct jit_ctx *ctx)
212{
213	__build_epilogue(ctx, false);
214}
215
216/* initialized on the first pass of build_body() */
217static int out_offset = -1;
218static int emit_bpf_tail_call(struct jit_ctx *ctx)
219{
220	int off;
221	u8 tcc = tail_call_reg(ctx);
222	u8 a1 = LOONGARCH_GPR_A1;
223	u8 a2 = LOONGARCH_GPR_A2;
224	u8 tmp1 = regmap[TMP_REG_1];
225	u8 tmp2 = regmap[TMP_REG_2];
226	u8 tmp3 = regmap[TMP_REG_3];
227	const int idx0 = ctx->idx;
228
229#define cur_offset (ctx->idx - idx0)
230#define jmp_offset (out_offset - (cur_offset))
231
232	/*
233	 * a0: &ctx
234	 * a1: &array
235	 * a2: index
236	 *
237	 * if (index >= array->map.max_entries)
238	 *	 goto out;
239	 */
240	off = offsetof(struct bpf_array, map.max_entries);
241	emit_insn(ctx, ldwu, tmp1, a1, off);
242	/* bgeu $a2, $t1, jmp_offset */
243	emit_tailcall_jump(ctx, BPF_JGE, a2, tmp1, jmp_offset);
244
245	/*
246	 * if (TCC-- < 0)
247	 *	 goto out;
248	 */
249	emit_insn(ctx, addid, tmp1, tcc, -1);
250	emit_tailcall_jump(ctx, BPF_JSLT, tcc, LOONGARCH_GPR_ZERO, jmp_offset);
251
252	/*
253	 * prog = array->ptrs[index];
254	 * if (!prog)
255	 *	 goto out;
256	 */
257	emit_insn(ctx, sllid, tmp2, a2, 3);
258	emit_insn(ctx, addd, tmp2, tmp2, a1);
259	off = offsetof(struct bpf_array, ptrs);
260	emit_insn(ctx, ldd, tmp2, tmp2, off);
261	/* beq $t2, $zero, jmp_offset */
262	emit_tailcall_jump(ctx, BPF_JEQ, tmp2, LOONGARCH_GPR_ZERO, jmp_offset);
263
264	/* goto *(prog->bpf_func + 4); */
265	off = offsetof(struct bpf_prog, bpf_func);
266	emit_insn(ctx, ldd, tmp3, tmp2, off);
267	move_reg(ctx, tcc, tmp1);
268	__build_epilogue(ctx, true);
269
270	/* out: */
271	if (out_offset == -1)
272		out_offset = cur_offset;
273	if (cur_offset != out_offset) {
274		pr_err_once("tail_call out_offset = %d, expected %d!\n",
275			     cur_offset, out_offset);
276		return -1;
277	}
278
279	return 0;
280#undef cur_offset
281#undef jmp_offset
282}
283
284static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool extra_pass)
285{
286	bool is32 = (BPF_CLASS(insn->code) == BPF_ALU);
287	const u8 code = insn->code;
288	const u8 cond = BPF_OP(code);
289	const u8 dst = regmap[insn->dst_reg];
290	const u8 src = regmap[insn->src_reg];
291	const u8 tmp = regmap[TMP_REG_1];
292	const u8 tmp2 = regmap[TMP_REG_2];
293	const s16 off = insn->off;
294	const s32 imm = insn->imm;
295	int i = insn - ctx->prog->insnsi;
296	int jmp_offset;
297	bool func_addr_fixed;
298	u64 func_addr;
299	u64 imm64;
300	int ret;
301
302	switch (code) {
303	/* dst = src */
304	case BPF_ALU | BPF_MOV | BPF_X:
305	case BPF_ALU64 | BPF_MOV | BPF_X:
306		move_reg(ctx, dst, src);
307		emit_zext_32(ctx, dst, is32);
308		break;
309	/* dst = imm */
310	case BPF_ALU | BPF_MOV | BPF_K:
311	case BPF_ALU64 | BPF_MOV | BPF_K:
312		move_imm32(ctx, dst, imm, is32);
313		break;
314
315	/* dst = dst + src */
316	case BPF_ALU | BPF_ADD | BPF_X:
317	case BPF_ALU64 | BPF_ADD | BPF_X:
318		emit_insn(ctx, addd, dst, dst, src);
319		emit_zext_32(ctx, dst, is32);
320		break;
321	/* dst = dst + imm */
322	case BPF_ALU | BPF_ADD | BPF_K:
323	case BPF_ALU64 | BPF_ADD | BPF_K:
324		if (is_signed_imm12(imm)) {
325			emit_insn(ctx, addid, dst, dst, imm);
326		} else {
327			move_imm32(ctx, tmp, imm, is32);
328			emit_insn(ctx, addd, dst, dst, tmp);
329		}
330		emit_zext_32(ctx, dst, is32);
331		break;
332
333	/* dst = dst - src */
334	case BPF_ALU | BPF_SUB | BPF_X:
335	case BPF_ALU64 | BPF_SUB | BPF_X:
336		emit_insn(ctx, subd, dst, dst, src);
337		emit_zext_32(ctx, dst, is32);
338		break;
339	/* dst = dst - imm */
340	case BPF_ALU | BPF_SUB | BPF_K:
341	case BPF_ALU64 | BPF_SUB | BPF_K:
342		if (is_signed_imm12(-imm)) {
343			emit_insn(ctx, addid, dst, dst, -imm);
344		} else {
345			move_imm32(ctx, tmp, imm, is32);
346			emit_insn(ctx, subd, dst, dst, tmp);
347		}
348		emit_zext_32(ctx, dst, is32);
349		break;
350
351	/* dst = dst * src */
352	case BPF_ALU | BPF_MUL | BPF_X:
353	case BPF_ALU64 | BPF_MUL | BPF_X:
354		emit_insn(ctx, muld, dst, dst, src);
355		emit_zext_32(ctx, dst, is32);
356		break;
357	/* dst = dst * imm */
358	case BPF_ALU | BPF_MUL | BPF_K:
359	case BPF_ALU64 | BPF_MUL | BPF_K:
360		move_imm32(ctx, tmp, imm, is32);
361		emit_insn(ctx, muld, dst, dst, tmp);
362		emit_zext_32(ctx, dst, is32);
363		break;
364
365	/* dst = dst / src */
366	case BPF_ALU | BPF_DIV | BPF_X:
367	case BPF_ALU64 | BPF_DIV | BPF_X:
368		emit_insn(ctx, divdu, dst, dst, src);
369		emit_zext_32(ctx, dst, is32);
370		break;
371	/* dst = dst / imm */
372	case BPF_ALU | BPF_DIV | BPF_K:
373	case BPF_ALU64 | BPF_DIV | BPF_K:
374		move_imm32(ctx, tmp, imm, is32);
375		emit_insn(ctx, divdu, dst, dst, tmp);
376		emit_zext_32(ctx, dst, is32);
377		break;
378
379	/* dst = dst % src */
380	case BPF_ALU | BPF_MOD | BPF_X:
381	case BPF_ALU64 | BPF_MOD | BPF_X:
382		emit_insn(ctx, moddu, dst, dst, src);
383		emit_zext_32(ctx, dst, is32);
384		break;
385	/* dst = dst % imm */
386	case BPF_ALU | BPF_MOD | BPF_K:
387	case BPF_ALU64 | BPF_MOD | BPF_K:
388		move_imm32(ctx, tmp, imm, is32);
389		emit_insn(ctx, moddu, dst, dst, tmp);
390		emit_zext_32(ctx, dst, is32);
391		break;
392
393	/* dst = -dst */
394	case BPF_ALU | BPF_NEG:
395	case BPF_ALU64 | BPF_NEG:
396		move_imm32(ctx, tmp, imm, is32);
397		emit_insn(ctx, subd, dst, LOONGARCH_GPR_ZERO, dst);
398		emit_zext_32(ctx, dst, is32);
399		break;
400
401	/* dst = dst & src */
402	case BPF_ALU | BPF_AND | BPF_X:
403	case BPF_ALU64 | BPF_AND | BPF_X:
404		emit_insn(ctx, and, dst, dst, src);
405		emit_zext_32(ctx, dst, is32);
406		break;
407	/* dst = dst & imm */
408	case BPF_ALU | BPF_AND | BPF_K:
409	case BPF_ALU64 | BPF_AND | BPF_K:
410		if (is_unsigned_imm12(imm)) {
411			emit_insn(ctx, andi, dst, dst, imm);
412		} else {
413			move_imm32(ctx, tmp, imm, is32);
414			emit_insn(ctx, and, dst, dst, tmp);
415		}
416		emit_zext_32(ctx, dst, is32);
417		break;
418
419	/* dst = dst | src */
420	case BPF_ALU | BPF_OR | BPF_X:
421	case BPF_ALU64 | BPF_OR | BPF_X:
422		emit_insn(ctx, or, dst, dst, src);
423		emit_zext_32(ctx, dst, is32);
424		break;
425	/* dst = dst | imm */
426	case BPF_ALU | BPF_OR | BPF_K:
427	case BPF_ALU64 | BPF_OR | BPF_K:
428		if (is_unsigned_imm12(imm)) {
429			emit_insn(ctx, ori, dst, dst, imm);
430		} else {
431			move_imm32(ctx, tmp, imm, is32);
432			emit_insn(ctx, or, dst, dst, tmp);
433		}
434		emit_zext_32(ctx, dst, is32);
435		break;
436
437	/* dst = dst ^ src */
438	case BPF_ALU | BPF_XOR | BPF_X:
439	case BPF_ALU64 | BPF_XOR | BPF_X:
440		emit_insn(ctx, xor, dst, dst, src);
441		emit_zext_32(ctx, dst, is32);
442		break;
443	/* dst = dst ^ imm */
444	case BPF_ALU | BPF_XOR | BPF_K:
445	case BPF_ALU64 | BPF_XOR | BPF_K:
446		if (is_unsigned_imm12(imm)) {
447			emit_insn(ctx, xori, dst, dst, imm);
448		} else {
449			move_imm32(ctx, tmp, imm, is32);
450			emit_insn(ctx, xor, dst, dst, tmp);
451		}
452		emit_zext_32(ctx, dst, is32);
453		break;
454
455	/* dst = dst << src (logical) */
456	case BPF_ALU | BPF_LSH | BPF_X:
457		emit_insn(ctx, sllw, dst, dst, src);
458		emit_zext_32(ctx, dst, is32);
459		break;
460	case BPF_ALU64 | BPF_LSH | BPF_X:
461		emit_insn(ctx, slld, dst, dst, src);
462		break;
463	/* dst = dst << imm (logical) */
464	case BPF_ALU | BPF_LSH | BPF_K:
465		emit_insn(ctx, slliw, dst, dst, imm);
466		emit_zext_32(ctx, dst, is32);
467		break;
468	case BPF_ALU64 | BPF_LSH | BPF_K:
469		emit_insn(ctx, sllid, dst, dst, imm);
470		break;
471
472	/* dst = dst >> src (logical) */
473	case BPF_ALU | BPF_RSH | BPF_X:
474		emit_insn(ctx, srlw, dst, dst, src);
475		emit_zext_32(ctx, dst, is32);
476		break;
477	case BPF_ALU64 | BPF_RSH | BPF_X:
478		emit_insn(ctx, srld, dst, dst, src);
479		break;
480	/* dst = dst >> imm (logical) */
481	case BPF_ALU | BPF_RSH | BPF_K:
482		emit_insn(ctx, srliw, dst, dst, imm);
483		emit_zext_32(ctx, dst, is32);
484		break;
485	case BPF_ALU64 | BPF_RSH | BPF_K:
486		emit_insn(ctx, srlid, dst, dst, imm);
487		break;
488
489	/* dst = dst >> src (arithmetic) */
490	case BPF_ALU | BPF_ARSH | BPF_X:
491		emit_insn(ctx, sraw, dst, dst, src);
492		emit_zext_32(ctx, dst, is32);
493		break;
494	case BPF_ALU64 | BPF_ARSH | BPF_X:
495		emit_insn(ctx, srad, dst, dst, src);
496		break;
497	/* dst = dst >> imm (arithmetic) */
498	case BPF_ALU | BPF_ARSH | BPF_K:
499		emit_insn(ctx, sraiw, dst, dst, imm);
500		emit_zext_32(ctx, dst, is32);
501		break;
502	case BPF_ALU64 | BPF_ARSH | BPF_K:
503		emit_insn(ctx, sraid, dst, dst, imm);
504		break;
505
506	/* dst = BSWAP##imm(dst) */
507	case BPF_ALU | BPF_END | BPF_FROM_LE:
508		switch (imm) {
509		case 16:
510			/* zero-extend 16 bits into 64 bits */
511			emit_insn(ctx, sllid, dst, dst, 48);
512			emit_insn(ctx, srlid, dst, dst, 48);
513			break;
514		case 32:
515			/* zero-extend 32 bits into 64 bits */
516			emit_zext_32(ctx, dst, is32);
517			break;
518		case 64:
519			/* do nothing */
520			break;
521		}
522		break;
523	case BPF_ALU | BPF_END | BPF_FROM_BE:
524		switch (imm) {
525		case 16:
526			emit_insn(ctx, revb2h, dst, dst);
527			/* zero-extend 16 bits into 64 bits */
528			emit_insn(ctx, sllid, dst, dst, 48);
529			emit_insn(ctx, srlid, dst, dst, 48);
530			break;
531		case 32:
532			emit_insn(ctx, revb2w, dst, dst);
533			/* zero-extend 32 bits into 64 bits */
534			emit_zext_32(ctx, dst, is32);
535			break;
536		case 64:
537			emit_insn(ctx, revbd, dst, dst);
538			break;
539		}
540		break;
541
542	/* PC += off if dst cond src */
543	case BPF_JMP | BPF_JEQ | BPF_X:
544	case BPF_JMP | BPF_JNE | BPF_X:
545	case BPF_JMP | BPF_JGT | BPF_X:
546	case BPF_JMP | BPF_JGE | BPF_X:
547	case BPF_JMP | BPF_JLT | BPF_X:
548	case BPF_JMP | BPF_JLE | BPF_X:
549	case BPF_JMP | BPF_JSGT | BPF_X:
550	case BPF_JMP | BPF_JSGE | BPF_X:
551	case BPF_JMP | BPF_JSLT | BPF_X:
552	case BPF_JMP | BPF_JSLE | BPF_X:
553		jmp_offset = bpf2la_offset(i, off, ctx);
554		emit_cond_jump(ctx, cond, dst, src, jmp_offset);
555		break;
556
557	/* PC += off if dst cond imm */
558	case BPF_JMP | BPF_JEQ | BPF_K:
559	case BPF_JMP | BPF_JNE | BPF_K:
560	case BPF_JMP | BPF_JGT | BPF_K:
561	case BPF_JMP | BPF_JGE | BPF_K:
562	case BPF_JMP | BPF_JLT | BPF_K:
563	case BPF_JMP | BPF_JLE | BPF_K:
564	case BPF_JMP | BPF_JSGT | BPF_K:
565	case BPF_JMP | BPF_JSGE | BPF_K:
566	case BPF_JMP | BPF_JSLT | BPF_K:
567	case BPF_JMP | BPF_JSLE | BPF_K:
568		jmp_offset = bpf2la_offset(i, off, ctx);
569		move_imm32(ctx, tmp, imm, is32);
570		emit_cond_jump(ctx, cond, dst, tmp, jmp_offset);
571		break;
572
573	/* PC += off if dst & src */
574	case BPF_JMP | BPF_JSET | BPF_X:
575		jmp_offset = bpf2la_offset(i, off, ctx);
576		emit_insn(ctx, and, tmp, dst, src);
577		emit_cond_jump(ctx, cond, tmp, LOONGARCH_GPR_ZERO, jmp_offset);
578		break;
579	/* PC += off if dst & imm */
580	case BPF_JMP | BPF_JSET | BPF_K:
581		jmp_offset = bpf2la_offset(i, off, ctx);
582		move_imm32(ctx, tmp, imm, is32);
583		emit_insn(ctx, and, tmp, dst, tmp);
584		emit_cond_jump(ctx, cond, tmp, LOONGARCH_GPR_ZERO, jmp_offset);
585		break;
586
587	/* PC += off */
588	case BPF_JMP | BPF_JA:
589		jmp_offset = bpf2la_offset(i, off, ctx);
590		emit_uncond_jump(ctx, jmp_offset, false);
591		break;
592
593	/* function call */
594	case BPF_JMP | BPF_CALL:
595		mark_call(ctx);
596		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
597					    &func_addr, &func_addr_fixed);
598		if (ret < 0)
599			return ret;
600
601		move_imm64(ctx, tmp, func_addr, is32);
602		emit_insn(ctx, jirl, LOONGARCH_GPR_RA, tmp, 0);
603		move_reg(ctx, regmap[BPF_REG_0], LOONGARCH_GPR_A0);
604		break;
605
606	/* tail call */
607	case BPF_JMP | BPF_TAIL_CALL:
608		mark_tail_call(ctx);
609		if (emit_bpf_tail_call(ctx))
610			return -EINVAL;
611		break;
612
613	/* function return */
614	case BPF_JMP | BPF_EXIT:
615		emit_sext_32(ctx, regmap[BPF_REG_0]);
616		/*
617		 * Optimization: when last instruction is EXIT,
618		 * simply fallthrough to epilogue.
619		 */
620		if (i == ctx->prog->len - 1)
621			break;
622
623		jmp_offset = epilogue_offset(ctx);
624		emit_uncond_jump(ctx, jmp_offset, true);
625		break;
626
627	/* dst = imm64 */
628	case BPF_LD | BPF_IMM | BPF_DW:
629		imm64 = (u64)(insn + 1)->imm << 32 | (u32)insn->imm;
630		move_imm64(ctx, dst, imm64, is32);
631		return 1;
632
633	/* dst = *(size *)(src + off) */
634	case BPF_LDX | BPF_MEM | BPF_B:
635	case BPF_LDX | BPF_MEM | BPF_H:
636	case BPF_LDX | BPF_MEM | BPF_W:
637	case BPF_LDX | BPF_MEM | BPF_DW:
638		if (is_signed_imm12(off)) {
639			switch (BPF_SIZE(code)) {
640			case BPF_B:
641				emit_insn(ctx, ldbu, dst, src, off);
642				break;
643			case BPF_H:
644				emit_insn(ctx, ldhu, dst, src, off);
645				break;
646			case BPF_W:
647				emit_insn(ctx, ldwu, dst, src, off);
648				break;
649			case BPF_DW:
650				emit_insn(ctx, ldd, dst, src, off);
651				break;
652			}
653		} else {
654			move_imm32(ctx, tmp, off, is32);
655			switch (BPF_SIZE(code)) {
656			case BPF_B:
657				emit_insn(ctx, ldxbu, dst, src, tmp);
658				break;
659			case BPF_H:
660				emit_insn(ctx, ldxhu, dst, src, tmp);
661				break;
662			case BPF_W:
663				emit_insn(ctx, ldxwu, dst, src, tmp);
664				break;
665			case BPF_DW:
666				emit_insn(ctx, ldxd, dst, src, tmp);
667				break;
668			}
669		}
670		break;
671
672	/* *(size *)(dst + off) = imm */
673	case BPF_ST | BPF_MEM | BPF_B:
674	case BPF_ST | BPF_MEM | BPF_H:
675	case BPF_ST | BPF_MEM | BPF_W:
676	case BPF_ST | BPF_MEM | BPF_DW:
677		move_imm32(ctx, tmp, imm, is32);
678		if (is_signed_imm12(off)) {
679			switch (BPF_SIZE(code)) {
680			case BPF_B:
681				emit_insn(ctx, stb, tmp, dst, off);
682				break;
683			case BPF_H:
684				emit_insn(ctx, sth, tmp, dst, off);
685				break;
686			case BPF_W:
687				emit_insn(ctx, stw, tmp, dst, off);
688				break;
689			case BPF_DW:
690				emit_insn(ctx, std, tmp, dst, off);
691				break;
692			}
693		} else {
694			move_imm32(ctx, tmp2, off, is32);
695			switch (BPF_SIZE(code)) {
696			case BPF_B:
697				emit_insn(ctx, stxb, tmp, dst, tmp2);
698				break;
699			case BPF_H:
700				emit_insn(ctx, stxh, tmp, dst, tmp2);
701				break;
702			case BPF_W:
703				emit_insn(ctx, stxw, tmp, dst, tmp2);
704				break;
705			case BPF_DW:
706				emit_insn(ctx, stxd, tmp, dst, tmp2);
707				break;
708			}
709		}
710		break;
711
712	/* *(size *)(dst + off) = src */
713	case BPF_STX | BPF_MEM | BPF_B:
714	case BPF_STX | BPF_MEM | BPF_H:
715	case BPF_STX | BPF_MEM | BPF_W:
716	case BPF_STX | BPF_MEM | BPF_DW:
717		if (is_signed_imm12(off)) {
718			switch (BPF_SIZE(code)) {
719			case BPF_B:
720				emit_insn(ctx, stb, src, dst, off);
721				break;
722			case BPF_H:
723				emit_insn(ctx, sth, src, dst, off);
724				break;
725			case BPF_W:
726				emit_insn(ctx, stw, src, dst, off);
727				break;
728			case BPF_DW:
729				emit_insn(ctx, std, src, dst, off);
730				break;
731			}
732		} else {
733			move_imm32(ctx, tmp, off, is32);
734			switch (BPF_SIZE(code)) {
735			case BPF_B:
736				emit_insn(ctx, stxb, src, dst, tmp);
737				break;
738			case BPF_H:
739				emit_insn(ctx, stxh, src, dst, tmp);
740				break;
741			case BPF_W:
742				emit_insn(ctx, stxw, src, dst, tmp);
743				break;
744			case BPF_DW:
745				emit_insn(ctx, stxd, src, dst, tmp);
746				break;
747			}
748		}
749		break;
750
751	/* atomic_add: lock *(size *)(dst + off) += src */
752	case BPF_STX | BPF_XADD | BPF_W:
753	case BPF_STX | BPF_XADD | BPF_DW:
754		if (insn->imm != BPF_ADD) {
755			pr_err_once("unknown atomic op code %02x\n", insn->imm);
756			return -EINVAL;
757		}
758
759		move_imm32(ctx, tmp, off, is32);
760		emit_insn(ctx, addd, tmp, dst, tmp);
761		switch (BPF_SIZE(insn->code)) {
762		case BPF_W:
763			emit_insn(ctx, amaddw, tmp2, src, tmp);
764			break;
765		case BPF_DW:
766			emit_insn(ctx, amaddd, tmp2, src, tmp);
767			break;
768		}
769		break;
770
771	default:
772		pr_err("bpf_jit: unknown opcode %02x\n", code);
773		return -EINVAL;
774	}
775
776	return 0;
777}
778
779static int build_body(struct jit_ctx *ctx, bool extra_pass)
780{
781	const struct bpf_prog *prog = ctx->prog;
782	int i;
783
784	for (i = 0; i < prog->len; i++) {
785		const struct bpf_insn *insn = &prog->insnsi[i];
786		int ret;
787
788		if (!ctx->image)
789			ctx->offset[i] = ctx->idx;
790
791		ret = build_insn(insn, ctx, extra_pass);
792		if (ret > 0) {
793			i++;
794			if (!ctx->image)
795				ctx->offset[i] = ctx->idx;
796			continue;
797		}
798		if (ret)
799			return ret;
800	}
801
802	if (!ctx->image)
803		ctx->offset[i] = ctx->idx;
804
805	return 0;
806}
807
808static inline void bpf_flush_icache(void *start, void *end)
809{
810	flush_icache_range((unsigned long)start, (unsigned long)end);
811}
812
813/* Fill space with illegal instructions */
814static void jit_fill_hole(void *area, unsigned int size)
815{
816	u32 *ptr;
817
818	/* We are guaranteed to have aligned memory */
819	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
820		*ptr++ = INSN_BREAK;
821}
822
823static int validate_code(struct jit_ctx *ctx)
824{
825	int i;
826	union loongarch_instruction insn;
827
828	for (i = 0; i < ctx->idx; i++) {
829		insn = ctx->image[i];
830		/* Check INSN_BREAK */
831		if (insn.word == INSN_BREAK)
832			return -1;
833	}
834
835	return 0;
836}
837
838struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
839{
840	struct bpf_prog *tmp, *orig_prog = prog;
841	struct bpf_binary_header *header;
842	struct jit_data *jit_data;
843	struct jit_ctx ctx;
844	bool tmp_blinded = false;
845	bool extra_pass = false;
846	int image_size;
847	u8 *image_ptr;
848
849	/*
850	 * If BPF JIT was not enabled then we must fall back to
851	 * the interpreter.
852	 */
853	if (!prog->jit_requested)
854		return orig_prog;
855
856	tmp = bpf_jit_blind_constants(prog);
857	/*
858	 * If blinding was requested and we failed during blinding,
859	 * we must fall back to the interpreter. Otherwise, we save
860	 * the new JITed code.
861	 */
862	if (IS_ERR(tmp))
863		return orig_prog;
864	if (tmp != prog) {
865		tmp_blinded = true;
866		prog = tmp;
867	}
868
869	jit_data = prog->aux->jit_data;
870	if (!jit_data) {
871		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
872		if (!jit_data) {
873			prog = orig_prog;
874			goto out;
875		}
876		prog->aux->jit_data = jit_data;
877	}
878	if (jit_data->ctx.offset) {
879		ctx = jit_data->ctx;
880		image_ptr = jit_data->image;
881		header = jit_data->header;
882		extra_pass = true;
883		image_size = sizeof(u32) * ctx.idx;
884		goto skip_init_ctx;
885	}
886
887	memset(&ctx, 0, sizeof(ctx));
888	ctx.prog = prog;
889
890	ctx.offset = kcalloc(prog->len + 1, sizeof(*ctx.offset), GFP_KERNEL);
891	if (!ctx.offset) {
892		prog = orig_prog;
893		goto out_off;
894	}
895
896	/* 1. Initial fake pass to compute ctx->idx and set ctx->flags */
897	if (build_body(&ctx, extra_pass)) {
898		prog = orig_prog;
899		goto out_off;
900	}
901	build_prologue(&ctx);
902	ctx.epilogue_offset = ctx.idx;
903	build_epilogue(&ctx);
904
905	/* Now we know the actual image size.
906	 * As each LoongArch instruction is of length 32bit,
907	 * we are translating number of JITed intructions into
908	 * the size required to store these JITed code.
909	 */
910	image_size = sizeof(u32) * ctx.idx;
911	/* Now we know the size of the structure to make */
912	header = bpf_jit_binary_alloc(image_size, &image_ptr,
913				      sizeof(u32), jit_fill_hole);
914	if (!header) {
915		prog = orig_prog;
916		goto out_off;
917	}
918
919	/* 2. Now, the actual pass to generate final JIT code */
920	ctx.image = (union loongarch_instruction *)image_ptr;
921skip_init_ctx:
922	ctx.idx = 0;
923
924	build_prologue(&ctx);
925	if (build_body(&ctx, extra_pass)) {
926		bpf_jit_binary_free(header);
927		prog = orig_prog;
928		goto out_off;
929	}
930	build_epilogue(&ctx);
931
932	/* 3. Extra pass to validate JITed code */
933	if (validate_code(&ctx)) {
934		bpf_jit_binary_free(header);
935		prog = orig_prog;
936		goto out_off;
937	}
938
939	/* And we're done */
940	if (bpf_jit_enable > 1)
941		bpf_jit_dump(prog->len, image_size, 2, ctx.image);
942
943	/* Update the icache */
944	bpf_flush_icache(header, ctx.image + ctx.idx);
945
946	if (!prog->is_func || extra_pass) {
947		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
948			pr_err_once("multi-func JIT bug %d != %d\n",
949				    ctx.idx, jit_data->ctx.idx);
950			bpf_jit_binary_free(header);
951			prog->bpf_func = NULL;
952			prog->jited = 0;
953			goto out_off;
954		}
955		bpf_jit_binary_lock_ro(header);
956	} else {
957		jit_data->ctx = ctx;
958		jit_data->image = image_ptr;
959		jit_data->header = header;
960	}
961	prog->bpf_func = (void *)ctx.image;
962	prog->jited = 1;
963	prog->jited_len = image_size;
964
965	if (!prog->is_func || extra_pass) {
966out_off:
967		kfree(ctx.offset);
968		kfree(jit_data);
969		prog->aux->jit_data = NULL;
970	}
971out:
972	if (tmp_blinded)
973		bpf_jit_prog_release_other(prog, prog == orig_prog ?
974					   tmp : orig_prog);
975
976	out_offset = -1;
977	return prog;
978}
979