1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for LoongArch
4  *
5  * Copyright (C) 2021 Loongson Technology Corporation Limited
6  */
7 #include "ebpf_jit.h"
8 
9 #define TMP_REG_1       (MAX_BPF_JIT_REG + 0)
10 #define TMP_REG_2       (MAX_BPF_JIT_REG + 1)
11 #define TMP_REG_3       (MAX_BPF_JIT_REG + 2)
12 #define REG_TCC         (MAX_BPF_JIT_REG + 3)
13 #define TCC_SAVED       (MAX_BPF_JIT_REG + 4)
14 
15 #define SAVE_RA         BIT(0)
16 #define SAVE_TCC        BIT(1)
17 
18 static const int regmap[] = {
19 	/* return value from in-kernel function, and exit value for eBPF program */
20 	[BPF_REG_0] = LOONGARCH_GPR_A5,
21 	/* arguments from eBPF program to in-kernel function */
22 	[BPF_REG_1] = LOONGARCH_GPR_A0,
23 	[BPF_REG_2] = LOONGARCH_GPR_A1,
24 	[BPF_REG_3] = LOONGARCH_GPR_A2,
25 	[BPF_REG_4] = LOONGARCH_GPR_A3,
26 	[BPF_REG_5] = LOONGARCH_GPR_A4,
27 	/* callee saved registers that in-kernel function will preserve */
28 	[BPF_REG_6] = LOONGARCH_GPR_S0,
29 	[BPF_REG_7] = LOONGARCH_GPR_S1,
30 	[BPF_REG_8] = LOONGARCH_GPR_S2,
31 	[BPF_REG_9] = LOONGARCH_GPR_S3,
32 	/* read-only frame pointer to access stack */
33 	[BPF_REG_FP] = LOONGARCH_GPR_S4,
34 	/* temporary register for blinding constants */
35 	[BPF_REG_AX] = LOONGARCH_GPR_T0,
36 	/* temporary register for internal BPF JIT */
37 	[TMP_REG_1] = LOONGARCH_GPR_T1,
38 	[TMP_REG_2] = LOONGARCH_GPR_T2,
39 	[TMP_REG_3] = LOONGARCH_GPR_T3,
40 	/* tail call */
41 	[REG_TCC] = LOONGARCH_GPR_A6,
42 	/* store A6 in S5 if program do calls */
43 	[TCC_SAVED] = LOONGARCH_GPR_S5,
44 };
45 
mark_call(struct jit_ctx *ctx)46 static void mark_call(struct jit_ctx *ctx)
47 {
48 	ctx->flags |= SAVE_RA;
49 }
50 
mark_tail_call(struct jit_ctx *ctx)51 static void mark_tail_call(struct jit_ctx *ctx)
52 {
53 	ctx->flags |= SAVE_TCC;
54 }
55 
seen_call(struct jit_ctx *ctx)56 static bool seen_call(struct jit_ctx *ctx)
57 {
58 	return (ctx->flags & SAVE_RA);
59 }
60 
seen_tail_call(struct jit_ctx *ctx)61 static bool seen_tail_call(struct jit_ctx *ctx)
62 {
63 	return (ctx->flags & SAVE_TCC);
64 }
65 
tail_call_reg(struct jit_ctx *ctx)66 static u8 tail_call_reg(struct jit_ctx *ctx)
67 {
68 	if (seen_call(ctx))
69 		return regmap[TCC_SAVED];
70 
71 	return regmap[REG_TCC];
72 }
73 
74 /*
75  * eBPF prog stack layout:
76  *
77  *                                        high
78  * original $sp ------------> +-------------------------+ <--LOONGARCH_GPR_FP
79  *                            |           $ra           |
80  *                            +-------------------------+
81  *                            |           $fp           |
82  *                            +-------------------------+
83  *                            |           $s0           |
84  *                            +-------------------------+
85  *                            |           $s1           |
86  *                            +-------------------------+
87  *                            |           $s2           |
88  *                            +-------------------------+
89  *                            |           $s3           |
90  *                            +-------------------------+
91  *                            |           $s4           |
92  *                            +-------------------------+
93  *                            |           $s5           |
94  *                            +-------------------------+ <--BPF_REG_FP
95  *                            |  prog->aux->stack_depth |
96  *                            |        (optional)       |
97  * current $sp -------------> +-------------------------+
98  *                                        low
99  */
build_prologue(struct jit_ctx *ctx)100 static void build_prologue(struct jit_ctx *ctx)
101 {
102 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
103 
104 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
105 
106 	stack_adjust += sizeof(long); /* LOONGARCH_GPR_RA */
107 	stack_adjust += sizeof(long); /* LOONGARCH_GPR_FP */
108 	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S0 */
109 	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S1 */
110 	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S2 */
111 	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S3 */
112 	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S4 */
113 	stack_adjust += sizeof(long); /* LOONGARCH_GPR_S5 */
114 
115 	stack_adjust = round_up(stack_adjust, 16);
116 	stack_adjust += bpf_stack_adjust;
117 
118 	/*
119 	 * First instruction initializes the tail call count (TCC).
120 	 * On tail call we skip this instruction, and the TCC is
121 	 * passed in REG_TCC from the caller.
122 	 */
123 	emit_insn(ctx, addid, regmap[REG_TCC], LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
124 
125 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
126 
127 	store_offset = stack_adjust - sizeof(long);
128 	emit_insn(ctx, std, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, store_offset);
129 
130 	store_offset -= sizeof(long);
131 	emit_insn(ctx, std, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, store_offset);
132 
133 	store_offset -= sizeof(long);
134 	emit_insn(ctx, std, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, store_offset);
135 
136 	store_offset -= sizeof(long);
137 	emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, store_offset);
138 
139 	store_offset -= sizeof(long);
140 	emit_insn(ctx, std, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, store_offset);
141 
142 	store_offset -= sizeof(long);
143 	emit_insn(ctx, std, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, store_offset);
144 
145 	store_offset -= sizeof(long);
146 	emit_insn(ctx, std, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, store_offset);
147 
148 	store_offset -= sizeof(long);
149 	emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
150 
151 	emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
152 
153 	if (bpf_stack_adjust)
154 		emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
155 
156 	/*
157 	 * Program contains calls and tail calls, so REG_TCC need
158 	 * to be saved across calls.
159 	 */
160 	if (seen_tail_call(ctx) && seen_call(ctx))
161 		move_reg(ctx, regmap[TCC_SAVED], regmap[REG_TCC]);
162 
163 	ctx->stack_size = stack_adjust;
164 }
165 
__build_epilogue(struct jit_ctx *ctx, bool is_tail_call)166 static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
167 {
168 	int stack_adjust = ctx->stack_size;
169 	int load_offset;
170 
171 	load_offset = stack_adjust - sizeof(long);
172 	emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, load_offset);
173 
174 	load_offset -= sizeof(long);
175 	emit_insn(ctx, ldd, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, load_offset);
176 
177 	load_offset -= sizeof(long);
178 	emit_insn(ctx, ldd, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, load_offset);
179 
180 	load_offset -= sizeof(long);
181 	emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, load_offset);
182 
183 	load_offset -= sizeof(long);
184 	emit_insn(ctx, ldd, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, load_offset);
185 
186 	load_offset -= sizeof(long);
187 	emit_insn(ctx, ldd, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, load_offset);
188 
189 	load_offset -= sizeof(long);
190 	emit_insn(ctx, ldd, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, load_offset);
191 
192 	load_offset -= sizeof(long);
193 	emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
194 
195 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
196 
197 	if (!is_tail_call) {
198 		/* Set return value */
199 		move_reg(ctx, LOONGARCH_GPR_A0, regmap[BPF_REG_0]);
200 		/* Return to the caller */
201 		emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_RA, 0);
202 	} else {
203 		/*
204 		 * Call the next bpf prog and skip the first instruction
205 		 * of TCC initialization.
206 		 */
207 		emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, regmap[TMP_REG_3], 1);
208 	}
209 }
210 
build_epilogue(struct jit_ctx *ctx)211 void build_epilogue(struct jit_ctx *ctx)
212 {
213 	__build_epilogue(ctx, false);
214 }
215 
216 /* initialized on the first pass of build_body() */
217 static int out_offset = -1;
emit_bpf_tail_call(struct jit_ctx *ctx)218 static int emit_bpf_tail_call(struct jit_ctx *ctx)
219 {
220 	int off;
221 	u8 tcc = tail_call_reg(ctx);
222 	u8 a1 = LOONGARCH_GPR_A1;
223 	u8 a2 = LOONGARCH_GPR_A2;
224 	u8 tmp1 = regmap[TMP_REG_1];
225 	u8 tmp2 = regmap[TMP_REG_2];
226 	u8 tmp3 = regmap[TMP_REG_3];
227 	const int idx0 = ctx->idx;
228 
229 #define cur_offset (ctx->idx - idx0)
230 #define jmp_offset (out_offset - (cur_offset))
231 
232 	/*
233 	 * a0: &ctx
234 	 * a1: &array
235 	 * a2: index
236 	 *
237 	 * if (index >= array->map.max_entries)
238 	 *	 goto out;
239 	 */
240 	off = offsetof(struct bpf_array, map.max_entries);
241 	emit_insn(ctx, ldwu, tmp1, a1, off);
242 	/* bgeu $a2, $t1, jmp_offset */
243 	emit_tailcall_jump(ctx, BPF_JGE, a2, tmp1, jmp_offset);
244 
245 	/*
246 	 * if (TCC-- < 0)
247 	 *	 goto out;
248 	 */
249 	emit_insn(ctx, addid, tmp1, tcc, -1);
250 	emit_tailcall_jump(ctx, BPF_JSLT, tcc, LOONGARCH_GPR_ZERO, jmp_offset);
251 
252 	/*
253 	 * prog = array->ptrs[index];
254 	 * if (!prog)
255 	 *	 goto out;
256 	 */
257 	emit_insn(ctx, sllid, tmp2, a2, 3);
258 	emit_insn(ctx, addd, tmp2, tmp2, a1);
259 	off = offsetof(struct bpf_array, ptrs);
260 	emit_insn(ctx, ldd, tmp2, tmp2, off);
261 	/* beq $t2, $zero, jmp_offset */
262 	emit_tailcall_jump(ctx, BPF_JEQ, tmp2, LOONGARCH_GPR_ZERO, jmp_offset);
263 
264 	/* goto *(prog->bpf_func + 4); */
265 	off = offsetof(struct bpf_prog, bpf_func);
266 	emit_insn(ctx, ldd, tmp3, tmp2, off);
267 	move_reg(ctx, tcc, tmp1);
268 	__build_epilogue(ctx, true);
269 
270 	/* out: */
271 	if (out_offset == -1)
272 		out_offset = cur_offset;
273 	if (cur_offset != out_offset) {
274 		pr_err_once("tail_call out_offset = %d, expected %d!\n",
275 			     cur_offset, out_offset);
276 		return -1;
277 	}
278 
279 	return 0;
280 #undef cur_offset
281 #undef jmp_offset
282 }
283 
build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool extra_pass)284 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool extra_pass)
285 {
286 	bool is32 = (BPF_CLASS(insn->code) == BPF_ALU);
287 	const u8 code = insn->code;
288 	const u8 cond = BPF_OP(code);
289 	const u8 dst = regmap[insn->dst_reg];
290 	const u8 src = regmap[insn->src_reg];
291 	const u8 tmp = regmap[TMP_REG_1];
292 	const u8 tmp2 = regmap[TMP_REG_2];
293 	const s16 off = insn->off;
294 	const s32 imm = insn->imm;
295 	int i = insn - ctx->prog->insnsi;
296 	int jmp_offset;
297 	bool func_addr_fixed;
298 	u64 func_addr;
299 	u64 imm64;
300 	int ret;
301 
302 	switch (code) {
303 	/* dst = src */
304 	case BPF_ALU | BPF_MOV | BPF_X:
305 	case BPF_ALU64 | BPF_MOV | BPF_X:
306 		move_reg(ctx, dst, src);
307 		emit_zext_32(ctx, dst, is32);
308 		break;
309 	/* dst = imm */
310 	case BPF_ALU | BPF_MOV | BPF_K:
311 	case BPF_ALU64 | BPF_MOV | BPF_K:
312 		move_imm32(ctx, dst, imm, is32);
313 		break;
314 
315 	/* dst = dst + src */
316 	case BPF_ALU | BPF_ADD | BPF_X:
317 	case BPF_ALU64 | BPF_ADD | BPF_X:
318 		emit_insn(ctx, addd, dst, dst, src);
319 		emit_zext_32(ctx, dst, is32);
320 		break;
321 	/* dst = dst + imm */
322 	case BPF_ALU | BPF_ADD | BPF_K:
323 	case BPF_ALU64 | BPF_ADD | BPF_K:
324 		if (is_signed_imm12(imm)) {
325 			emit_insn(ctx, addid, dst, dst, imm);
326 		} else {
327 			move_imm32(ctx, tmp, imm, is32);
328 			emit_insn(ctx, addd, dst, dst, tmp);
329 		}
330 		emit_zext_32(ctx, dst, is32);
331 		break;
332 
333 	/* dst = dst - src */
334 	case BPF_ALU | BPF_SUB | BPF_X:
335 	case BPF_ALU64 | BPF_SUB | BPF_X:
336 		emit_insn(ctx, subd, dst, dst, src);
337 		emit_zext_32(ctx, dst, is32);
338 		break;
339 	/* dst = dst - imm */
340 	case BPF_ALU | BPF_SUB | BPF_K:
341 	case BPF_ALU64 | BPF_SUB | BPF_K:
342 		if (is_signed_imm12(-imm)) {
343 			emit_insn(ctx, addid, dst, dst, -imm);
344 		} else {
345 			move_imm32(ctx, tmp, imm, is32);
346 			emit_insn(ctx, subd, dst, dst, tmp);
347 		}
348 		emit_zext_32(ctx, dst, is32);
349 		break;
350 
351 	/* dst = dst * src */
352 	case BPF_ALU | BPF_MUL | BPF_X:
353 	case BPF_ALU64 | BPF_MUL | BPF_X:
354 		emit_insn(ctx, muld, dst, dst, src);
355 		emit_zext_32(ctx, dst, is32);
356 		break;
357 	/* dst = dst * imm */
358 	case BPF_ALU | BPF_MUL | BPF_K:
359 	case BPF_ALU64 | BPF_MUL | BPF_K:
360 		move_imm32(ctx, tmp, imm, is32);
361 		emit_insn(ctx, muld, dst, dst, tmp);
362 		emit_zext_32(ctx, dst, is32);
363 		break;
364 
365 	/* dst = dst / src */
366 	case BPF_ALU | BPF_DIV | BPF_X:
367 	case BPF_ALU64 | BPF_DIV | BPF_X:
368 		emit_insn(ctx, divdu, dst, dst, src);
369 		emit_zext_32(ctx, dst, is32);
370 		break;
371 	/* dst = dst / imm */
372 	case BPF_ALU | BPF_DIV | BPF_K:
373 	case BPF_ALU64 | BPF_DIV | BPF_K:
374 		move_imm32(ctx, tmp, imm, is32);
375 		emit_insn(ctx, divdu, dst, dst, tmp);
376 		emit_zext_32(ctx, dst, is32);
377 		break;
378 
379 	/* dst = dst % src */
380 	case BPF_ALU | BPF_MOD | BPF_X:
381 	case BPF_ALU64 | BPF_MOD | BPF_X:
382 		emit_insn(ctx, moddu, dst, dst, src);
383 		emit_zext_32(ctx, dst, is32);
384 		break;
385 	/* dst = dst % imm */
386 	case BPF_ALU | BPF_MOD | BPF_K:
387 	case BPF_ALU64 | BPF_MOD | BPF_K:
388 		move_imm32(ctx, tmp, imm, is32);
389 		emit_insn(ctx, moddu, dst, dst, tmp);
390 		emit_zext_32(ctx, dst, is32);
391 		break;
392 
393 	/* dst = -dst */
394 	case BPF_ALU | BPF_NEG:
395 	case BPF_ALU64 | BPF_NEG:
396 		move_imm32(ctx, tmp, imm, is32);
397 		emit_insn(ctx, subd, dst, LOONGARCH_GPR_ZERO, dst);
398 		emit_zext_32(ctx, dst, is32);
399 		break;
400 
401 	/* dst = dst & src */
402 	case BPF_ALU | BPF_AND | BPF_X:
403 	case BPF_ALU64 | BPF_AND | BPF_X:
404 		emit_insn(ctx, and, dst, dst, src);
405 		emit_zext_32(ctx, dst, is32);
406 		break;
407 	/* dst = dst & imm */
408 	case BPF_ALU | BPF_AND | BPF_K:
409 	case BPF_ALU64 | BPF_AND | BPF_K:
410 		if (is_unsigned_imm12(imm)) {
411 			emit_insn(ctx, andi, dst, dst, imm);
412 		} else {
413 			move_imm32(ctx, tmp, imm, is32);
414 			emit_insn(ctx, and, dst, dst, tmp);
415 		}
416 		emit_zext_32(ctx, dst, is32);
417 		break;
418 
419 	/* dst = dst | src */
420 	case BPF_ALU | BPF_OR | BPF_X:
421 	case BPF_ALU64 | BPF_OR | BPF_X:
422 		emit_insn(ctx, or, dst, dst, src);
423 		emit_zext_32(ctx, dst, is32);
424 		break;
425 	/* dst = dst | imm */
426 	case BPF_ALU | BPF_OR | BPF_K:
427 	case BPF_ALU64 | BPF_OR | BPF_K:
428 		if (is_unsigned_imm12(imm)) {
429 			emit_insn(ctx, ori, dst, dst, imm);
430 		} else {
431 			move_imm32(ctx, tmp, imm, is32);
432 			emit_insn(ctx, or, dst, dst, tmp);
433 		}
434 		emit_zext_32(ctx, dst, is32);
435 		break;
436 
437 	/* dst = dst ^ src */
438 	case BPF_ALU | BPF_XOR | BPF_X:
439 	case BPF_ALU64 | BPF_XOR | BPF_X:
440 		emit_insn(ctx, xor, dst, dst, src);
441 		emit_zext_32(ctx, dst, is32);
442 		break;
443 	/* dst = dst ^ imm */
444 	case BPF_ALU | BPF_XOR | BPF_K:
445 	case BPF_ALU64 | BPF_XOR | BPF_K:
446 		if (is_unsigned_imm12(imm)) {
447 			emit_insn(ctx, xori, dst, dst, imm);
448 		} else {
449 			move_imm32(ctx, tmp, imm, is32);
450 			emit_insn(ctx, xor, dst, dst, tmp);
451 		}
452 		emit_zext_32(ctx, dst, is32);
453 		break;
454 
455 	/* dst = dst << src (logical) */
456 	case BPF_ALU | BPF_LSH | BPF_X:
457 		emit_insn(ctx, sllw, dst, dst, src);
458 		emit_zext_32(ctx, dst, is32);
459 		break;
460 	case BPF_ALU64 | BPF_LSH | BPF_X:
461 		emit_insn(ctx, slld, dst, dst, src);
462 		break;
463 	/* dst = dst << imm (logical) */
464 	case BPF_ALU | BPF_LSH | BPF_K:
465 		emit_insn(ctx, slliw, dst, dst, imm);
466 		emit_zext_32(ctx, dst, is32);
467 		break;
468 	case BPF_ALU64 | BPF_LSH | BPF_K:
469 		emit_insn(ctx, sllid, dst, dst, imm);
470 		break;
471 
472 	/* dst = dst >> src (logical) */
473 	case BPF_ALU | BPF_RSH | BPF_X:
474 		emit_insn(ctx, srlw, dst, dst, src);
475 		emit_zext_32(ctx, dst, is32);
476 		break;
477 	case BPF_ALU64 | BPF_RSH | BPF_X:
478 		emit_insn(ctx, srld, dst, dst, src);
479 		break;
480 	/* dst = dst >> imm (logical) */
481 	case BPF_ALU | BPF_RSH | BPF_K:
482 		emit_insn(ctx, srliw, dst, dst, imm);
483 		emit_zext_32(ctx, dst, is32);
484 		break;
485 	case BPF_ALU64 | BPF_RSH | BPF_K:
486 		emit_insn(ctx, srlid, dst, dst, imm);
487 		break;
488 
489 	/* dst = dst >> src (arithmetic) */
490 	case BPF_ALU | BPF_ARSH | BPF_X:
491 		emit_insn(ctx, sraw, dst, dst, src);
492 		emit_zext_32(ctx, dst, is32);
493 		break;
494 	case BPF_ALU64 | BPF_ARSH | BPF_X:
495 		emit_insn(ctx, srad, dst, dst, src);
496 		break;
497 	/* dst = dst >> imm (arithmetic) */
498 	case BPF_ALU | BPF_ARSH | BPF_K:
499 		emit_insn(ctx, sraiw, dst, dst, imm);
500 		emit_zext_32(ctx, dst, is32);
501 		break;
502 	case BPF_ALU64 | BPF_ARSH | BPF_K:
503 		emit_insn(ctx, sraid, dst, dst, imm);
504 		break;
505 
506 	/* dst = BSWAP##imm(dst) */
507 	case BPF_ALU | BPF_END | BPF_FROM_LE:
508 		switch (imm) {
509 		case 16:
510 			/* zero-extend 16 bits into 64 bits */
511 			emit_insn(ctx, sllid, dst, dst, 48);
512 			emit_insn(ctx, srlid, dst, dst, 48);
513 			break;
514 		case 32:
515 			/* zero-extend 32 bits into 64 bits */
516 			emit_zext_32(ctx, dst, is32);
517 			break;
518 		case 64:
519 			/* do nothing */
520 			break;
521 		}
522 		break;
523 	case BPF_ALU | BPF_END | BPF_FROM_BE:
524 		switch (imm) {
525 		case 16:
526 			emit_insn(ctx, revb2h, dst, dst);
527 			/* zero-extend 16 bits into 64 bits */
528 			emit_insn(ctx, sllid, dst, dst, 48);
529 			emit_insn(ctx, srlid, dst, dst, 48);
530 			break;
531 		case 32:
532 			emit_insn(ctx, revb2w, dst, dst);
533 			/* zero-extend 32 bits into 64 bits */
534 			emit_zext_32(ctx, dst, is32);
535 			break;
536 		case 64:
537 			emit_insn(ctx, revbd, dst, dst);
538 			break;
539 		}
540 		break;
541 
542 	/* PC += off if dst cond src */
543 	case BPF_JMP | BPF_JEQ | BPF_X:
544 	case BPF_JMP | BPF_JNE | BPF_X:
545 	case BPF_JMP | BPF_JGT | BPF_X:
546 	case BPF_JMP | BPF_JGE | BPF_X:
547 	case BPF_JMP | BPF_JLT | BPF_X:
548 	case BPF_JMP | BPF_JLE | BPF_X:
549 	case BPF_JMP | BPF_JSGT | BPF_X:
550 	case BPF_JMP | BPF_JSGE | BPF_X:
551 	case BPF_JMP | BPF_JSLT | BPF_X:
552 	case BPF_JMP | BPF_JSLE | BPF_X:
553 		jmp_offset = bpf2la_offset(i, off, ctx);
554 		emit_cond_jump(ctx, cond, dst, src, jmp_offset);
555 		break;
556 
557 	/* PC += off if dst cond imm */
558 	case BPF_JMP | BPF_JEQ | BPF_K:
559 	case BPF_JMP | BPF_JNE | BPF_K:
560 	case BPF_JMP | BPF_JGT | BPF_K:
561 	case BPF_JMP | BPF_JGE | BPF_K:
562 	case BPF_JMP | BPF_JLT | BPF_K:
563 	case BPF_JMP | BPF_JLE | BPF_K:
564 	case BPF_JMP | BPF_JSGT | BPF_K:
565 	case BPF_JMP | BPF_JSGE | BPF_K:
566 	case BPF_JMP | BPF_JSLT | BPF_K:
567 	case BPF_JMP | BPF_JSLE | BPF_K:
568 		jmp_offset = bpf2la_offset(i, off, ctx);
569 		move_imm32(ctx, tmp, imm, is32);
570 		emit_cond_jump(ctx, cond, dst, tmp, jmp_offset);
571 		break;
572 
573 	/* PC += off if dst & src */
574 	case BPF_JMP | BPF_JSET | BPF_X:
575 		jmp_offset = bpf2la_offset(i, off, ctx);
576 		emit_insn(ctx, and, tmp, dst, src);
577 		emit_cond_jump(ctx, cond, tmp, LOONGARCH_GPR_ZERO, jmp_offset);
578 		break;
579 	/* PC += off if dst & imm */
580 	case BPF_JMP | BPF_JSET | BPF_K:
581 		jmp_offset = bpf2la_offset(i, off, ctx);
582 		move_imm32(ctx, tmp, imm, is32);
583 		emit_insn(ctx, and, tmp, dst, tmp);
584 		emit_cond_jump(ctx, cond, tmp, LOONGARCH_GPR_ZERO, jmp_offset);
585 		break;
586 
587 	/* PC += off */
588 	case BPF_JMP | BPF_JA:
589 		jmp_offset = bpf2la_offset(i, off, ctx);
590 		emit_uncond_jump(ctx, jmp_offset, false);
591 		break;
592 
593 	/* function call */
594 	case BPF_JMP | BPF_CALL:
595 		mark_call(ctx);
596 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
597 					    &func_addr, &func_addr_fixed);
598 		if (ret < 0)
599 			return ret;
600 
601 		move_imm64(ctx, tmp, func_addr, is32);
602 		emit_insn(ctx, jirl, LOONGARCH_GPR_RA, tmp, 0);
603 		move_reg(ctx, regmap[BPF_REG_0], LOONGARCH_GPR_A0);
604 		break;
605 
606 	/* tail call */
607 	case BPF_JMP | BPF_TAIL_CALL:
608 		mark_tail_call(ctx);
609 		if (emit_bpf_tail_call(ctx))
610 			return -EINVAL;
611 		break;
612 
613 	/* function return */
614 	case BPF_JMP | BPF_EXIT:
615 		emit_sext_32(ctx, regmap[BPF_REG_0]);
616 		/*
617 		 * Optimization: when last instruction is EXIT,
618 		 * simply fallthrough to epilogue.
619 		 */
620 		if (i == ctx->prog->len - 1)
621 			break;
622 
623 		jmp_offset = epilogue_offset(ctx);
624 		emit_uncond_jump(ctx, jmp_offset, true);
625 		break;
626 
627 	/* dst = imm64 */
628 	case BPF_LD | BPF_IMM | BPF_DW:
629 		imm64 = (u64)(insn + 1)->imm << 32 | (u32)insn->imm;
630 		move_imm64(ctx, dst, imm64, is32);
631 		return 1;
632 
633 	/* dst = *(size *)(src + off) */
634 	case BPF_LDX | BPF_MEM | BPF_B:
635 	case BPF_LDX | BPF_MEM | BPF_H:
636 	case BPF_LDX | BPF_MEM | BPF_W:
637 	case BPF_LDX | BPF_MEM | BPF_DW:
638 		if (is_signed_imm12(off)) {
639 			switch (BPF_SIZE(code)) {
640 			case BPF_B:
641 				emit_insn(ctx, ldbu, dst, src, off);
642 				break;
643 			case BPF_H:
644 				emit_insn(ctx, ldhu, dst, src, off);
645 				break;
646 			case BPF_W:
647 				emit_insn(ctx, ldwu, dst, src, off);
648 				break;
649 			case BPF_DW:
650 				emit_insn(ctx, ldd, dst, src, off);
651 				break;
652 			}
653 		} else {
654 			move_imm32(ctx, tmp, off, is32);
655 			switch (BPF_SIZE(code)) {
656 			case BPF_B:
657 				emit_insn(ctx, ldxbu, dst, src, tmp);
658 				break;
659 			case BPF_H:
660 				emit_insn(ctx, ldxhu, dst, src, tmp);
661 				break;
662 			case BPF_W:
663 				emit_insn(ctx, ldxwu, dst, src, tmp);
664 				break;
665 			case BPF_DW:
666 				emit_insn(ctx, ldxd, dst, src, tmp);
667 				break;
668 			}
669 		}
670 		break;
671 
672 	/* *(size *)(dst + off) = imm */
673 	case BPF_ST | BPF_MEM | BPF_B:
674 	case BPF_ST | BPF_MEM | BPF_H:
675 	case BPF_ST | BPF_MEM | BPF_W:
676 	case BPF_ST | BPF_MEM | BPF_DW:
677 		move_imm32(ctx, tmp, imm, is32);
678 		if (is_signed_imm12(off)) {
679 			switch (BPF_SIZE(code)) {
680 			case BPF_B:
681 				emit_insn(ctx, stb, tmp, dst, off);
682 				break;
683 			case BPF_H:
684 				emit_insn(ctx, sth, tmp, dst, off);
685 				break;
686 			case BPF_W:
687 				emit_insn(ctx, stw, tmp, dst, off);
688 				break;
689 			case BPF_DW:
690 				emit_insn(ctx, std, tmp, dst, off);
691 				break;
692 			}
693 		} else {
694 			move_imm32(ctx, tmp2, off, is32);
695 			switch (BPF_SIZE(code)) {
696 			case BPF_B:
697 				emit_insn(ctx, stxb, tmp, dst, tmp2);
698 				break;
699 			case BPF_H:
700 				emit_insn(ctx, stxh, tmp, dst, tmp2);
701 				break;
702 			case BPF_W:
703 				emit_insn(ctx, stxw, tmp, dst, tmp2);
704 				break;
705 			case BPF_DW:
706 				emit_insn(ctx, stxd, tmp, dst, tmp2);
707 				break;
708 			}
709 		}
710 		break;
711 
712 	/* *(size *)(dst + off) = src */
713 	case BPF_STX | BPF_MEM | BPF_B:
714 	case BPF_STX | BPF_MEM | BPF_H:
715 	case BPF_STX | BPF_MEM | BPF_W:
716 	case BPF_STX | BPF_MEM | BPF_DW:
717 		if (is_signed_imm12(off)) {
718 			switch (BPF_SIZE(code)) {
719 			case BPF_B:
720 				emit_insn(ctx, stb, src, dst, off);
721 				break;
722 			case BPF_H:
723 				emit_insn(ctx, sth, src, dst, off);
724 				break;
725 			case BPF_W:
726 				emit_insn(ctx, stw, src, dst, off);
727 				break;
728 			case BPF_DW:
729 				emit_insn(ctx, std, src, dst, off);
730 				break;
731 			}
732 		} else {
733 			move_imm32(ctx, tmp, off, is32);
734 			switch (BPF_SIZE(code)) {
735 			case BPF_B:
736 				emit_insn(ctx, stxb, src, dst, tmp);
737 				break;
738 			case BPF_H:
739 				emit_insn(ctx, stxh, src, dst, tmp);
740 				break;
741 			case BPF_W:
742 				emit_insn(ctx, stxw, src, dst, tmp);
743 				break;
744 			case BPF_DW:
745 				emit_insn(ctx, stxd, src, dst, tmp);
746 				break;
747 			}
748 		}
749 		break;
750 
751 	/* atomic_add: lock *(size *)(dst + off) += src */
752 	case BPF_STX | BPF_XADD | BPF_W:
753 	case BPF_STX | BPF_XADD | BPF_DW:
754 		if (insn->imm != BPF_ADD) {
755 			pr_err_once("unknown atomic op code %02x\n", insn->imm);
756 			return -EINVAL;
757 		}
758 
759 		move_imm32(ctx, tmp, off, is32);
760 		emit_insn(ctx, addd, tmp, dst, tmp);
761 		switch (BPF_SIZE(insn->code)) {
762 		case BPF_W:
763 			emit_insn(ctx, amaddw, tmp2, src, tmp);
764 			break;
765 		case BPF_DW:
766 			emit_insn(ctx, amaddd, tmp2, src, tmp);
767 			break;
768 		}
769 		break;
770 
771 	default:
772 		pr_err("bpf_jit: unknown opcode %02x\n", code);
773 		return -EINVAL;
774 	}
775 
776 	return 0;
777 }
778 
build_body(struct jit_ctx *ctx, bool extra_pass)779 static int build_body(struct jit_ctx *ctx, bool extra_pass)
780 {
781 	const struct bpf_prog *prog = ctx->prog;
782 	int i;
783 
784 	for (i = 0; i < prog->len; i++) {
785 		const struct bpf_insn *insn = &prog->insnsi[i];
786 		int ret;
787 
788 		if (!ctx->image)
789 			ctx->offset[i] = ctx->idx;
790 
791 		ret = build_insn(insn, ctx, extra_pass);
792 		if (ret > 0) {
793 			i++;
794 			if (!ctx->image)
795 				ctx->offset[i] = ctx->idx;
796 			continue;
797 		}
798 		if (ret)
799 			return ret;
800 	}
801 
802 	if (!ctx->image)
803 		ctx->offset[i] = ctx->idx;
804 
805 	return 0;
806 }
807 
bpf_flush_icache(void *start, void *end)808 static inline void bpf_flush_icache(void *start, void *end)
809 {
810 	flush_icache_range((unsigned long)start, (unsigned long)end);
811 }
812 
813 /* Fill space with illegal instructions */
jit_fill_hole(void *area, unsigned int size)814 static void jit_fill_hole(void *area, unsigned int size)
815 {
816 	u32 *ptr;
817 
818 	/* We are guaranteed to have aligned memory */
819 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
820 		*ptr++ = INSN_BREAK;
821 }
822 
validate_code(struct jit_ctx *ctx)823 static int validate_code(struct jit_ctx *ctx)
824 {
825 	int i;
826 	union loongarch_instruction insn;
827 
828 	for (i = 0; i < ctx->idx; i++) {
829 		insn = ctx->image[i];
830 		/* Check INSN_BREAK */
831 		if (insn.word == INSN_BREAK)
832 			return -1;
833 	}
834 
835 	return 0;
836 }
837 
bpf_int_jit_compile(struct bpf_prog *prog)838 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
839 {
840 	struct bpf_prog *tmp, *orig_prog = prog;
841 	struct bpf_binary_header *header;
842 	struct jit_data *jit_data;
843 	struct jit_ctx ctx;
844 	bool tmp_blinded = false;
845 	bool extra_pass = false;
846 	int image_size;
847 	u8 *image_ptr;
848 
849 	/*
850 	 * If BPF JIT was not enabled then we must fall back to
851 	 * the interpreter.
852 	 */
853 	if (!prog->jit_requested)
854 		return orig_prog;
855 
856 	tmp = bpf_jit_blind_constants(prog);
857 	/*
858 	 * If blinding was requested and we failed during blinding,
859 	 * we must fall back to the interpreter. Otherwise, we save
860 	 * the new JITed code.
861 	 */
862 	if (IS_ERR(tmp))
863 		return orig_prog;
864 	if (tmp != prog) {
865 		tmp_blinded = true;
866 		prog = tmp;
867 	}
868 
869 	jit_data = prog->aux->jit_data;
870 	if (!jit_data) {
871 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
872 		if (!jit_data) {
873 			prog = orig_prog;
874 			goto out;
875 		}
876 		prog->aux->jit_data = jit_data;
877 	}
878 	if (jit_data->ctx.offset) {
879 		ctx = jit_data->ctx;
880 		image_ptr = jit_data->image;
881 		header = jit_data->header;
882 		extra_pass = true;
883 		image_size = sizeof(u32) * ctx.idx;
884 		goto skip_init_ctx;
885 	}
886 
887 	memset(&ctx, 0, sizeof(ctx));
888 	ctx.prog = prog;
889 
890 	ctx.offset = kcalloc(prog->len + 1, sizeof(*ctx.offset), GFP_KERNEL);
891 	if (!ctx.offset) {
892 		prog = orig_prog;
893 		goto out_off;
894 	}
895 
896 	/* 1. Initial fake pass to compute ctx->idx and set ctx->flags */
897 	if (build_body(&ctx, extra_pass)) {
898 		prog = orig_prog;
899 		goto out_off;
900 	}
901 	build_prologue(&ctx);
902 	ctx.epilogue_offset = ctx.idx;
903 	build_epilogue(&ctx);
904 
905 	/* Now we know the actual image size.
906 	 * As each LoongArch instruction is of length 32bit,
907 	 * we are translating number of JITed intructions into
908 	 * the size required to store these JITed code.
909 	 */
910 	image_size = sizeof(u32) * ctx.idx;
911 	/* Now we know the size of the structure to make */
912 	header = bpf_jit_binary_alloc(image_size, &image_ptr,
913 				      sizeof(u32), jit_fill_hole);
914 	if (!header) {
915 		prog = orig_prog;
916 		goto out_off;
917 	}
918 
919 	/* 2. Now, the actual pass to generate final JIT code */
920 	ctx.image = (union loongarch_instruction *)image_ptr;
921 skip_init_ctx:
922 	ctx.idx = 0;
923 
924 	build_prologue(&ctx);
925 	if (build_body(&ctx, extra_pass)) {
926 		bpf_jit_binary_free(header);
927 		prog = orig_prog;
928 		goto out_off;
929 	}
930 	build_epilogue(&ctx);
931 
932 	/* 3. Extra pass to validate JITed code */
933 	if (validate_code(&ctx)) {
934 		bpf_jit_binary_free(header);
935 		prog = orig_prog;
936 		goto out_off;
937 	}
938 
939 	/* And we're done */
940 	if (bpf_jit_enable > 1)
941 		bpf_jit_dump(prog->len, image_size, 2, ctx.image);
942 
943 	/* Update the icache */
944 	bpf_flush_icache(header, ctx.image + ctx.idx);
945 
946 	if (!prog->is_func || extra_pass) {
947 		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
948 			pr_err_once("multi-func JIT bug %d != %d\n",
949 				    ctx.idx, jit_data->ctx.idx);
950 			bpf_jit_binary_free(header);
951 			prog->bpf_func = NULL;
952 			prog->jited = 0;
953 			goto out_off;
954 		}
955 		bpf_jit_binary_lock_ro(header);
956 	} else {
957 		jit_data->ctx = ctx;
958 		jit_data->image = image_ptr;
959 		jit_data->header = header;
960 	}
961 	prog->bpf_func = (void *)ctx.image;
962 	prog->jited = 1;
963 	prog->jited_len = image_size;
964 
965 	if (!prog->is_func || extra_pass) {
966 out_off:
967 		kfree(ctx.offset);
968 		kfree(jit_data);
969 		prog->aux->jit_data = NULL;
970 	}
971 out:
972 	if (tmp_blinded)
973 		bpf_jit_prog_release_other(prog, prog == orig_prog ?
974 					   tmp : orig_prog);
975 
976 	out_offset = -1;
977 	return prog;
978 }
979