1// SPDX-License-Identifier: GPL-2.0
2/* BPF JIT compiler for RV64G
3 *
4 * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5 *
6 */
7
8#include <linux/bitfield.h>
9#include <linux/bpf.h>
10#include <linux/filter.h>
11#include <linux/memory.h>
12#include <linux/stop_machine.h>
13#include <asm/patch.h>
14#include "bpf_jit.h"
15
16#define RV_FENTRY_NINSNS 2
17
18#define RV_REG_TCC RV_REG_A6
19#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
20
21static const int regmap[] = {
22	[BPF_REG_0] =	RV_REG_A5,
23	[BPF_REG_1] =	RV_REG_A0,
24	[BPF_REG_2] =	RV_REG_A1,
25	[BPF_REG_3] =	RV_REG_A2,
26	[BPF_REG_4] =	RV_REG_A3,
27	[BPF_REG_5] =	RV_REG_A4,
28	[BPF_REG_6] =	RV_REG_S1,
29	[BPF_REG_7] =	RV_REG_S2,
30	[BPF_REG_8] =	RV_REG_S3,
31	[BPF_REG_9] =	RV_REG_S4,
32	[BPF_REG_FP] =	RV_REG_S5,
33	[BPF_REG_AX] =	RV_REG_T0,
34};
35
36static const int pt_regmap[] = {
37	[RV_REG_A0] = offsetof(struct pt_regs, a0),
38	[RV_REG_A1] = offsetof(struct pt_regs, a1),
39	[RV_REG_A2] = offsetof(struct pt_regs, a2),
40	[RV_REG_A3] = offsetof(struct pt_regs, a3),
41	[RV_REG_A4] = offsetof(struct pt_regs, a4),
42	[RV_REG_A5] = offsetof(struct pt_regs, a5),
43	[RV_REG_S1] = offsetof(struct pt_regs, s1),
44	[RV_REG_S2] = offsetof(struct pt_regs, s2),
45	[RV_REG_S3] = offsetof(struct pt_regs, s3),
46	[RV_REG_S4] = offsetof(struct pt_regs, s4),
47	[RV_REG_S5] = offsetof(struct pt_regs, s5),
48	[RV_REG_T0] = offsetof(struct pt_regs, t0),
49};
50
51enum {
52	RV_CTX_F_SEEN_TAIL_CALL =	0,
53	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
54	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
55	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
56	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
57	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
58	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
59	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
60};
61
62static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
63{
64	u8 reg = regmap[bpf_reg];
65
66	switch (reg) {
67	case RV_CTX_F_SEEN_S1:
68	case RV_CTX_F_SEEN_S2:
69	case RV_CTX_F_SEEN_S3:
70	case RV_CTX_F_SEEN_S4:
71	case RV_CTX_F_SEEN_S5:
72	case RV_CTX_F_SEEN_S6:
73		__set_bit(reg, &ctx->flags);
74	}
75	return reg;
76};
77
78static bool seen_reg(int reg, struct rv_jit_context *ctx)
79{
80	switch (reg) {
81	case RV_CTX_F_SEEN_CALL:
82	case RV_CTX_F_SEEN_S1:
83	case RV_CTX_F_SEEN_S2:
84	case RV_CTX_F_SEEN_S3:
85	case RV_CTX_F_SEEN_S4:
86	case RV_CTX_F_SEEN_S5:
87	case RV_CTX_F_SEEN_S6:
88		return test_bit(reg, &ctx->flags);
89	}
90	return false;
91}
92
93static void mark_fp(struct rv_jit_context *ctx)
94{
95	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
96}
97
98static void mark_call(struct rv_jit_context *ctx)
99{
100	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
101}
102
103static bool seen_call(struct rv_jit_context *ctx)
104{
105	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
106}
107
108static void mark_tail_call(struct rv_jit_context *ctx)
109{
110	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
111}
112
113static bool seen_tail_call(struct rv_jit_context *ctx)
114{
115	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
116}
117
118static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
119{
120	mark_tail_call(ctx);
121
122	if (seen_call(ctx)) {
123		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
124		return RV_REG_S6;
125	}
126	return RV_REG_A6;
127}
128
129static bool is_32b_int(s64 val)
130{
131	return -(1L << 31) <= val && val < (1L << 31);
132}
133
134static bool in_auipc_jalr_range(s64 val)
135{
136	/*
137	 * auipc+jalr can reach any signed PC-relative offset in the range
138	 * [-2^31 - 2^11, 2^31 - 2^11).
139	 */
140	return (-(1L << 31) - (1L << 11)) <= val &&
141		val < ((1L << 31) - (1L << 11));
142}
143
144/* Emit fixed-length instructions for address */
145static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
146{
147	/*
148	 * Use the ro_insns(RX) to calculate the offset as the BPF program will
149	 * finally run from this memory region.
150	 */
151	u64 ip = (u64)(ctx->ro_insns + ctx->ninsns);
152	s64 off = addr - ip;
153	s64 upper = (off + (1 << 11)) >> 12;
154	s64 lower = off & 0xfff;
155
156	if (extra_pass && !in_auipc_jalr_range(off)) {
157		pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
158		return -ERANGE;
159	}
160
161	emit(rv_auipc(rd, upper), ctx);
162	emit(rv_addi(rd, rd, lower), ctx);
163	return 0;
164}
165
166/* Emit variable-length instructions for 32-bit and 64-bit imm */
167static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
168{
169	/* Note that the immediate from the add is sign-extended,
170	 * which means that we need to compensate this by adding 2^12,
171	 * when the 12th bit is set. A simpler way of doing this, and
172	 * getting rid of the check, is to just add 2**11 before the
173	 * shift. The "Loading a 32-Bit constant" example from the
174	 * "Computer Organization and Design, RISC-V edition" book by
175	 * Patterson/Hennessy highlights this fact.
176	 *
177	 * This also means that we need to process LSB to MSB.
178	 */
179	s64 upper = (val + (1 << 11)) >> 12;
180	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
181	 * and addi are signed and RVC checks will perform signed comparisons.
182	 */
183	s64 lower = ((val & 0xfff) << 52) >> 52;
184	int shift;
185
186	if (is_32b_int(val)) {
187		if (upper)
188			emit_lui(rd, upper, ctx);
189
190		if (!upper) {
191			emit_li(rd, lower, ctx);
192			return;
193		}
194
195		emit_addiw(rd, rd, lower, ctx);
196		return;
197	}
198
199	shift = __ffs(upper);
200	upper >>= shift;
201	shift += 12;
202
203	emit_imm(rd, upper, ctx);
204
205	emit_slli(rd, rd, shift, ctx);
206	if (lower)
207		emit_addi(rd, rd, lower, ctx);
208}
209
210static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
211{
212	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
213
214	if (seen_reg(RV_REG_RA, ctx)) {
215		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
216		store_offset -= 8;
217	}
218	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
219	store_offset -= 8;
220	if (seen_reg(RV_REG_S1, ctx)) {
221		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
222		store_offset -= 8;
223	}
224	if (seen_reg(RV_REG_S2, ctx)) {
225		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
226		store_offset -= 8;
227	}
228	if (seen_reg(RV_REG_S3, ctx)) {
229		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
230		store_offset -= 8;
231	}
232	if (seen_reg(RV_REG_S4, ctx)) {
233		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
234		store_offset -= 8;
235	}
236	if (seen_reg(RV_REG_S5, ctx)) {
237		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
238		store_offset -= 8;
239	}
240	if (seen_reg(RV_REG_S6, ctx)) {
241		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
242		store_offset -= 8;
243	}
244
245	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
246	/* Set return value. */
247	if (!is_tail_call)
248		emit_addiw(RV_REG_A0, RV_REG_A5, 0, ctx);
249	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
250		  is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
251		  ctx);
252}
253
254static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
255		     struct rv_jit_context *ctx)
256{
257	switch (cond) {
258	case BPF_JEQ:
259		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
260		return;
261	case BPF_JGT:
262		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
263		return;
264	case BPF_JLT:
265		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
266		return;
267	case BPF_JGE:
268		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
269		return;
270	case BPF_JLE:
271		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
272		return;
273	case BPF_JNE:
274		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
275		return;
276	case BPF_JSGT:
277		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
278		return;
279	case BPF_JSLT:
280		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
281		return;
282	case BPF_JSGE:
283		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
284		return;
285	case BPF_JSLE:
286		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
287	}
288}
289
290static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
291			struct rv_jit_context *ctx)
292{
293	s64 upper, lower;
294
295	if (is_13b_int(rvoff)) {
296		emit_bcc(cond, rd, rs, rvoff, ctx);
297		return;
298	}
299
300	/* Adjust for jal */
301	rvoff -= 4;
302
303	/* Transform, e.g.:
304	 *   bne rd,rs,foo
305	 * to
306	 *   beq rd,rs,<.L1>
307	 *   (auipc foo)
308	 *   jal(r) foo
309	 * .L1
310	 */
311	cond = invert_bpf_cond(cond);
312	if (is_21b_int(rvoff)) {
313		emit_bcc(cond, rd, rs, 8, ctx);
314		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
315		return;
316	}
317
318	/* 32b No need for an additional rvoff adjustment, since we
319	 * get that from the auipc at PC', where PC = PC' + 4.
320	 */
321	upper = (rvoff + (1 << 11)) >> 12;
322	lower = rvoff & 0xfff;
323
324	emit_bcc(cond, rd, rs, 12, ctx);
325	emit(rv_auipc(RV_REG_T1, upper), ctx);
326	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
327}
328
329static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
330{
331	emit_slli(reg, reg, 32, ctx);
332	emit_srli(reg, reg, 32, ctx);
333}
334
335static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
336{
337	int tc_ninsn, off, start_insn = ctx->ninsns;
338	u8 tcc = rv_tail_call_reg(ctx);
339
340	/* a0: &ctx
341	 * a1: &array
342	 * a2: index
343	 *
344	 * if (index >= array->map.max_entries)
345	 *	goto out;
346	 */
347	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
348		   ctx->offset[0];
349	emit_zext_32(RV_REG_A2, ctx);
350
351	off = offsetof(struct bpf_array, map.max_entries);
352	if (is_12b_check(off, insn))
353		return -1;
354	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
355	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
356	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
357
358	/* if (--TCC < 0)
359	 *     goto out;
360	 */
361	emit_addi(RV_REG_TCC, tcc, -1, ctx);
362	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
363	emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
364
365	/* prog = array->ptrs[index];
366	 * if (!prog)
367	 *     goto out;
368	 */
369	emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
370	emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
371	off = offsetof(struct bpf_array, ptrs);
372	if (is_12b_check(off, insn))
373		return -1;
374	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
375	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
376	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
377
378	/* goto *(prog->bpf_func + 4); */
379	off = offsetof(struct bpf_prog, bpf_func);
380	if (is_12b_check(off, insn))
381		return -1;
382	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
383	__build_epilogue(true, ctx);
384	return 0;
385}
386
387static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
388		      struct rv_jit_context *ctx)
389{
390	u8 code = insn->code;
391
392	switch (code) {
393	case BPF_JMP | BPF_JA:
394	case BPF_JMP | BPF_CALL:
395	case BPF_JMP | BPF_EXIT:
396	case BPF_JMP | BPF_TAIL_CALL:
397		break;
398	default:
399		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
400	}
401
402	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
403	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
404	    code & BPF_LDX || code & BPF_STX)
405		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
406}
407
408static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
409{
410	emit_mv(RV_REG_T2, *rd, ctx);
411	emit_zext_32(RV_REG_T2, ctx);
412	emit_mv(RV_REG_T1, *rs, ctx);
413	emit_zext_32(RV_REG_T1, ctx);
414	*rd = RV_REG_T2;
415	*rs = RV_REG_T1;
416}
417
418static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
419{
420	emit_addiw(RV_REG_T2, *rd, 0, ctx);
421	emit_addiw(RV_REG_T1, *rs, 0, ctx);
422	*rd = RV_REG_T2;
423	*rs = RV_REG_T1;
424}
425
426static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
427{
428	emit_mv(RV_REG_T2, *rd, ctx);
429	emit_zext_32(RV_REG_T2, ctx);
430	emit_zext_32(RV_REG_T1, ctx);
431	*rd = RV_REG_T2;
432}
433
434static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
435{
436	emit_addiw(RV_REG_T2, *rd, 0, ctx);
437	*rd = RV_REG_T2;
438}
439
440static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
441			      struct rv_jit_context *ctx)
442{
443	s64 upper, lower;
444
445	if (rvoff && fixed_addr && is_21b_int(rvoff)) {
446		emit(rv_jal(rd, rvoff >> 1), ctx);
447		return 0;
448	} else if (in_auipc_jalr_range(rvoff)) {
449		upper = (rvoff + (1 << 11)) >> 12;
450		lower = rvoff & 0xfff;
451		emit(rv_auipc(RV_REG_T1, upper), ctx);
452		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
453		return 0;
454	}
455
456	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
457	return -ERANGE;
458}
459
460static bool is_signed_bpf_cond(u8 cond)
461{
462	return cond == BPF_JSGT || cond == BPF_JSLT ||
463		cond == BPF_JSGE || cond == BPF_JSLE;
464}
465
466static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
467{
468	s64 off = 0;
469	u64 ip;
470
471	if (addr && ctx->insns && ctx->ro_insns) {
472		/*
473		 * Use the ro_insns(RX) to calculate the offset as the BPF
474		 * program will finally run from this memory region.
475		 */
476		ip = (u64)(long)(ctx->ro_insns + ctx->ninsns);
477		off = addr - ip;
478	}
479
480	return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
481}
482
483static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
484			struct rv_jit_context *ctx)
485{
486	u8 r0;
487	int jmp_offset;
488
489	if (off) {
490		if (is_12b_int(off)) {
491			emit_addi(RV_REG_T1, rd, off, ctx);
492		} else {
493			emit_imm(RV_REG_T1, off, ctx);
494			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
495		}
496		rd = RV_REG_T1;
497	}
498
499	switch (imm) {
500	/* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
501	case BPF_ADD:
502		emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
503		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
504		break;
505	case BPF_AND:
506		emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
507		     rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
508		break;
509	case BPF_OR:
510		emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
511		     rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
512		break;
513	case BPF_XOR:
514		emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
515		     rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
516		break;
517	/* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
518	case BPF_ADD | BPF_FETCH:
519		emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) :
520		     rv_amoadd_w(rs, rs, rd, 0, 0), ctx);
521		if (!is64)
522			emit_zext_32(rs, ctx);
523		break;
524	case BPF_AND | BPF_FETCH:
525		emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) :
526		     rv_amoand_w(rs, rs, rd, 0, 0), ctx);
527		if (!is64)
528			emit_zext_32(rs, ctx);
529		break;
530	case BPF_OR | BPF_FETCH:
531		emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) :
532		     rv_amoor_w(rs, rs, rd, 0, 0), ctx);
533		if (!is64)
534			emit_zext_32(rs, ctx);
535		break;
536	case BPF_XOR | BPF_FETCH:
537		emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) :
538		     rv_amoxor_w(rs, rs, rd, 0, 0), ctx);
539		if (!is64)
540			emit_zext_32(rs, ctx);
541		break;
542	/* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
543	case BPF_XCHG:
544		emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) :
545		     rv_amoswap_w(rs, rs, rd, 0, 0), ctx);
546		if (!is64)
547			emit_zext_32(rs, ctx);
548		break;
549	/* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
550	case BPF_CMPXCHG:
551		r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
552		emit(is64 ? rv_addi(RV_REG_T2, r0, 0) :
553		     rv_addiw(RV_REG_T2, r0, 0), ctx);
554		emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
555		     rv_lr_w(r0, 0, rd, 0, 0), ctx);
556		jmp_offset = ninsns_rvoff(8);
557		emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
558		emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
559		     rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
560		jmp_offset = ninsns_rvoff(-6);
561		emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
562		emit(rv_fence(0x3, 0x3), ctx);
563		break;
564	}
565}
566
567#define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
568#define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
569
570bool ex_handler_bpf(const struct exception_table_entry *ex,
571		    struct pt_regs *regs)
572{
573	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
574	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
575
576	*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
577	regs->epc = (unsigned long)&ex->fixup - offset;
578
579	return true;
580}
581
582/* For accesses to BTF pointers, add an entry to the exception table */
583static int add_exception_handler(const struct bpf_insn *insn,
584				 struct rv_jit_context *ctx,
585				 int dst_reg, int insn_len)
586{
587	struct exception_table_entry *ex;
588	unsigned long pc;
589	off_t ins_offset;
590	off_t fixup_offset;
591
592	if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
593	    (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX))
594		return 0;
595
596	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
597		return -EINVAL;
598
599	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
600		return -EINVAL;
601
602	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
603		return -EINVAL;
604
605	ex = &ctx->prog->aux->extable[ctx->nexentries];
606	pc = (unsigned long)&ctx->ro_insns[ctx->ninsns - insn_len];
607
608	/*
609	 * This is the relative offset of the instruction that may fault from
610	 * the exception table itself. This will be written to the exception
611	 * table and if this instruction faults, the destination register will
612	 * be set to '0' and the execution will jump to the next instruction.
613	 */
614	ins_offset = pc - (long)&ex->insn;
615	if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
616		return -ERANGE;
617
618	/*
619	 * Since the extable follows the program, the fixup offset is always
620	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
621	 * to keep things simple, and put the destination register in the upper
622	 * bits. We don't need to worry about buildtime or runtime sort
623	 * modifying the upper bits because the table is already sorted, and
624	 * isn't part of the main exception table.
625	 *
626	 * The fixup_offset is set to the next instruction from the instruction
627	 * that may fault. The execution will jump to this after handling the
628	 * fault.
629	 */
630	fixup_offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
631	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
632		return -ERANGE;
633
634	/*
635	 * The offsets above have been calculated using the RO buffer but we
636	 * need to use the R/W buffer for writes.
637	 * switch ex to rw buffer for writing.
638	 */
639	ex = (void *)ctx->insns + ((void *)ex - (void *)ctx->ro_insns);
640
641	ex->insn = ins_offset;
642
643	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
644		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
645	ex->type = EX_TYPE_BPF;
646
647	ctx->nexentries++;
648	return 0;
649}
650
651static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
652{
653	s64 rvoff;
654	struct rv_jit_context ctx;
655
656	ctx.ninsns = 0;
657	ctx.insns = (u16 *)insns;
658
659	if (!target) {
660		emit(rv_nop(), &ctx);
661		emit(rv_nop(), &ctx);
662		return 0;
663	}
664
665	rvoff = (s64)(target - ip);
666	return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
667}
668
669int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
670		       void *old_addr, void *new_addr)
671{
672	u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
673	bool is_call = poke_type == BPF_MOD_CALL;
674	int ret;
675
676	if (!is_kernel_text((unsigned long)ip) &&
677	    !is_bpf_text_address((unsigned long)ip))
678		return -ENOTSUPP;
679
680	ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
681	if (ret)
682		return ret;
683
684	if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
685		return -EFAULT;
686
687	ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
688	if (ret)
689		return ret;
690
691	cpus_read_lock();
692	mutex_lock(&text_mutex);
693	if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
694		ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
695	mutex_unlock(&text_mutex);
696	cpus_read_unlock();
697
698	return ret;
699}
700
701static void store_args(int nregs, int args_off, struct rv_jit_context *ctx)
702{
703	int i;
704
705	for (i = 0; i < nregs; i++) {
706		emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
707		args_off -= 8;
708	}
709}
710
711static void restore_args(int nregs, int args_off, struct rv_jit_context *ctx)
712{
713	int i;
714
715	for (i = 0; i < nregs; i++) {
716		emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
717		args_off -= 8;
718	}
719}
720
721static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
722			   int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
723{
724	int ret, branch_off;
725	struct bpf_prog *p = l->link.prog;
726	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
727
728	if (l->cookie) {
729		emit_imm(RV_REG_T1, l->cookie, ctx);
730		emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
731	} else {
732		emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
733	}
734
735	/* arg1: prog */
736	emit_imm(RV_REG_A0, (const s64)p, ctx);
737	/* arg2: &run_ctx */
738	emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
739	ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
740	if (ret)
741		return ret;
742
743	/* if (__bpf_prog_enter(prog) == 0)
744	 *	goto skip_exec_of_prog;
745	 */
746	branch_off = ctx->ninsns;
747	/* nop reserved for conditional jump */
748	emit(rv_nop(), ctx);
749
750	/* store prog start time */
751	emit_mv(RV_REG_S1, RV_REG_A0, ctx);
752
753	/* arg1: &args_off */
754	emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
755	if (!p->jited)
756		/* arg2: progs[i]->insnsi for interpreter */
757		emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
758	ret = emit_call((const u64)p->bpf_func, true, ctx);
759	if (ret)
760		return ret;
761
762	if (save_ret) {
763		emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
764		emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
765	}
766
767	/* update branch with beqz */
768	if (ctx->insns) {
769		int offset = ninsns_rvoff(ctx->ninsns - branch_off);
770		u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
771		*(u32 *)(ctx->insns + branch_off) = insn;
772	}
773
774	/* arg1: prog */
775	emit_imm(RV_REG_A0, (const s64)p, ctx);
776	/* arg2: prog start time */
777	emit_mv(RV_REG_A1, RV_REG_S1, ctx);
778	/* arg3: &run_ctx */
779	emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
780	ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
781
782	return ret;
783}
784
785static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
786					 const struct btf_func_model *m,
787					 struct bpf_tramp_links *tlinks,
788					 void *func_addr, u32 flags,
789					 struct rv_jit_context *ctx)
790{
791	int i, ret, offset;
792	int *branches_off = NULL;
793	int stack_size = 0, nregs = m->nr_args;
794	int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
795	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
796	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
797	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
798	void *orig_call = func_addr;
799	bool save_ret;
800	u32 insn;
801
802	/* Two types of generated trampoline stack layout:
803	 *
804	 * 1. trampoline called from function entry
805	 * --------------------------------------
806	 * FP + 8	    [ RA to parent func	] return address to parent
807	 *					  function
808	 * FP + 0	    [ FP of parent func ] frame pointer of parent
809	 *					  function
810	 * FP - 8           [ T0 to traced func ] return address of traced
811	 *					  function
812	 * FP - 16	    [ FP of traced func ] frame pointer of traced
813	 *					  function
814	 * --------------------------------------
815	 *
816	 * 2. trampoline called directly
817	 * --------------------------------------
818	 * FP - 8	    [ RA to caller func ] return address to caller
819	 *					  function
820	 * FP - 16	    [ FP of caller func	] frame pointer of caller
821	 *					  function
822	 * --------------------------------------
823	 *
824	 * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
825	 *					  BPF_TRAMP_F_RET_FENTRY_RET
826	 *                  [ argN              ]
827	 *                  [ ...               ]
828	 * FP - args_off    [ arg1              ]
829	 *
830	 * FP - nregs_off   [ regs count        ]
831	 *
832	 * FP - ip_off      [ traced func	] BPF_TRAMP_F_IP_ARG
833	 *
834	 * FP - run_ctx_off [ bpf_tramp_run_ctx ]
835	 *
836	 * FP - sreg_off    [ callee saved reg	]
837	 *
838	 *		    [ pads              ] pads for 16 bytes alignment
839	 */
840
841	if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
842		return -ENOTSUPP;
843
844	/* extra regiters for struct arguments */
845	for (i = 0; i < m->nr_args; i++)
846		if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
847			nregs += round_up(m->arg_size[i], 8) / 8 - 1;
848
849	/* 8 arguments passed by registers */
850	if (nregs > 8)
851		return -ENOTSUPP;
852
853	/* room of trampoline frame to store return address and frame pointer */
854	stack_size += 16;
855
856	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
857	if (save_ret) {
858		stack_size += 16; /* Save both A5 (BPF R0) and A0 */
859		retval_off = stack_size;
860	}
861
862	stack_size += nregs * 8;
863	args_off = stack_size;
864
865	stack_size += 8;
866	nregs_off = stack_size;
867
868	if (flags & BPF_TRAMP_F_IP_ARG) {
869		stack_size += 8;
870		ip_off = stack_size;
871	}
872
873	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
874	run_ctx_off = stack_size;
875
876	stack_size += 8;
877	sreg_off = stack_size;
878
879	stack_size = round_up(stack_size, 16);
880
881	if (func_addr) {
882		/* For the trampoline called from function entry,
883		 * the frame of traced function and the frame of
884		 * trampoline need to be considered.
885		 */
886		emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
887		emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
888		emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
889		emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
890
891		emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
892		emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
893		emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
894		emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
895	} else {
896		/* For the trampoline called directly, just handle
897		 * the frame of trampoline.
898		 */
899		emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
900		emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
901		emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
902		emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
903	}
904
905	/* callee saved register S1 to pass start time */
906	emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
907
908	/* store ip address of the traced function */
909	if (flags & BPF_TRAMP_F_IP_ARG) {
910		emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
911		emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
912	}
913
914	emit_li(RV_REG_T1, nregs, ctx);
915	emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
916
917	store_args(nregs, args_off, ctx);
918
919	/* skip to actual body of traced function */
920	if (flags & BPF_TRAMP_F_SKIP_FRAME)
921		orig_call += RV_FENTRY_NINSNS * 4;
922
923	if (flags & BPF_TRAMP_F_CALL_ORIG) {
924		emit_imm(RV_REG_A0, (const s64)im, ctx);
925		ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
926		if (ret)
927			return ret;
928	}
929
930	for (i = 0; i < fentry->nr_links; i++) {
931		ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
932				      flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
933		if (ret)
934			return ret;
935	}
936
937	if (fmod_ret->nr_links) {
938		branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
939		if (!branches_off)
940			return -ENOMEM;
941
942		/* cleanup to avoid garbage return value confusion */
943		emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
944		for (i = 0; i < fmod_ret->nr_links; i++) {
945			ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
946					      run_ctx_off, true, ctx);
947			if (ret)
948				goto out;
949			emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
950			branches_off[i] = ctx->ninsns;
951			/* nop reserved for conditional jump */
952			emit(rv_nop(), ctx);
953		}
954	}
955
956	if (flags & BPF_TRAMP_F_CALL_ORIG) {
957		restore_args(nregs, args_off, ctx);
958		ret = emit_call((const u64)orig_call, true, ctx);
959		if (ret)
960			goto out;
961		emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
962		emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
963		im->ip_after_call = ctx->insns + ctx->ninsns;
964		/* 2 nops reserved for auipc+jalr pair */
965		emit(rv_nop(), ctx);
966		emit(rv_nop(), ctx);
967	}
968
969	/* update branches saved in invoke_bpf_mod_ret with bnez */
970	for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
971		offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
972		insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
973		*(u32 *)(ctx->insns + branches_off[i]) = insn;
974	}
975
976	for (i = 0; i < fexit->nr_links; i++) {
977		ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
978				      run_ctx_off, false, ctx);
979		if (ret)
980			goto out;
981	}
982
983	if (flags & BPF_TRAMP_F_CALL_ORIG) {
984		im->ip_epilogue = ctx->insns + ctx->ninsns;
985		emit_imm(RV_REG_A0, (const s64)im, ctx);
986		ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
987		if (ret)
988			goto out;
989	}
990
991	if (flags & BPF_TRAMP_F_RESTORE_REGS)
992		restore_args(nregs, args_off, ctx);
993
994	if (save_ret) {
995		emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
996		emit_ld(regmap[BPF_REG_0], -(retval_off - 8), RV_REG_FP, ctx);
997	}
998
999	emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
1000
1001	if (func_addr) {
1002		/* trampoline called from function entry */
1003		emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
1004		emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1005		emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1006
1007		emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
1008		emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
1009		emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);
1010
1011		if (flags & BPF_TRAMP_F_SKIP_FRAME)
1012			/* return to parent function */
1013			emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1014		else
1015			/* return to traced function */
1016			emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
1017	} else {
1018		/* trampoline called directly */
1019		emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
1020		emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1021		emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1022
1023		emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1024	}
1025
1026	ret = ctx->ninsns;
1027out:
1028	kfree(branches_off);
1029	return ret;
1030}
1031
1032int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
1033				void *image_end, const struct btf_func_model *m,
1034				u32 flags, struct bpf_tramp_links *tlinks,
1035				void *func_addr)
1036{
1037	int ret;
1038	struct rv_jit_context ctx;
1039
1040	ctx.ninsns = 0;
1041	ctx.insns = NULL;
1042	ctx.ro_insns = NULL;
1043	ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1044	if (ret < 0)
1045		return ret;
1046
1047	if (ninsns_rvoff(ret) > (long)image_end - (long)image)
1048		return -EFBIG;
1049
1050	ctx.ninsns = 0;
1051	/*
1052	 * The bpf_int_jit_compile() uses a RW buffer (ctx.insns) to write the
1053	 * JITed instructions and later copies it to a RX region (ctx.ro_insns).
1054	 * It also uses ctx.ro_insns to calculate offsets for jumps etc. As the
1055	 * trampoline image uses the same memory area for writing and execution,
1056	 * both ctx.insns and ctx.ro_insns can be set to image.
1057	 */
1058	ctx.insns = image;
1059	ctx.ro_insns = image;
1060	ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1061	if (ret < 0)
1062		return ret;
1063
1064	bpf_flush_icache(ctx.insns, ctx.insns + ctx.ninsns);
1065
1066	return ninsns_rvoff(ret);
1067}
1068
1069int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1070		      bool extra_pass)
1071{
1072	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1073		    BPF_CLASS(insn->code) == BPF_JMP;
1074	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1075	struct bpf_prog_aux *aux = ctx->prog->aux;
1076	u8 rd = -1, rs = -1, code = insn->code;
1077	s16 off = insn->off;
1078	s32 imm = insn->imm;
1079
1080	init_regs(&rd, &rs, insn, ctx);
1081
1082	switch (code) {
1083	/* dst = src */
1084	case BPF_ALU | BPF_MOV | BPF_X:
1085	case BPF_ALU64 | BPF_MOV | BPF_X:
1086		if (imm == 1) {
1087			/* Special mov32 for zext */
1088			emit_zext_32(rd, ctx);
1089			break;
1090		}
1091		switch (insn->off) {
1092		case 0:
1093			emit_mv(rd, rs, ctx);
1094			break;
1095		case 8:
1096		case 16:
1097			emit_slli(RV_REG_T1, rs, 64 - insn->off, ctx);
1098			emit_srai(rd, RV_REG_T1, 64 - insn->off, ctx);
1099			break;
1100		case 32:
1101			emit_addiw(rd, rs, 0, ctx);
1102			break;
1103		}
1104		if (!is64 && !aux->verifier_zext)
1105			emit_zext_32(rd, ctx);
1106		break;
1107
1108	/* dst = dst OP src */
1109	case BPF_ALU | BPF_ADD | BPF_X:
1110	case BPF_ALU64 | BPF_ADD | BPF_X:
1111		emit_add(rd, rd, rs, ctx);
1112		if (!is64 && !aux->verifier_zext)
1113			emit_zext_32(rd, ctx);
1114		break;
1115	case BPF_ALU | BPF_SUB | BPF_X:
1116	case BPF_ALU64 | BPF_SUB | BPF_X:
1117		if (is64)
1118			emit_sub(rd, rd, rs, ctx);
1119		else
1120			emit_subw(rd, rd, rs, ctx);
1121
1122		if (!is64 && !aux->verifier_zext)
1123			emit_zext_32(rd, ctx);
1124		break;
1125	case BPF_ALU | BPF_AND | BPF_X:
1126	case BPF_ALU64 | BPF_AND | BPF_X:
1127		emit_and(rd, rd, rs, ctx);
1128		if (!is64 && !aux->verifier_zext)
1129			emit_zext_32(rd, ctx);
1130		break;
1131	case BPF_ALU | BPF_OR | BPF_X:
1132	case BPF_ALU64 | BPF_OR | BPF_X:
1133		emit_or(rd, rd, rs, ctx);
1134		if (!is64 && !aux->verifier_zext)
1135			emit_zext_32(rd, ctx);
1136		break;
1137	case BPF_ALU | BPF_XOR | BPF_X:
1138	case BPF_ALU64 | BPF_XOR | BPF_X:
1139		emit_xor(rd, rd, rs, ctx);
1140		if (!is64 && !aux->verifier_zext)
1141			emit_zext_32(rd, ctx);
1142		break;
1143	case BPF_ALU | BPF_MUL | BPF_X:
1144	case BPF_ALU64 | BPF_MUL | BPF_X:
1145		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1146		if (!is64 && !aux->verifier_zext)
1147			emit_zext_32(rd, ctx);
1148		break;
1149	case BPF_ALU | BPF_DIV | BPF_X:
1150	case BPF_ALU64 | BPF_DIV | BPF_X:
1151		if (off)
1152			emit(is64 ? rv_div(rd, rd, rs) : rv_divw(rd, rd, rs), ctx);
1153		else
1154			emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1155		if (!is64 && !aux->verifier_zext)
1156			emit_zext_32(rd, ctx);
1157		break;
1158	case BPF_ALU | BPF_MOD | BPF_X:
1159	case BPF_ALU64 | BPF_MOD | BPF_X:
1160		if (off)
1161			emit(is64 ? rv_rem(rd, rd, rs) : rv_remw(rd, rd, rs), ctx);
1162		else
1163			emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1164		if (!is64 && !aux->verifier_zext)
1165			emit_zext_32(rd, ctx);
1166		break;
1167	case BPF_ALU | BPF_LSH | BPF_X:
1168	case BPF_ALU64 | BPF_LSH | BPF_X:
1169		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1170		if (!is64 && !aux->verifier_zext)
1171			emit_zext_32(rd, ctx);
1172		break;
1173	case BPF_ALU | BPF_RSH | BPF_X:
1174	case BPF_ALU64 | BPF_RSH | BPF_X:
1175		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1176		if (!is64 && !aux->verifier_zext)
1177			emit_zext_32(rd, ctx);
1178		break;
1179	case BPF_ALU | BPF_ARSH | BPF_X:
1180	case BPF_ALU64 | BPF_ARSH | BPF_X:
1181		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1182		if (!is64 && !aux->verifier_zext)
1183			emit_zext_32(rd, ctx);
1184		break;
1185
1186	/* dst = -dst */
1187	case BPF_ALU | BPF_NEG:
1188	case BPF_ALU64 | BPF_NEG:
1189		emit_sub(rd, RV_REG_ZERO, rd, ctx);
1190		if (!is64 && !aux->verifier_zext)
1191			emit_zext_32(rd, ctx);
1192		break;
1193
1194	/* dst = BSWAP##imm(dst) */
1195	case BPF_ALU | BPF_END | BPF_FROM_LE:
1196		switch (imm) {
1197		case 16:
1198			emit_slli(rd, rd, 48, ctx);
1199			emit_srli(rd, rd, 48, ctx);
1200			break;
1201		case 32:
1202			if (!aux->verifier_zext)
1203				emit_zext_32(rd, ctx);
1204			break;
1205		case 64:
1206			/* Do nothing */
1207			break;
1208		}
1209		break;
1210
1211	case BPF_ALU | BPF_END | BPF_FROM_BE:
1212	case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1213		emit_li(RV_REG_T2, 0, ctx);
1214
1215		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1216		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1217		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1218		emit_srli(rd, rd, 8, ctx);
1219		if (imm == 16)
1220			goto out_be;
1221
1222		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1223		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1224		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1225		emit_srli(rd, rd, 8, ctx);
1226
1227		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1228		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1229		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1230		emit_srli(rd, rd, 8, ctx);
1231		if (imm == 32)
1232			goto out_be;
1233
1234		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1235		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1236		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1237		emit_srli(rd, rd, 8, ctx);
1238
1239		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1240		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1241		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1242		emit_srli(rd, rd, 8, ctx);
1243
1244		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1245		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1246		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1247		emit_srli(rd, rd, 8, ctx);
1248
1249		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1250		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1251		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1252		emit_srli(rd, rd, 8, ctx);
1253out_be:
1254		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1255		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1256
1257		emit_mv(rd, RV_REG_T2, ctx);
1258		break;
1259
1260	/* dst = imm */
1261	case BPF_ALU | BPF_MOV | BPF_K:
1262	case BPF_ALU64 | BPF_MOV | BPF_K:
1263		emit_imm(rd, imm, ctx);
1264		if (!is64 && !aux->verifier_zext)
1265			emit_zext_32(rd, ctx);
1266		break;
1267
1268	/* dst = dst OP imm */
1269	case BPF_ALU | BPF_ADD | BPF_K:
1270	case BPF_ALU64 | BPF_ADD | BPF_K:
1271		if (is_12b_int(imm)) {
1272			emit_addi(rd, rd, imm, ctx);
1273		} else {
1274			emit_imm(RV_REG_T1, imm, ctx);
1275			emit_add(rd, rd, RV_REG_T1, ctx);
1276		}
1277		if (!is64 && !aux->verifier_zext)
1278			emit_zext_32(rd, ctx);
1279		break;
1280	case BPF_ALU | BPF_SUB | BPF_K:
1281	case BPF_ALU64 | BPF_SUB | BPF_K:
1282		if (is_12b_int(-imm)) {
1283			emit_addi(rd, rd, -imm, ctx);
1284		} else {
1285			emit_imm(RV_REG_T1, imm, ctx);
1286			emit_sub(rd, rd, RV_REG_T1, ctx);
1287		}
1288		if (!is64 && !aux->verifier_zext)
1289			emit_zext_32(rd, ctx);
1290		break;
1291	case BPF_ALU | BPF_AND | BPF_K:
1292	case BPF_ALU64 | BPF_AND | BPF_K:
1293		if (is_12b_int(imm)) {
1294			emit_andi(rd, rd, imm, ctx);
1295		} else {
1296			emit_imm(RV_REG_T1, imm, ctx);
1297			emit_and(rd, rd, RV_REG_T1, ctx);
1298		}
1299		if (!is64 && !aux->verifier_zext)
1300			emit_zext_32(rd, ctx);
1301		break;
1302	case BPF_ALU | BPF_OR | BPF_K:
1303	case BPF_ALU64 | BPF_OR | BPF_K:
1304		if (is_12b_int(imm)) {
1305			emit(rv_ori(rd, rd, imm), ctx);
1306		} else {
1307			emit_imm(RV_REG_T1, imm, ctx);
1308			emit_or(rd, rd, RV_REG_T1, ctx);
1309		}
1310		if (!is64 && !aux->verifier_zext)
1311			emit_zext_32(rd, ctx);
1312		break;
1313	case BPF_ALU | BPF_XOR | BPF_K:
1314	case BPF_ALU64 | BPF_XOR | BPF_K:
1315		if (is_12b_int(imm)) {
1316			emit(rv_xori(rd, rd, imm), ctx);
1317		} else {
1318			emit_imm(RV_REG_T1, imm, ctx);
1319			emit_xor(rd, rd, RV_REG_T1, ctx);
1320		}
1321		if (!is64 && !aux->verifier_zext)
1322			emit_zext_32(rd, ctx);
1323		break;
1324	case BPF_ALU | BPF_MUL | BPF_K:
1325	case BPF_ALU64 | BPF_MUL | BPF_K:
1326		emit_imm(RV_REG_T1, imm, ctx);
1327		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1328		     rv_mulw(rd, rd, RV_REG_T1), ctx);
1329		if (!is64 && !aux->verifier_zext)
1330			emit_zext_32(rd, ctx);
1331		break;
1332	case BPF_ALU | BPF_DIV | BPF_K:
1333	case BPF_ALU64 | BPF_DIV | BPF_K:
1334		emit_imm(RV_REG_T1, imm, ctx);
1335		if (off)
1336			emit(is64 ? rv_div(rd, rd, RV_REG_T1) :
1337			     rv_divw(rd, rd, RV_REG_T1), ctx);
1338		else
1339			emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1340			     rv_divuw(rd, rd, RV_REG_T1), ctx);
1341		if (!is64 && !aux->verifier_zext)
1342			emit_zext_32(rd, ctx);
1343		break;
1344	case BPF_ALU | BPF_MOD | BPF_K:
1345	case BPF_ALU64 | BPF_MOD | BPF_K:
1346		emit_imm(RV_REG_T1, imm, ctx);
1347		if (off)
1348			emit(is64 ? rv_rem(rd, rd, RV_REG_T1) :
1349			     rv_remw(rd, rd, RV_REG_T1), ctx);
1350		else
1351			emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1352			     rv_remuw(rd, rd, RV_REG_T1), ctx);
1353		if (!is64 && !aux->verifier_zext)
1354			emit_zext_32(rd, ctx);
1355		break;
1356	case BPF_ALU | BPF_LSH | BPF_K:
1357	case BPF_ALU64 | BPF_LSH | BPF_K:
1358		emit_slli(rd, rd, imm, ctx);
1359
1360		if (!is64 && !aux->verifier_zext)
1361			emit_zext_32(rd, ctx);
1362		break;
1363	case BPF_ALU | BPF_RSH | BPF_K:
1364	case BPF_ALU64 | BPF_RSH | BPF_K:
1365		if (is64)
1366			emit_srli(rd, rd, imm, ctx);
1367		else
1368			emit(rv_srliw(rd, rd, imm), ctx);
1369
1370		if (!is64 && !aux->verifier_zext)
1371			emit_zext_32(rd, ctx);
1372		break;
1373	case BPF_ALU | BPF_ARSH | BPF_K:
1374	case BPF_ALU64 | BPF_ARSH | BPF_K:
1375		if (is64)
1376			emit_srai(rd, rd, imm, ctx);
1377		else
1378			emit(rv_sraiw(rd, rd, imm), ctx);
1379
1380		if (!is64 && !aux->verifier_zext)
1381			emit_zext_32(rd, ctx);
1382		break;
1383
1384	/* JUMP off */
1385	case BPF_JMP | BPF_JA:
1386	case BPF_JMP32 | BPF_JA:
1387		if (BPF_CLASS(code) == BPF_JMP)
1388			rvoff = rv_offset(i, off, ctx);
1389		else
1390			rvoff = rv_offset(i, imm, ctx);
1391		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1392		if (ret)
1393			return ret;
1394		break;
1395
1396	/* IF (dst COND src) JUMP off */
1397	case BPF_JMP | BPF_JEQ | BPF_X:
1398	case BPF_JMP32 | BPF_JEQ | BPF_X:
1399	case BPF_JMP | BPF_JGT | BPF_X:
1400	case BPF_JMP32 | BPF_JGT | BPF_X:
1401	case BPF_JMP | BPF_JLT | BPF_X:
1402	case BPF_JMP32 | BPF_JLT | BPF_X:
1403	case BPF_JMP | BPF_JGE | BPF_X:
1404	case BPF_JMP32 | BPF_JGE | BPF_X:
1405	case BPF_JMP | BPF_JLE | BPF_X:
1406	case BPF_JMP32 | BPF_JLE | BPF_X:
1407	case BPF_JMP | BPF_JNE | BPF_X:
1408	case BPF_JMP32 | BPF_JNE | BPF_X:
1409	case BPF_JMP | BPF_JSGT | BPF_X:
1410	case BPF_JMP32 | BPF_JSGT | BPF_X:
1411	case BPF_JMP | BPF_JSLT | BPF_X:
1412	case BPF_JMP32 | BPF_JSLT | BPF_X:
1413	case BPF_JMP | BPF_JSGE | BPF_X:
1414	case BPF_JMP32 | BPF_JSGE | BPF_X:
1415	case BPF_JMP | BPF_JSLE | BPF_X:
1416	case BPF_JMP32 | BPF_JSLE | BPF_X:
1417	case BPF_JMP | BPF_JSET | BPF_X:
1418	case BPF_JMP32 | BPF_JSET | BPF_X:
1419		rvoff = rv_offset(i, off, ctx);
1420		if (!is64) {
1421			s = ctx->ninsns;
1422			if (is_signed_bpf_cond(BPF_OP(code)))
1423				emit_sext_32_rd_rs(&rd, &rs, ctx);
1424			else
1425				emit_zext_32_rd_rs(&rd, &rs, ctx);
1426			e = ctx->ninsns;
1427
1428			/* Adjust for extra insns */
1429			rvoff -= ninsns_rvoff(e - s);
1430		}
1431
1432		if (BPF_OP(code) == BPF_JSET) {
1433			/* Adjust for and */
1434			rvoff -= 4;
1435			emit_and(RV_REG_T1, rd, rs, ctx);
1436			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
1437				    ctx);
1438		} else {
1439			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1440		}
1441		break;
1442
1443	/* IF (dst COND imm) JUMP off */
1444	case BPF_JMP | BPF_JEQ | BPF_K:
1445	case BPF_JMP32 | BPF_JEQ | BPF_K:
1446	case BPF_JMP | BPF_JGT | BPF_K:
1447	case BPF_JMP32 | BPF_JGT | BPF_K:
1448	case BPF_JMP | BPF_JLT | BPF_K:
1449	case BPF_JMP32 | BPF_JLT | BPF_K:
1450	case BPF_JMP | BPF_JGE | BPF_K:
1451	case BPF_JMP32 | BPF_JGE | BPF_K:
1452	case BPF_JMP | BPF_JLE | BPF_K:
1453	case BPF_JMP32 | BPF_JLE | BPF_K:
1454	case BPF_JMP | BPF_JNE | BPF_K:
1455	case BPF_JMP32 | BPF_JNE | BPF_K:
1456	case BPF_JMP | BPF_JSGT | BPF_K:
1457	case BPF_JMP32 | BPF_JSGT | BPF_K:
1458	case BPF_JMP | BPF_JSLT | BPF_K:
1459	case BPF_JMP32 | BPF_JSLT | BPF_K:
1460	case BPF_JMP | BPF_JSGE | BPF_K:
1461	case BPF_JMP32 | BPF_JSGE | BPF_K:
1462	case BPF_JMP | BPF_JSLE | BPF_K:
1463	case BPF_JMP32 | BPF_JSLE | BPF_K:
1464		rvoff = rv_offset(i, off, ctx);
1465		s = ctx->ninsns;
1466		if (imm) {
1467			emit_imm(RV_REG_T1, imm, ctx);
1468			rs = RV_REG_T1;
1469		} else {
1470			/* If imm is 0, simply use zero register. */
1471			rs = RV_REG_ZERO;
1472		}
1473		if (!is64) {
1474			if (is_signed_bpf_cond(BPF_OP(code)))
1475				emit_sext_32_rd(&rd, ctx);
1476			else
1477				emit_zext_32_rd_t1(&rd, ctx);
1478		}
1479		e = ctx->ninsns;
1480
1481		/* Adjust for extra insns */
1482		rvoff -= ninsns_rvoff(e - s);
1483		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1484		break;
1485
1486	case BPF_JMP | BPF_JSET | BPF_K:
1487	case BPF_JMP32 | BPF_JSET | BPF_K:
1488		rvoff = rv_offset(i, off, ctx);
1489		s = ctx->ninsns;
1490		if (is_12b_int(imm)) {
1491			emit_andi(RV_REG_T1, rd, imm, ctx);
1492		} else {
1493			emit_imm(RV_REG_T1, imm, ctx);
1494			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1495		}
1496		/* For jset32, we should clear the upper 32 bits of t1, but
1497		 * sign-extension is sufficient here and saves one instruction,
1498		 * as t1 is used only in comparison against zero.
1499		 */
1500		if (!is64 && imm < 0)
1501			emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
1502		e = ctx->ninsns;
1503		rvoff -= ninsns_rvoff(e - s);
1504		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1505		break;
1506
1507	/* function call */
1508	case BPF_JMP | BPF_CALL:
1509	{
1510		bool fixed_addr;
1511		u64 addr;
1512
1513		mark_call(ctx);
1514		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1515					    &addr, &fixed_addr);
1516		if (ret < 0)
1517			return ret;
1518
1519		ret = emit_call(addr, fixed_addr, ctx);
1520		if (ret)
1521			return ret;
1522
1523		if (insn->src_reg != BPF_PSEUDO_CALL)
1524			emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1525		break;
1526	}
1527	/* tail call */
1528	case BPF_JMP | BPF_TAIL_CALL:
1529		if (emit_bpf_tail_call(i, ctx))
1530			return -1;
1531		break;
1532
1533	/* function return */
1534	case BPF_JMP | BPF_EXIT:
1535		if (i == ctx->prog->len - 1)
1536			break;
1537
1538		rvoff = epilogue_offset(ctx);
1539		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1540		if (ret)
1541			return ret;
1542		break;
1543
1544	/* dst = imm64 */
1545	case BPF_LD | BPF_IMM | BPF_DW:
1546	{
1547		struct bpf_insn insn1 = insn[1];
1548		u64 imm64;
1549
1550		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1551		if (bpf_pseudo_func(insn)) {
1552			/* fixed-length insns for extra jit pass */
1553			ret = emit_addr(rd, imm64, extra_pass, ctx);
1554			if (ret)
1555				return ret;
1556		} else {
1557			emit_imm(rd, imm64, ctx);
1558		}
1559
1560		return 1;
1561	}
1562
1563	/* LDX: dst = *(unsigned size *)(src + off) */
1564	case BPF_LDX | BPF_MEM | BPF_B:
1565	case BPF_LDX | BPF_MEM | BPF_H:
1566	case BPF_LDX | BPF_MEM | BPF_W:
1567	case BPF_LDX | BPF_MEM | BPF_DW:
1568	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1569	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1570	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1571	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1572	/* LDSX: dst = *(signed size *)(src + off) */
1573	case BPF_LDX | BPF_MEMSX | BPF_B:
1574	case BPF_LDX | BPF_MEMSX | BPF_H:
1575	case BPF_LDX | BPF_MEMSX | BPF_W:
1576	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1577	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1578	case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1579	{
1580		int insn_len, insns_start;
1581		bool sign_ext;
1582
1583		sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
1584			   BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
1585
1586		switch (BPF_SIZE(code)) {
1587		case BPF_B:
1588			if (is_12b_int(off)) {
1589				insns_start = ctx->ninsns;
1590				if (sign_ext)
1591					emit(rv_lb(rd, off, rs), ctx);
1592				else
1593					emit(rv_lbu(rd, off, rs), ctx);
1594				insn_len = ctx->ninsns - insns_start;
1595				break;
1596			}
1597
1598			emit_imm(RV_REG_T1, off, ctx);
1599			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1600			insns_start = ctx->ninsns;
1601			if (sign_ext)
1602				emit(rv_lb(rd, 0, RV_REG_T1), ctx);
1603			else
1604				emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1605			insn_len = ctx->ninsns - insns_start;
1606			break;
1607		case BPF_H:
1608			if (is_12b_int(off)) {
1609				insns_start = ctx->ninsns;
1610				if (sign_ext)
1611					emit(rv_lh(rd, off, rs), ctx);
1612				else
1613					emit(rv_lhu(rd, off, rs), ctx);
1614				insn_len = ctx->ninsns - insns_start;
1615				break;
1616			}
1617
1618			emit_imm(RV_REG_T1, off, ctx);
1619			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1620			insns_start = ctx->ninsns;
1621			if (sign_ext)
1622				emit(rv_lh(rd, 0, RV_REG_T1), ctx);
1623			else
1624				emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1625			insn_len = ctx->ninsns - insns_start;
1626			break;
1627		case BPF_W:
1628			if (is_12b_int(off)) {
1629				insns_start = ctx->ninsns;
1630				if (sign_ext)
1631					emit(rv_lw(rd, off, rs), ctx);
1632				else
1633					emit(rv_lwu(rd, off, rs), ctx);
1634				insn_len = ctx->ninsns - insns_start;
1635				break;
1636			}
1637
1638			emit_imm(RV_REG_T1, off, ctx);
1639			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1640			insns_start = ctx->ninsns;
1641			if (sign_ext)
1642				emit(rv_lw(rd, 0, RV_REG_T1), ctx);
1643			else
1644				emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1645			insn_len = ctx->ninsns - insns_start;
1646			break;
1647		case BPF_DW:
1648			if (is_12b_int(off)) {
1649				insns_start = ctx->ninsns;
1650				emit_ld(rd, off, rs, ctx);
1651				insn_len = ctx->ninsns - insns_start;
1652				break;
1653			}
1654
1655			emit_imm(RV_REG_T1, off, ctx);
1656			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1657			insns_start = ctx->ninsns;
1658			emit_ld(rd, 0, RV_REG_T1, ctx);
1659			insn_len = ctx->ninsns - insns_start;
1660			break;
1661		}
1662
1663		ret = add_exception_handler(insn, ctx, rd, insn_len);
1664		if (ret)
1665			return ret;
1666
1667		if (BPF_SIZE(code) != BPF_DW && insn_is_zext(&insn[1]))
1668			return 1;
1669		break;
1670	}
1671	/* speculation barrier */
1672	case BPF_ST | BPF_NOSPEC:
1673		break;
1674
1675	/* ST: *(size *)(dst + off) = imm */
1676	case BPF_ST | BPF_MEM | BPF_B:
1677		emit_imm(RV_REG_T1, imm, ctx);
1678		if (is_12b_int(off)) {
1679			emit(rv_sb(rd, off, RV_REG_T1), ctx);
1680			break;
1681		}
1682
1683		emit_imm(RV_REG_T2, off, ctx);
1684		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1685		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1686		break;
1687
1688	case BPF_ST | BPF_MEM | BPF_H:
1689		emit_imm(RV_REG_T1, imm, ctx);
1690		if (is_12b_int(off)) {
1691			emit(rv_sh(rd, off, RV_REG_T1), ctx);
1692			break;
1693		}
1694
1695		emit_imm(RV_REG_T2, off, ctx);
1696		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1697		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1698		break;
1699	case BPF_ST | BPF_MEM | BPF_W:
1700		emit_imm(RV_REG_T1, imm, ctx);
1701		if (is_12b_int(off)) {
1702			emit_sw(rd, off, RV_REG_T1, ctx);
1703			break;
1704		}
1705
1706		emit_imm(RV_REG_T2, off, ctx);
1707		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1708		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1709		break;
1710	case BPF_ST | BPF_MEM | BPF_DW:
1711		emit_imm(RV_REG_T1, imm, ctx);
1712		if (is_12b_int(off)) {
1713			emit_sd(rd, off, RV_REG_T1, ctx);
1714			break;
1715		}
1716
1717		emit_imm(RV_REG_T2, off, ctx);
1718		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1719		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1720		break;
1721
1722	/* STX: *(size *)(dst + off) = src */
1723	case BPF_STX | BPF_MEM | BPF_B:
1724		if (is_12b_int(off)) {
1725			emit(rv_sb(rd, off, rs), ctx);
1726			break;
1727		}
1728
1729		emit_imm(RV_REG_T1, off, ctx);
1730		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1731		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1732		break;
1733	case BPF_STX | BPF_MEM | BPF_H:
1734		if (is_12b_int(off)) {
1735			emit(rv_sh(rd, off, rs), ctx);
1736			break;
1737		}
1738
1739		emit_imm(RV_REG_T1, off, ctx);
1740		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1741		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1742		break;
1743	case BPF_STX | BPF_MEM | BPF_W:
1744		if (is_12b_int(off)) {
1745			emit_sw(rd, off, rs, ctx);
1746			break;
1747		}
1748
1749		emit_imm(RV_REG_T1, off, ctx);
1750		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1751		emit_sw(RV_REG_T1, 0, rs, ctx);
1752		break;
1753	case BPF_STX | BPF_MEM | BPF_DW:
1754		if (is_12b_int(off)) {
1755			emit_sd(rd, off, rs, ctx);
1756			break;
1757		}
1758
1759		emit_imm(RV_REG_T1, off, ctx);
1760		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1761		emit_sd(RV_REG_T1, 0, rs, ctx);
1762		break;
1763	case BPF_STX | BPF_ATOMIC | BPF_W:
1764	case BPF_STX | BPF_ATOMIC | BPF_DW:
1765		emit_atomic(rd, rs, off, imm,
1766			    BPF_SIZE(code) == BPF_DW, ctx);
1767		break;
1768	default:
1769		pr_err("bpf-jit: unknown opcode %02x\n", code);
1770		return -EINVAL;
1771	}
1772
1773	return 0;
1774}
1775
1776void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1777{
1778	int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
1779
1780	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1781	if (bpf_stack_adjust)
1782		mark_fp(ctx);
1783
1784	if (seen_reg(RV_REG_RA, ctx))
1785		stack_adjust += 8;
1786	stack_adjust += 8; /* RV_REG_FP */
1787	if (seen_reg(RV_REG_S1, ctx))
1788		stack_adjust += 8;
1789	if (seen_reg(RV_REG_S2, ctx))
1790		stack_adjust += 8;
1791	if (seen_reg(RV_REG_S3, ctx))
1792		stack_adjust += 8;
1793	if (seen_reg(RV_REG_S4, ctx))
1794		stack_adjust += 8;
1795	if (seen_reg(RV_REG_S5, ctx))
1796		stack_adjust += 8;
1797	if (seen_reg(RV_REG_S6, ctx))
1798		stack_adjust += 8;
1799
1800	stack_adjust = round_up(stack_adjust, 16);
1801	stack_adjust += bpf_stack_adjust;
1802
1803	store_offset = stack_adjust - 8;
1804
1805	/* nops reserved for auipc+jalr pair */
1806	for (i = 0; i < RV_FENTRY_NINSNS; i++)
1807		emit(rv_nop(), ctx);
1808
1809	/* First instruction is always setting the tail-call-counter
1810	 * (TCC) register. This instruction is skipped for tail calls.
1811	 * Force using a 4-byte (non-compressed) instruction.
1812	 */
1813	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1814
1815	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1816
1817	if (seen_reg(RV_REG_RA, ctx)) {
1818		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1819		store_offset -= 8;
1820	}
1821	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1822	store_offset -= 8;
1823	if (seen_reg(RV_REG_S1, ctx)) {
1824		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1825		store_offset -= 8;
1826	}
1827	if (seen_reg(RV_REG_S2, ctx)) {
1828		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1829		store_offset -= 8;
1830	}
1831	if (seen_reg(RV_REG_S3, ctx)) {
1832		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1833		store_offset -= 8;
1834	}
1835	if (seen_reg(RV_REG_S4, ctx)) {
1836		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1837		store_offset -= 8;
1838	}
1839	if (seen_reg(RV_REG_S5, ctx)) {
1840		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1841		store_offset -= 8;
1842	}
1843	if (seen_reg(RV_REG_S6, ctx)) {
1844		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1845		store_offset -= 8;
1846	}
1847
1848	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1849
1850	if (bpf_stack_adjust)
1851		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1852
1853	/* Program contains calls and tail calls, so RV_REG_TCC need
1854	 * to be saved across calls.
1855	 */
1856	if (seen_tail_call(ctx) && seen_call(ctx))
1857		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1858
1859	ctx->stack_size = stack_adjust;
1860}
1861
1862void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1863{
1864	__build_epilogue(false, ctx);
1865}
1866
1867bool bpf_jit_supports_kfunc_call(void)
1868{
1869	return true;
1870}
1871