1// SPDX-License-Identifier: GPL-2.0
2/* BPF JIT compiler for RV64G
3 *
4 * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5 *
6 */
7
8#include <linux/bpf.h>
9#include <linux/filter.h>
10#include "bpf_jit.h"
11
12#define RV_REG_TCC RV_REG_A6
13#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
14
15static const int regmap[] = {
16	[BPF_REG_0] =	RV_REG_A5,
17	[BPF_REG_1] =	RV_REG_A0,
18	[BPF_REG_2] =	RV_REG_A1,
19	[BPF_REG_3] =	RV_REG_A2,
20	[BPF_REG_4] =	RV_REG_A3,
21	[BPF_REG_5] =	RV_REG_A4,
22	[BPF_REG_6] =	RV_REG_S1,
23	[BPF_REG_7] =	RV_REG_S2,
24	[BPF_REG_8] =	RV_REG_S3,
25	[BPF_REG_9] =	RV_REG_S4,
26	[BPF_REG_FP] =	RV_REG_S5,
27	[BPF_REG_AX] =	RV_REG_T0,
28};
29
30enum {
31	RV_CTX_F_SEEN_TAIL_CALL =	0,
32	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
33	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
34	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
35	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
36	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
37	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
38	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
39};
40
41static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
42{
43	u8 reg = regmap[bpf_reg];
44
45	switch (reg) {
46	case RV_CTX_F_SEEN_S1:
47	case RV_CTX_F_SEEN_S2:
48	case RV_CTX_F_SEEN_S3:
49	case RV_CTX_F_SEEN_S4:
50	case RV_CTX_F_SEEN_S5:
51	case RV_CTX_F_SEEN_S6:
52		__set_bit(reg, &ctx->flags);
53	}
54	return reg;
55};
56
57static bool seen_reg(int reg, struct rv_jit_context *ctx)
58{
59	switch (reg) {
60	case RV_CTX_F_SEEN_CALL:
61	case RV_CTX_F_SEEN_S1:
62	case RV_CTX_F_SEEN_S2:
63	case RV_CTX_F_SEEN_S3:
64	case RV_CTX_F_SEEN_S4:
65	case RV_CTX_F_SEEN_S5:
66	case RV_CTX_F_SEEN_S6:
67		return test_bit(reg, &ctx->flags);
68	}
69	return false;
70}
71
72static void mark_fp(struct rv_jit_context *ctx)
73{
74	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
75}
76
77static void mark_call(struct rv_jit_context *ctx)
78{
79	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
80}
81
82static bool seen_call(struct rv_jit_context *ctx)
83{
84	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
85}
86
87static void mark_tail_call(struct rv_jit_context *ctx)
88{
89	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
90}
91
92static bool seen_tail_call(struct rv_jit_context *ctx)
93{
94	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
95}
96
97static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
98{
99	mark_tail_call(ctx);
100
101	if (seen_call(ctx)) {
102		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
103		return RV_REG_S6;
104	}
105	return RV_REG_A6;
106}
107
108static bool is_32b_int(s64 val)
109{
110	return -(1L << 31) <= val && val < (1L << 31);
111}
112
113static bool in_auipc_jalr_range(s64 val)
114{
115	/*
116	 * auipc+jalr can reach any signed PC-relative offset in the range
117	 * [-2^31 - 2^11, 2^31 - 2^11).
118	 */
119	return (-(1L << 31) - (1L << 11)) <= val &&
120		val < ((1L << 31) - (1L << 11));
121}
122
123static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
124{
125	/* Note that the immediate from the add is sign-extended,
126	 * which means that we need to compensate this by adding 2^12,
127	 * when the 12th bit is set. A simpler way of doing this, and
128	 * getting rid of the check, is to just add 2**11 before the
129	 * shift. The "Loading a 32-Bit constant" example from the
130	 * "Computer Organization and Design, RISC-V edition" book by
131	 * Patterson/Hennessy highlights this fact.
132	 *
133	 * This also means that we need to process LSB to MSB.
134	 */
135	s64 upper = (val + (1 << 11)) >> 12;
136	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
137	 * and addi are signed and RVC checks will perform signed comparisons.
138	 */
139	s64 lower = ((val & 0xfff) << 52) >> 52;
140	int shift;
141
142	if (is_32b_int(val)) {
143		if (upper)
144			emit_lui(rd, upper, ctx);
145
146		if (!upper) {
147			emit_li(rd, lower, ctx);
148			return;
149		}
150
151		emit_addiw(rd, rd, lower, ctx);
152		return;
153	}
154
155	shift = __ffs(upper);
156	upper >>= shift;
157	shift += 12;
158
159	emit_imm(rd, upper, ctx);
160
161	emit_slli(rd, rd, shift, ctx);
162	if (lower)
163		emit_addi(rd, rd, lower, ctx);
164}
165
166static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
167{
168	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
169
170	if (seen_reg(RV_REG_RA, ctx)) {
171		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
172		store_offset -= 8;
173	}
174	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
175	store_offset -= 8;
176	if (seen_reg(RV_REG_S1, ctx)) {
177		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
178		store_offset -= 8;
179	}
180	if (seen_reg(RV_REG_S2, ctx)) {
181		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
182		store_offset -= 8;
183	}
184	if (seen_reg(RV_REG_S3, ctx)) {
185		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
186		store_offset -= 8;
187	}
188	if (seen_reg(RV_REG_S4, ctx)) {
189		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
190		store_offset -= 8;
191	}
192	if (seen_reg(RV_REG_S5, ctx)) {
193		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
194		store_offset -= 8;
195	}
196	if (seen_reg(RV_REG_S6, ctx)) {
197		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
198		store_offset -= 8;
199	}
200
201	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
202	/* Set return value. */
203	if (!is_tail_call)
204		emit_addiw(RV_REG_A0, RV_REG_A5, 0, ctx);
205	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
206		  is_tail_call ? 4 : 0, /* skip TCC init */
207		  ctx);
208}
209
210static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
211		     struct rv_jit_context *ctx)
212{
213	switch (cond) {
214	case BPF_JEQ:
215		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
216		return;
217	case BPF_JGT:
218		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
219		return;
220	case BPF_JLT:
221		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
222		return;
223	case BPF_JGE:
224		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
225		return;
226	case BPF_JLE:
227		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
228		return;
229	case BPF_JNE:
230		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
231		return;
232	case BPF_JSGT:
233		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
234		return;
235	case BPF_JSLT:
236		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
237		return;
238	case BPF_JSGE:
239		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
240		return;
241	case BPF_JSLE:
242		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
243	}
244}
245
246static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
247			struct rv_jit_context *ctx)
248{
249	s64 upper, lower;
250
251	if (is_13b_int(rvoff)) {
252		emit_bcc(cond, rd, rs, rvoff, ctx);
253		return;
254	}
255
256	/* Adjust for jal */
257	rvoff -= 4;
258
259	/* Transform, e.g.:
260	 *   bne rd,rs,foo
261	 * to
262	 *   beq rd,rs,<.L1>
263	 *   (auipc foo)
264	 *   jal(r) foo
265	 * .L1
266	 */
267	cond = invert_bpf_cond(cond);
268	if (is_21b_int(rvoff)) {
269		emit_bcc(cond, rd, rs, 8, ctx);
270		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
271		return;
272	}
273
274	/* 32b No need for an additional rvoff adjustment, since we
275	 * get that from the auipc at PC', where PC = PC' + 4.
276	 */
277	upper = (rvoff + (1 << 11)) >> 12;
278	lower = rvoff & 0xfff;
279
280	emit_bcc(cond, rd, rs, 12, ctx);
281	emit(rv_auipc(RV_REG_T1, upper), ctx);
282	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
283}
284
285static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
286{
287	emit_slli(reg, reg, 32, ctx);
288	emit_srli(reg, reg, 32, ctx);
289}
290
291static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
292{
293	int tc_ninsn, off, start_insn = ctx->ninsns;
294	u8 tcc = rv_tail_call_reg(ctx);
295
296	/* a0: &ctx
297	 * a1: &array
298	 * a2: index
299	 *
300	 * if (index >= array->map.max_entries)
301	 *	goto out;
302	 */
303	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
304		   ctx->offset[0];
305	emit_zext_32(RV_REG_A2, ctx);
306
307	off = offsetof(struct bpf_array, map.max_entries);
308	if (is_12b_check(off, insn))
309		return -1;
310	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
311	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
312	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
313
314	/* if (TCC-- < 0)
315	 *     goto out;
316	 */
317	emit_addi(RV_REG_T1, tcc, -1, ctx);
318	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
319	emit_branch(BPF_JSLT, tcc, RV_REG_ZERO, off, ctx);
320
321	/* prog = array->ptrs[index];
322	 * if (!prog)
323	 *     goto out;
324	 */
325	emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
326	emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
327	off = offsetof(struct bpf_array, ptrs);
328	if (is_12b_check(off, insn))
329		return -1;
330	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
331	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
332	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
333
334	/* goto *(prog->bpf_func + 4); */
335	off = offsetof(struct bpf_prog, bpf_func);
336	if (is_12b_check(off, insn))
337		return -1;
338	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
339	emit_mv(RV_REG_TCC, RV_REG_T1, ctx);
340	__build_epilogue(true, ctx);
341	return 0;
342}
343
344static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
345		      struct rv_jit_context *ctx)
346{
347	u8 code = insn->code;
348
349	switch (code) {
350	case BPF_JMP | BPF_JA:
351	case BPF_JMP | BPF_CALL:
352	case BPF_JMP | BPF_EXIT:
353	case BPF_JMP | BPF_TAIL_CALL:
354		break;
355	default:
356		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
357	}
358
359	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
360	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
361	    code & BPF_LDX || code & BPF_STX)
362		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
363}
364
365static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
366{
367	emit_mv(RV_REG_T2, *rd, ctx);
368	emit_zext_32(RV_REG_T2, ctx);
369	emit_mv(RV_REG_T1, *rs, ctx);
370	emit_zext_32(RV_REG_T1, ctx);
371	*rd = RV_REG_T2;
372	*rs = RV_REG_T1;
373}
374
375static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
376{
377	emit_addiw(RV_REG_T2, *rd, 0, ctx);
378	emit_addiw(RV_REG_T1, *rs, 0, ctx);
379	*rd = RV_REG_T2;
380	*rs = RV_REG_T1;
381}
382
383static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
384{
385	emit_mv(RV_REG_T2, *rd, ctx);
386	emit_zext_32(RV_REG_T2, ctx);
387	emit_zext_32(RV_REG_T1, ctx);
388	*rd = RV_REG_T2;
389}
390
391static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
392{
393	emit_addiw(RV_REG_T2, *rd, 0, ctx);
394	*rd = RV_REG_T2;
395}
396
397static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
398			      struct rv_jit_context *ctx)
399{
400	s64 upper, lower;
401
402	if (rvoff && fixed_addr && is_21b_int(rvoff)) {
403		emit(rv_jal(rd, rvoff >> 1), ctx);
404		return 0;
405	} else if (in_auipc_jalr_range(rvoff)) {
406		upper = (rvoff + (1 << 11)) >> 12;
407		lower = rvoff & 0xfff;
408		emit(rv_auipc(RV_REG_T1, upper), ctx);
409		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
410		return 0;
411	}
412
413	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
414	return -ERANGE;
415}
416
417static bool is_signed_bpf_cond(u8 cond)
418{
419	return cond == BPF_JSGT || cond == BPF_JSLT ||
420		cond == BPF_JSGE || cond == BPF_JSLE;
421}
422
423static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
424{
425	s64 off = 0;
426	u64 ip;
427
428	if (addr && ctx->insns) {
429		ip = (u64)(long)(ctx->insns + ctx->ninsns);
430		off = addr - ip;
431	}
432
433	return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
434}
435
436int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
437		      bool extra_pass)
438{
439	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
440		    BPF_CLASS(insn->code) == BPF_JMP;
441	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
442	struct bpf_prog_aux *aux = ctx->prog->aux;
443	u8 rd = -1, rs = -1, code = insn->code;
444	s16 off = insn->off;
445	s32 imm = insn->imm;
446
447	init_regs(&rd, &rs, insn, ctx);
448
449	switch (code) {
450	/* dst = src */
451	case BPF_ALU | BPF_MOV | BPF_X:
452	case BPF_ALU64 | BPF_MOV | BPF_X:
453		if (imm == 1) {
454			/* Special mov32 for zext */
455			emit_zext_32(rd, ctx);
456			break;
457		}
458		emit_mv(rd, rs, ctx);
459		if (!is64 && !aux->verifier_zext)
460			emit_zext_32(rd, ctx);
461		break;
462
463	/* dst = dst OP src */
464	case BPF_ALU | BPF_ADD | BPF_X:
465	case BPF_ALU64 | BPF_ADD | BPF_X:
466		emit_add(rd, rd, rs, ctx);
467		if (!is64 && !aux->verifier_zext)
468			emit_zext_32(rd, ctx);
469		break;
470	case BPF_ALU | BPF_SUB | BPF_X:
471	case BPF_ALU64 | BPF_SUB | BPF_X:
472		if (is64)
473			emit_sub(rd, rd, rs, ctx);
474		else
475			emit_subw(rd, rd, rs, ctx);
476
477		if (!is64 && !aux->verifier_zext)
478			emit_zext_32(rd, ctx);
479		break;
480	case BPF_ALU | BPF_AND | BPF_X:
481	case BPF_ALU64 | BPF_AND | BPF_X:
482		emit_and(rd, rd, rs, ctx);
483		if (!is64 && !aux->verifier_zext)
484			emit_zext_32(rd, ctx);
485		break;
486	case BPF_ALU | BPF_OR | BPF_X:
487	case BPF_ALU64 | BPF_OR | BPF_X:
488		emit_or(rd, rd, rs, ctx);
489		if (!is64 && !aux->verifier_zext)
490			emit_zext_32(rd, ctx);
491		break;
492	case BPF_ALU | BPF_XOR | BPF_X:
493	case BPF_ALU64 | BPF_XOR | BPF_X:
494		emit_xor(rd, rd, rs, ctx);
495		if (!is64 && !aux->verifier_zext)
496			emit_zext_32(rd, ctx);
497		break;
498	case BPF_ALU | BPF_MUL | BPF_X:
499	case BPF_ALU64 | BPF_MUL | BPF_X:
500		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
501		if (!is64 && !aux->verifier_zext)
502			emit_zext_32(rd, ctx);
503		break;
504	case BPF_ALU | BPF_DIV | BPF_X:
505	case BPF_ALU64 | BPF_DIV | BPF_X:
506		emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
507		if (!is64 && !aux->verifier_zext)
508			emit_zext_32(rd, ctx);
509		break;
510	case BPF_ALU | BPF_MOD | BPF_X:
511	case BPF_ALU64 | BPF_MOD | BPF_X:
512		emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
513		if (!is64 && !aux->verifier_zext)
514			emit_zext_32(rd, ctx);
515		break;
516	case BPF_ALU | BPF_LSH | BPF_X:
517	case BPF_ALU64 | BPF_LSH | BPF_X:
518		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
519		if (!is64 && !aux->verifier_zext)
520			emit_zext_32(rd, ctx);
521		break;
522	case BPF_ALU | BPF_RSH | BPF_X:
523	case BPF_ALU64 | BPF_RSH | BPF_X:
524		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
525		if (!is64 && !aux->verifier_zext)
526			emit_zext_32(rd, ctx);
527		break;
528	case BPF_ALU | BPF_ARSH | BPF_X:
529	case BPF_ALU64 | BPF_ARSH | BPF_X:
530		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
531		if (!is64 && !aux->verifier_zext)
532			emit_zext_32(rd, ctx);
533		break;
534
535	/* dst = -dst */
536	case BPF_ALU | BPF_NEG:
537	case BPF_ALU64 | BPF_NEG:
538		emit_sub(rd, RV_REG_ZERO, rd, ctx);
539		if (!is64 && !aux->verifier_zext)
540			emit_zext_32(rd, ctx);
541		break;
542
543	/* dst = BSWAP##imm(dst) */
544	case BPF_ALU | BPF_END | BPF_FROM_LE:
545		switch (imm) {
546		case 16:
547			emit_slli(rd, rd, 48, ctx);
548			emit_srli(rd, rd, 48, ctx);
549			break;
550		case 32:
551			if (!aux->verifier_zext)
552				emit_zext_32(rd, ctx);
553			break;
554		case 64:
555			/* Do nothing */
556			break;
557		}
558		break;
559
560	case BPF_ALU | BPF_END | BPF_FROM_BE:
561		emit_li(RV_REG_T2, 0, ctx);
562
563		emit_andi(RV_REG_T1, rd, 0xff, ctx);
564		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
565		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
566		emit_srli(rd, rd, 8, ctx);
567		if (imm == 16)
568			goto out_be;
569
570		emit_andi(RV_REG_T1, rd, 0xff, ctx);
571		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
572		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
573		emit_srli(rd, rd, 8, ctx);
574
575		emit_andi(RV_REG_T1, rd, 0xff, ctx);
576		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
577		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
578		emit_srli(rd, rd, 8, ctx);
579		if (imm == 32)
580			goto out_be;
581
582		emit_andi(RV_REG_T1, rd, 0xff, ctx);
583		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
584		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
585		emit_srli(rd, rd, 8, ctx);
586
587		emit_andi(RV_REG_T1, rd, 0xff, ctx);
588		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
589		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
590		emit_srli(rd, rd, 8, ctx);
591
592		emit_andi(RV_REG_T1, rd, 0xff, ctx);
593		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
594		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
595		emit_srli(rd, rd, 8, ctx);
596
597		emit_andi(RV_REG_T1, rd, 0xff, ctx);
598		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
599		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
600		emit_srli(rd, rd, 8, ctx);
601out_be:
602		emit_andi(RV_REG_T1, rd, 0xff, ctx);
603		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
604
605		emit_mv(rd, RV_REG_T2, ctx);
606		break;
607
608	/* dst = imm */
609	case BPF_ALU | BPF_MOV | BPF_K:
610	case BPF_ALU64 | BPF_MOV | BPF_K:
611		emit_imm(rd, imm, ctx);
612		if (!is64 && !aux->verifier_zext)
613			emit_zext_32(rd, ctx);
614		break;
615
616	/* dst = dst OP imm */
617	case BPF_ALU | BPF_ADD | BPF_K:
618	case BPF_ALU64 | BPF_ADD | BPF_K:
619		if (is_12b_int(imm)) {
620			emit_addi(rd, rd, imm, ctx);
621		} else {
622			emit_imm(RV_REG_T1, imm, ctx);
623			emit_add(rd, rd, RV_REG_T1, ctx);
624		}
625		if (!is64 && !aux->verifier_zext)
626			emit_zext_32(rd, ctx);
627		break;
628	case BPF_ALU | BPF_SUB | BPF_K:
629	case BPF_ALU64 | BPF_SUB | BPF_K:
630		if (is_12b_int(-imm)) {
631			emit_addi(rd, rd, -imm, ctx);
632		} else {
633			emit_imm(RV_REG_T1, imm, ctx);
634			emit_sub(rd, rd, RV_REG_T1, ctx);
635		}
636		if (!is64 && !aux->verifier_zext)
637			emit_zext_32(rd, ctx);
638		break;
639	case BPF_ALU | BPF_AND | BPF_K:
640	case BPF_ALU64 | BPF_AND | BPF_K:
641		if (is_12b_int(imm)) {
642			emit_andi(rd, rd, imm, ctx);
643		} else {
644			emit_imm(RV_REG_T1, imm, ctx);
645			emit_and(rd, rd, RV_REG_T1, ctx);
646		}
647		if (!is64 && !aux->verifier_zext)
648			emit_zext_32(rd, ctx);
649		break;
650	case BPF_ALU | BPF_OR | BPF_K:
651	case BPF_ALU64 | BPF_OR | BPF_K:
652		if (is_12b_int(imm)) {
653			emit(rv_ori(rd, rd, imm), ctx);
654		} else {
655			emit_imm(RV_REG_T1, imm, ctx);
656			emit_or(rd, rd, RV_REG_T1, ctx);
657		}
658		if (!is64 && !aux->verifier_zext)
659			emit_zext_32(rd, ctx);
660		break;
661	case BPF_ALU | BPF_XOR | BPF_K:
662	case BPF_ALU64 | BPF_XOR | BPF_K:
663		if (is_12b_int(imm)) {
664			emit(rv_xori(rd, rd, imm), ctx);
665		} else {
666			emit_imm(RV_REG_T1, imm, ctx);
667			emit_xor(rd, rd, RV_REG_T1, ctx);
668		}
669		if (!is64 && !aux->verifier_zext)
670			emit_zext_32(rd, ctx);
671		break;
672	case BPF_ALU | BPF_MUL | BPF_K:
673	case BPF_ALU64 | BPF_MUL | BPF_K:
674		emit_imm(RV_REG_T1, imm, ctx);
675		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
676		     rv_mulw(rd, rd, RV_REG_T1), ctx);
677		if (!is64 && !aux->verifier_zext)
678			emit_zext_32(rd, ctx);
679		break;
680	case BPF_ALU | BPF_DIV | BPF_K:
681	case BPF_ALU64 | BPF_DIV | BPF_K:
682		emit_imm(RV_REG_T1, imm, ctx);
683		emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
684		     rv_divuw(rd, rd, RV_REG_T1), ctx);
685		if (!is64 && !aux->verifier_zext)
686			emit_zext_32(rd, ctx);
687		break;
688	case BPF_ALU | BPF_MOD | BPF_K:
689	case BPF_ALU64 | BPF_MOD | BPF_K:
690		emit_imm(RV_REG_T1, imm, ctx);
691		emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
692		     rv_remuw(rd, rd, RV_REG_T1), ctx);
693		if (!is64 && !aux->verifier_zext)
694			emit_zext_32(rd, ctx);
695		break;
696	case BPF_ALU | BPF_LSH | BPF_K:
697	case BPF_ALU64 | BPF_LSH | BPF_K:
698		emit_slli(rd, rd, imm, ctx);
699
700		if (!is64 && !aux->verifier_zext)
701			emit_zext_32(rd, ctx);
702		break;
703	case BPF_ALU | BPF_RSH | BPF_K:
704	case BPF_ALU64 | BPF_RSH | BPF_K:
705		if (is64)
706			emit_srli(rd, rd, imm, ctx);
707		else
708			emit(rv_srliw(rd, rd, imm), ctx);
709
710		if (!is64 && !aux->verifier_zext)
711			emit_zext_32(rd, ctx);
712		break;
713	case BPF_ALU | BPF_ARSH | BPF_K:
714	case BPF_ALU64 | BPF_ARSH | BPF_K:
715		if (is64)
716			emit_srai(rd, rd, imm, ctx);
717		else
718			emit(rv_sraiw(rd, rd, imm), ctx);
719
720		if (!is64 && !aux->verifier_zext)
721			emit_zext_32(rd, ctx);
722		break;
723
724	/* JUMP off */
725	case BPF_JMP | BPF_JA:
726		rvoff = rv_offset(i, off, ctx);
727		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
728		if (ret)
729			return ret;
730		break;
731
732	/* IF (dst COND src) JUMP off */
733	case BPF_JMP | BPF_JEQ | BPF_X:
734	case BPF_JMP32 | BPF_JEQ | BPF_X:
735	case BPF_JMP | BPF_JGT | BPF_X:
736	case BPF_JMP32 | BPF_JGT | BPF_X:
737	case BPF_JMP | BPF_JLT | BPF_X:
738	case BPF_JMP32 | BPF_JLT | BPF_X:
739	case BPF_JMP | BPF_JGE | BPF_X:
740	case BPF_JMP32 | BPF_JGE | BPF_X:
741	case BPF_JMP | BPF_JLE | BPF_X:
742	case BPF_JMP32 | BPF_JLE | BPF_X:
743	case BPF_JMP | BPF_JNE | BPF_X:
744	case BPF_JMP32 | BPF_JNE | BPF_X:
745	case BPF_JMP | BPF_JSGT | BPF_X:
746	case BPF_JMP32 | BPF_JSGT | BPF_X:
747	case BPF_JMP | BPF_JSLT | BPF_X:
748	case BPF_JMP32 | BPF_JSLT | BPF_X:
749	case BPF_JMP | BPF_JSGE | BPF_X:
750	case BPF_JMP32 | BPF_JSGE | BPF_X:
751	case BPF_JMP | BPF_JSLE | BPF_X:
752	case BPF_JMP32 | BPF_JSLE | BPF_X:
753	case BPF_JMP | BPF_JSET | BPF_X:
754	case BPF_JMP32 | BPF_JSET | BPF_X:
755		rvoff = rv_offset(i, off, ctx);
756		if (!is64) {
757			s = ctx->ninsns;
758			if (is_signed_bpf_cond(BPF_OP(code)))
759				emit_sext_32_rd_rs(&rd, &rs, ctx);
760			else
761				emit_zext_32_rd_rs(&rd, &rs, ctx);
762			e = ctx->ninsns;
763
764			/* Adjust for extra insns */
765			rvoff -= ninsns_rvoff(e - s);
766		}
767
768		if (BPF_OP(code) == BPF_JSET) {
769			/* Adjust for and */
770			rvoff -= 4;
771			emit_and(RV_REG_T1, rd, rs, ctx);
772			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
773				    ctx);
774		} else {
775			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
776		}
777		break;
778
779	/* IF (dst COND imm) JUMP off */
780	case BPF_JMP | BPF_JEQ | BPF_K:
781	case BPF_JMP32 | BPF_JEQ | BPF_K:
782	case BPF_JMP | BPF_JGT | BPF_K:
783	case BPF_JMP32 | BPF_JGT | BPF_K:
784	case BPF_JMP | BPF_JLT | BPF_K:
785	case BPF_JMP32 | BPF_JLT | BPF_K:
786	case BPF_JMP | BPF_JGE | BPF_K:
787	case BPF_JMP32 | BPF_JGE | BPF_K:
788	case BPF_JMP | BPF_JLE | BPF_K:
789	case BPF_JMP32 | BPF_JLE | BPF_K:
790	case BPF_JMP | BPF_JNE | BPF_K:
791	case BPF_JMP32 | BPF_JNE | BPF_K:
792	case BPF_JMP | BPF_JSGT | BPF_K:
793	case BPF_JMP32 | BPF_JSGT | BPF_K:
794	case BPF_JMP | BPF_JSLT | BPF_K:
795	case BPF_JMP32 | BPF_JSLT | BPF_K:
796	case BPF_JMP | BPF_JSGE | BPF_K:
797	case BPF_JMP32 | BPF_JSGE | BPF_K:
798	case BPF_JMP | BPF_JSLE | BPF_K:
799	case BPF_JMP32 | BPF_JSLE | BPF_K:
800		rvoff = rv_offset(i, off, ctx);
801		s = ctx->ninsns;
802		if (imm) {
803			emit_imm(RV_REG_T1, imm, ctx);
804			rs = RV_REG_T1;
805		} else {
806			/* If imm is 0, simply use zero register. */
807			rs = RV_REG_ZERO;
808		}
809		if (!is64) {
810			if (is_signed_bpf_cond(BPF_OP(code)))
811				emit_sext_32_rd(&rd, ctx);
812			else
813				emit_zext_32_rd_t1(&rd, ctx);
814		}
815		e = ctx->ninsns;
816
817		/* Adjust for extra insns */
818		rvoff -= ninsns_rvoff(e - s);
819		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
820		break;
821
822	case BPF_JMP | BPF_JSET | BPF_K:
823	case BPF_JMP32 | BPF_JSET | BPF_K:
824		rvoff = rv_offset(i, off, ctx);
825		s = ctx->ninsns;
826		if (is_12b_int(imm)) {
827			emit_andi(RV_REG_T1, rd, imm, ctx);
828		} else {
829			emit_imm(RV_REG_T1, imm, ctx);
830			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
831		}
832		/* For jset32, we should clear the upper 32 bits of t1, but
833		 * sign-extension is sufficient here and saves one instruction,
834		 * as t1 is used only in comparison against zero.
835		 */
836		if (!is64 && imm < 0)
837			emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
838		e = ctx->ninsns;
839		rvoff -= ninsns_rvoff(e - s);
840		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
841		break;
842
843	/* function call */
844	case BPF_JMP | BPF_CALL:
845	{
846		bool fixed_addr;
847		u64 addr;
848
849		mark_call(ctx);
850		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
851					    &addr, &fixed_addr);
852		if (ret < 0)
853			return ret;
854
855		ret = emit_call(addr, fixed_addr, ctx);
856		if (ret)
857			return ret;
858
859		if (insn->src_reg != BPF_PSEUDO_CALL)
860			emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
861		break;
862	}
863	/* tail call */
864	case BPF_JMP | BPF_TAIL_CALL:
865		if (emit_bpf_tail_call(i, ctx))
866			return -1;
867		break;
868
869	/* function return */
870	case BPF_JMP | BPF_EXIT:
871		if (i == ctx->prog->len - 1)
872			break;
873
874		rvoff = epilogue_offset(ctx);
875		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
876		if (ret)
877			return ret;
878		break;
879
880	/* dst = imm64 */
881	case BPF_LD | BPF_IMM | BPF_DW:
882	{
883		struct bpf_insn insn1 = insn[1];
884		u64 imm64;
885
886		imm64 = (u64)insn1.imm << 32 | (u32)imm;
887		emit_imm(rd, imm64, ctx);
888		return 1;
889	}
890
891	/* LDX: dst = *(size *)(src + off) */
892	case BPF_LDX | BPF_MEM | BPF_B:
893		if (is_12b_int(off)) {
894			emit(rv_lbu(rd, off, rs), ctx);
895			break;
896		}
897
898		emit_imm(RV_REG_T1, off, ctx);
899		emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
900		emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
901		if (insn_is_zext(&insn[1]))
902			return 1;
903		break;
904	case BPF_LDX | BPF_MEM | BPF_H:
905		if (is_12b_int(off)) {
906			emit(rv_lhu(rd, off, rs), ctx);
907			break;
908		}
909
910		emit_imm(RV_REG_T1, off, ctx);
911		emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
912		emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
913		if (insn_is_zext(&insn[1]))
914			return 1;
915		break;
916	case BPF_LDX | BPF_MEM | BPF_W:
917		if (is_12b_int(off)) {
918			emit(rv_lwu(rd, off, rs), ctx);
919			break;
920		}
921
922		emit_imm(RV_REG_T1, off, ctx);
923		emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
924		emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
925		if (insn_is_zext(&insn[1]))
926			return 1;
927		break;
928	case BPF_LDX | BPF_MEM | BPF_DW:
929		if (is_12b_int(off)) {
930			emit_ld(rd, off, rs, ctx);
931			break;
932		}
933
934		emit_imm(RV_REG_T1, off, ctx);
935		emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
936		emit_ld(rd, 0, RV_REG_T1, ctx);
937		break;
938
939	/* speculation barrier */
940	case BPF_ST | BPF_NOSPEC:
941		break;
942
943	/* ST: *(size *)(dst + off) = imm */
944	case BPF_ST | BPF_MEM | BPF_B:
945		emit_imm(RV_REG_T1, imm, ctx);
946		if (is_12b_int(off)) {
947			emit(rv_sb(rd, off, RV_REG_T1), ctx);
948			break;
949		}
950
951		emit_imm(RV_REG_T2, off, ctx);
952		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
953		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
954		break;
955
956	case BPF_ST | BPF_MEM | BPF_H:
957		emit_imm(RV_REG_T1, imm, ctx);
958		if (is_12b_int(off)) {
959			emit(rv_sh(rd, off, RV_REG_T1), ctx);
960			break;
961		}
962
963		emit_imm(RV_REG_T2, off, ctx);
964		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
965		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
966		break;
967	case BPF_ST | BPF_MEM | BPF_W:
968		emit_imm(RV_REG_T1, imm, ctx);
969		if (is_12b_int(off)) {
970			emit_sw(rd, off, RV_REG_T1, ctx);
971			break;
972		}
973
974		emit_imm(RV_REG_T2, off, ctx);
975		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
976		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
977		break;
978	case BPF_ST | BPF_MEM | BPF_DW:
979		emit_imm(RV_REG_T1, imm, ctx);
980		if (is_12b_int(off)) {
981			emit_sd(rd, off, RV_REG_T1, ctx);
982			break;
983		}
984
985		emit_imm(RV_REG_T2, off, ctx);
986		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
987		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
988		break;
989
990	/* STX: *(size *)(dst + off) = src */
991	case BPF_STX | BPF_MEM | BPF_B:
992		if (is_12b_int(off)) {
993			emit(rv_sb(rd, off, rs), ctx);
994			break;
995		}
996
997		emit_imm(RV_REG_T1, off, ctx);
998		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
999		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1000		break;
1001	case BPF_STX | BPF_MEM | BPF_H:
1002		if (is_12b_int(off)) {
1003			emit(rv_sh(rd, off, rs), ctx);
1004			break;
1005		}
1006
1007		emit_imm(RV_REG_T1, off, ctx);
1008		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1009		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1010		break;
1011	case BPF_STX | BPF_MEM | BPF_W:
1012		if (is_12b_int(off)) {
1013			emit_sw(rd, off, rs, ctx);
1014			break;
1015		}
1016
1017		emit_imm(RV_REG_T1, off, ctx);
1018		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1019		emit_sw(RV_REG_T1, 0, rs, ctx);
1020		break;
1021	case BPF_STX | BPF_MEM | BPF_DW:
1022		if (is_12b_int(off)) {
1023			emit_sd(rd, off, rs, ctx);
1024			break;
1025		}
1026
1027		emit_imm(RV_REG_T1, off, ctx);
1028		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1029		emit_sd(RV_REG_T1, 0, rs, ctx);
1030		break;
1031	/* STX XADD: lock *(u32 *)(dst + off) += src */
1032	case BPF_STX | BPF_XADD | BPF_W:
1033	/* STX XADD: lock *(u64 *)(dst + off) += src */
1034	case BPF_STX | BPF_XADD | BPF_DW:
1035		if (off) {
1036			if (is_12b_int(off)) {
1037				emit_addi(RV_REG_T1, rd, off, ctx);
1038			} else {
1039				emit_imm(RV_REG_T1, off, ctx);
1040				emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1041			}
1042
1043			rd = RV_REG_T1;
1044		}
1045
1046		emit(BPF_SIZE(code) == BPF_W ?
1047		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0) :
1048		     rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0), ctx);
1049		break;
1050	default:
1051		pr_err("bpf-jit: unknown opcode %02x\n", code);
1052		return -EINVAL;
1053	}
1054
1055	return 0;
1056}
1057
1058void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1059{
1060	int stack_adjust = 0, store_offset, bpf_stack_adjust;
1061
1062	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1063	if (bpf_stack_adjust)
1064		mark_fp(ctx);
1065
1066	if (seen_reg(RV_REG_RA, ctx))
1067		stack_adjust += 8;
1068	stack_adjust += 8; /* RV_REG_FP */
1069	if (seen_reg(RV_REG_S1, ctx))
1070		stack_adjust += 8;
1071	if (seen_reg(RV_REG_S2, ctx))
1072		stack_adjust += 8;
1073	if (seen_reg(RV_REG_S3, ctx))
1074		stack_adjust += 8;
1075	if (seen_reg(RV_REG_S4, ctx))
1076		stack_adjust += 8;
1077	if (seen_reg(RV_REG_S5, ctx))
1078		stack_adjust += 8;
1079	if (seen_reg(RV_REG_S6, ctx))
1080		stack_adjust += 8;
1081
1082	stack_adjust = round_up(stack_adjust, 16);
1083	stack_adjust += bpf_stack_adjust;
1084
1085	store_offset = stack_adjust - 8;
1086
1087	/* First instruction is always setting the tail-call-counter
1088	 * (TCC) register. This instruction is skipped for tail calls.
1089	 * Force using a 4-byte (non-compressed) instruction.
1090	 */
1091	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1092
1093	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1094
1095	if (seen_reg(RV_REG_RA, ctx)) {
1096		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1097		store_offset -= 8;
1098	}
1099	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1100	store_offset -= 8;
1101	if (seen_reg(RV_REG_S1, ctx)) {
1102		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1103		store_offset -= 8;
1104	}
1105	if (seen_reg(RV_REG_S2, ctx)) {
1106		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1107		store_offset -= 8;
1108	}
1109	if (seen_reg(RV_REG_S3, ctx)) {
1110		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1111		store_offset -= 8;
1112	}
1113	if (seen_reg(RV_REG_S4, ctx)) {
1114		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1115		store_offset -= 8;
1116	}
1117	if (seen_reg(RV_REG_S5, ctx)) {
1118		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1119		store_offset -= 8;
1120	}
1121	if (seen_reg(RV_REG_S6, ctx)) {
1122		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1123		store_offset -= 8;
1124	}
1125
1126	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1127
1128	if (bpf_stack_adjust)
1129		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1130
1131	/* Program contains calls and tail calls, so RV_REG_TCC need
1132	 * to be saved across calls.
1133	 */
1134	if (seen_tail_call(ctx) && seen_call(ctx))
1135		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1136
1137	ctx->stack_size = stack_adjust;
1138}
1139
1140void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1141{
1142	__build_epilogue(false, ctx);
1143}
1144