1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * bpf_jit_comp.c: BPF JIT compiler
4 *
5 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
6 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
7 */
8#include <linux/netdevice.h>
9#include <linux/filter.h>
10#include <linux/if_vlan.h>
11#include <linux/bpf.h>
12#include <linux/memory.h>
13#include <linux/sort.h>
14#include <asm/extable.h>
15#include <asm/set_memory.h>
16#include <asm/nospec-branch.h>
17#include <asm/text-patching.h>
18
19static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
20{
21	if (len == 1)
22		*ptr = bytes;
23	else if (len == 2)
24		*(u16 *)ptr = bytes;
25	else {
26		*(u32 *)ptr = bytes;
27		barrier();
28	}
29	return ptr + len;
30}
31
32#define EMIT(bytes, len) \
33	do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
34
35#define EMIT1(b1)		EMIT(b1, 1)
36#define EMIT2(b1, b2)		EMIT((b1) + ((b2) << 8), 2)
37#define EMIT3(b1, b2, b3)	EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
38#define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
39
40#define EMIT1_off32(b1, off) \
41	do { EMIT1(b1); EMIT(off, 4); } while (0)
42#define EMIT2_off32(b1, b2, off) \
43	do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
44#define EMIT3_off32(b1, b2, b3, off) \
45	do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
46#define EMIT4_off32(b1, b2, b3, b4, off) \
47	do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
48
49static bool is_imm8(int value)
50{
51	return value <= 127 && value >= -128;
52}
53
54static bool is_simm32(s64 value)
55{
56	return value == (s64)(s32)value;
57}
58
59static bool is_uimm32(u64 value)
60{
61	return value == (u64)(u32)value;
62}
63
64/* mov dst, src */
65#define EMIT_mov(DST, SRC)								 \
66	do {										 \
67		if (DST != SRC)								 \
68			EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
69	} while (0)
70
71static int bpf_size_to_x86_bytes(int bpf_size)
72{
73	if (bpf_size == BPF_W)
74		return 4;
75	else if (bpf_size == BPF_H)
76		return 2;
77	else if (bpf_size == BPF_B)
78		return 1;
79	else if (bpf_size == BPF_DW)
80		return 4; /* imm32 */
81	else
82		return 0;
83}
84
85/*
86 * List of x86 cond jumps opcodes (. + s8)
87 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
88 */
89#define X86_JB  0x72
90#define X86_JAE 0x73
91#define X86_JE  0x74
92#define X86_JNE 0x75
93#define X86_JBE 0x76
94#define X86_JA  0x77
95#define X86_JL  0x7C
96#define X86_JGE 0x7D
97#define X86_JLE 0x7E
98#define X86_JG  0x7F
99
100/* Pick a register outside of BPF range for JIT internal work */
101#define AUX_REG (MAX_BPF_JIT_REG + 1)
102#define X86_REG_R9 (MAX_BPF_JIT_REG + 2)
103
104/*
105 * The following table maps BPF registers to x86-64 registers.
106 *
107 * x86-64 register R12 is unused, since if used as base address
108 * register in load/store instructions, it always needs an
109 * extra byte of encoding and is callee saved.
110 *
111 * x86-64 register R9 is not used by BPF programs, but can be used by BPF
112 * trampoline. x86-64 register R10 is used for blinding (if enabled).
113 */
114static const int reg2hex[] = {
115	[BPF_REG_0] = 0,  /* RAX */
116	[BPF_REG_1] = 7,  /* RDI */
117	[BPF_REG_2] = 6,  /* RSI */
118	[BPF_REG_3] = 2,  /* RDX */
119	[BPF_REG_4] = 1,  /* RCX */
120	[BPF_REG_5] = 0,  /* R8  */
121	[BPF_REG_6] = 3,  /* RBX callee saved */
122	[BPF_REG_7] = 5,  /* R13 callee saved */
123	[BPF_REG_8] = 6,  /* R14 callee saved */
124	[BPF_REG_9] = 7,  /* R15 callee saved */
125	[BPF_REG_FP] = 5, /* RBP readonly */
126	[BPF_REG_AX] = 2, /* R10 temp register */
127	[AUX_REG] = 3,    /* R11 temp register */
128	[X86_REG_R9] = 1, /* R9 register, 6th function argument */
129};
130
131static const int reg2pt_regs[] = {
132	[BPF_REG_0] = offsetof(struct pt_regs, ax),
133	[BPF_REG_1] = offsetof(struct pt_regs, di),
134	[BPF_REG_2] = offsetof(struct pt_regs, si),
135	[BPF_REG_3] = offsetof(struct pt_regs, dx),
136	[BPF_REG_4] = offsetof(struct pt_regs, cx),
137	[BPF_REG_5] = offsetof(struct pt_regs, r8),
138	[BPF_REG_6] = offsetof(struct pt_regs, bx),
139	[BPF_REG_7] = offsetof(struct pt_regs, r13),
140	[BPF_REG_8] = offsetof(struct pt_regs, r14),
141	[BPF_REG_9] = offsetof(struct pt_regs, r15),
142};
143
144/*
145 * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
146 * which need extra byte of encoding.
147 * rax,rcx,...,rbp have simpler encoding
148 */
149static bool is_ereg(u32 reg)
150{
151	return (1 << reg) & (BIT(BPF_REG_5) |
152			     BIT(AUX_REG) |
153			     BIT(BPF_REG_7) |
154			     BIT(BPF_REG_8) |
155			     BIT(BPF_REG_9) |
156			     BIT(X86_REG_R9) |
157			     BIT(BPF_REG_AX));
158}
159
160/*
161 * is_ereg_8l() == true if BPF register 'reg' is mapped to access x86-64
162 * lower 8-bit registers dil,sil,bpl,spl,r8b..r15b, which need extra byte
163 * of encoding. al,cl,dl,bl have simpler encoding.
164 */
165static bool is_ereg_8l(u32 reg)
166{
167	return is_ereg(reg) ||
168	    (1 << reg) & (BIT(BPF_REG_1) |
169			  BIT(BPF_REG_2) |
170			  BIT(BPF_REG_FP));
171}
172
173static bool is_axreg(u32 reg)
174{
175	return reg == BPF_REG_0;
176}
177
178/* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
179static u8 add_1mod(u8 byte, u32 reg)
180{
181	if (is_ereg(reg))
182		byte |= 1;
183	return byte;
184}
185
186static u8 add_2mod(u8 byte, u32 r1, u32 r2)
187{
188	if (is_ereg(r1))
189		byte |= 1;
190	if (is_ereg(r2))
191		byte |= 4;
192	return byte;
193}
194
195/* Encode 'dst_reg' register into x86-64 opcode 'byte' */
196static u8 add_1reg(u8 byte, u32 dst_reg)
197{
198	return byte + reg2hex[dst_reg];
199}
200
201/* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
202static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
203{
204	return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
205}
206
207static void jit_fill_hole(void *area, unsigned int size)
208{
209	/* Fill whole space with INT3 instructions */
210	memset(area, 0xcc, size);
211}
212
213struct jit_context {
214	int cleanup_addr; /* Epilogue code offset */
215
216	/*
217	 * Program specific offsets of labels in the code; these rely on the
218	 * JIT doing at least 2 passes, recording the position on the first
219	 * pass, only to generate the correct offset on the second pass.
220	 */
221	int tail_call_direct_label;
222	int tail_call_indirect_label;
223};
224
225/* Maximum number of bytes emitted while JITing one eBPF insn */
226#define BPF_MAX_INSN_SIZE	128
227#define BPF_INSN_SAFETY		64
228
229/* Number of bytes emit_patch() needs to generate instructions */
230#define X86_PATCH_SIZE		5
231/* Number of bytes that will be skipped on tailcall */
232#define X86_TAIL_CALL_OFFSET	11
233
234static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
235{
236	u8 *prog = *pprog;
237	int cnt = 0;
238
239	if (callee_regs_used[0])
240		EMIT1(0x53);         /* push rbx */
241	if (callee_regs_used[1])
242		EMIT2(0x41, 0x55);   /* push r13 */
243	if (callee_regs_used[2])
244		EMIT2(0x41, 0x56);   /* push r14 */
245	if (callee_regs_used[3])
246		EMIT2(0x41, 0x57);   /* push r15 */
247	*pprog = prog;
248}
249
250static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
251{
252	u8 *prog = *pprog;
253	int cnt = 0;
254
255	if (callee_regs_used[3])
256		EMIT2(0x41, 0x5F);   /* pop r15 */
257	if (callee_regs_used[2])
258		EMIT2(0x41, 0x5E);   /* pop r14 */
259	if (callee_regs_used[1])
260		EMIT2(0x41, 0x5D);   /* pop r13 */
261	if (callee_regs_used[0])
262		EMIT1(0x5B);         /* pop rbx */
263	*pprog = prog;
264}
265
266/*
267 * Emit x86-64 prologue code for BPF program.
268 * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
269 * while jumping to another program
270 */
271static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
272			  bool tail_call_reachable, bool is_subprog)
273{
274	u8 *prog = *pprog;
275	int cnt = X86_PATCH_SIZE;
276
277	/* BPF trampoline can be made to work without these nops,
278	 * but let's waste 5 bytes for now and optimize later
279	 */
280	memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt);
281	prog += cnt;
282	if (!ebpf_from_cbpf) {
283		if (tail_call_reachable && !is_subprog)
284			EMIT2(0x31, 0xC0); /* xor eax, eax */
285		else
286			EMIT2(0x66, 0x90); /* nop2 */
287	}
288	EMIT1(0x55);             /* push rbp */
289	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
290	/* sub rsp, rounded_stack_depth */
291	if (stack_depth)
292		EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
293	if (tail_call_reachable)
294		EMIT1(0x50);         /* push rax */
295	*pprog = prog;
296}
297
298static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
299{
300	u8 *prog = *pprog;
301	int cnt = 0;
302	s64 offset;
303
304	offset = func - (ip + X86_PATCH_SIZE);
305	if (!is_simm32(offset)) {
306		pr_err("Target call %p is out of range\n", func);
307		return -ERANGE;
308	}
309	EMIT1_off32(opcode, offset);
310	*pprog = prog;
311	return 0;
312}
313
314static int emit_call(u8 **pprog, void *func, void *ip)
315{
316	return emit_patch(pprog, func, ip, 0xE8);
317}
318
319static int emit_jump(u8 **pprog, void *func, void *ip)
320{
321	return emit_patch(pprog, func, ip, 0xE9);
322}
323
324static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
325				void *old_addr, void *new_addr,
326				const bool text_live)
327{
328	const u8 *nop_insn = ideal_nops[NOP_ATOMIC5];
329	u8 old_insn[X86_PATCH_SIZE];
330	u8 new_insn[X86_PATCH_SIZE];
331	u8 *prog;
332	int ret;
333
334	memcpy(old_insn, nop_insn, X86_PATCH_SIZE);
335	if (old_addr) {
336		prog = old_insn;
337		ret = t == BPF_MOD_CALL ?
338		      emit_call(&prog, old_addr, ip) :
339		      emit_jump(&prog, old_addr, ip);
340		if (ret)
341			return ret;
342	}
343
344	memcpy(new_insn, nop_insn, X86_PATCH_SIZE);
345	if (new_addr) {
346		prog = new_insn;
347		ret = t == BPF_MOD_CALL ?
348		      emit_call(&prog, new_addr, ip) :
349		      emit_jump(&prog, new_addr, ip);
350		if (ret)
351			return ret;
352	}
353
354	ret = -EBUSY;
355	mutex_lock(&text_mutex);
356	if (memcmp(ip, old_insn, X86_PATCH_SIZE))
357		goto out;
358	ret = 1;
359	if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
360		if (text_live)
361			text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
362		else
363			memcpy(ip, new_insn, X86_PATCH_SIZE);
364		ret = 0;
365	}
366out:
367	mutex_unlock(&text_mutex);
368	return ret;
369}
370
371int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
372		       void *old_addr, void *new_addr)
373{
374	if (!is_kernel_text((long)ip) &&
375	    !is_bpf_text_address((long)ip))
376		/* BPF poking in modules is not supported */
377		return -EINVAL;
378
379	return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
380}
381
382#define EMIT_LFENCE()	EMIT3(0x0F, 0xAE, 0xE8)
383
384static void emit_indirect_jump(u8 **pprog, int reg, u8 *ip)
385{
386	u8 *prog = *pprog;
387	int cnt = 0;
388
389#ifdef CONFIG_RETPOLINE
390	if (cpu_feature_enabled(X86_FEATURE_RETPOLINE_LFENCE)) {
391		EMIT_LFENCE();
392		EMIT2(0xFF, 0xE0 + reg);
393	} else if (cpu_feature_enabled(X86_FEATURE_RETPOLINE)) {
394		emit_jump(&prog, &__x86_indirect_thunk_array[reg], ip);
395	} else
396#endif
397	EMIT2(0xFF, 0xE0 + reg);
398
399	*pprog = prog;
400}
401
402static void emit_return(u8 **pprog, u8 *ip)
403{
404	u8 *prog = *pprog;
405	int cnt = 0;
406
407	if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) {
408		emit_jump(&prog, &__x86_return_thunk, ip);
409	} else {
410		EMIT1(0xC3);		/* ret */
411		if (IS_ENABLED(CONFIG_SLS))
412			EMIT1(0xCC);	/* int3 */
413	}
414
415	*pprog = prog;
416}
417
418/*
419 * Generate the following code:
420 *
421 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
422 *   if (index >= array->map.max_entries)
423 *     goto out;
424 *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
425 *     goto out;
426 *   prog = array->ptrs[index];
427 *   if (prog == NULL)
428 *     goto out;
429 *   goto *(prog->bpf_func + prologue_size);
430 * out:
431 */
432static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
433					u32 stack_depth, u8 *ip,
434					struct jit_context *ctx)
435{
436	int tcc_off = -4 - round_up(stack_depth, 8);
437	u8 *prog = *pprog, *start = *pprog;
438	int cnt = 0, offset;
439
440	/*
441	 * rdi - pointer to ctx
442	 * rsi - pointer to bpf_array
443	 * rdx - index in bpf_array
444	 */
445
446	/*
447	 * if (index >= array->map.max_entries)
448	 *	goto out;
449	 */
450	EMIT2(0x89, 0xD2);                        /* mov edx, edx */
451	EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
452	      offsetof(struct bpf_array, map.max_entries));
453
454	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
455	EMIT2(X86_JBE, offset);                   /* jbe out */
456
457	/*
458	 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
459	 *	goto out;
460	 */
461	EMIT2_off32(0x8B, 0x85, tcc_off);         /* mov eax, dword ptr [rbp - tcc_off] */
462	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
463
464	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
465	EMIT2(X86_JA, offset);                    /* ja out */
466	EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
467	EMIT2_off32(0x89, 0x85, tcc_off);         /* mov dword ptr [rbp - tcc_off], eax */
468
469	/* prog = array->ptrs[index]; */
470	EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6,       /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
471		    offsetof(struct bpf_array, ptrs));
472
473	/*
474	 * if (prog == NULL)
475	 *	goto out;
476	 */
477	EMIT3(0x48, 0x85, 0xC9);                  /* test rcx,rcx */
478
479	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
480	EMIT2(X86_JE, offset);                    /* je out */
481
482	pop_callee_regs(&prog, callee_regs_used);
483
484	EMIT1(0x58);                              /* pop rax */
485	if (stack_depth)
486		EMIT3_off32(0x48, 0x81, 0xC4,     /* add rsp, sd */
487			    round_up(stack_depth, 8));
488
489	/* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
490	EMIT4(0x48, 0x8B, 0x49,                   /* mov rcx, qword ptr [rcx + 32] */
491	      offsetof(struct bpf_prog, bpf_func));
492	EMIT4(0x48, 0x83, 0xC1,                   /* add rcx, X86_TAIL_CALL_OFFSET */
493	      X86_TAIL_CALL_OFFSET);
494	/*
495	 * Now we're ready to jump into next BPF program
496	 * rdi == ctx (1st arg)
497	 * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
498	 */
499	emit_indirect_jump(&prog, 1 /* rcx */, ip + (prog - start));
500
501	/* out: */
502	ctx->tail_call_indirect_label = prog - start;
503	*pprog = prog;
504}
505
506static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
507				      u8 **pprog, u8 *ip,
508				      bool *callee_regs_used, u32 stack_depth,
509				      struct jit_context *ctx)
510{
511	int tcc_off = -4 - round_up(stack_depth, 8);
512	u8 *prog = *pprog, *start = *pprog;
513	int cnt = 0, offset;
514
515	/*
516	 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
517	 *	goto out;
518	 */
519	EMIT2_off32(0x8B, 0x85, tcc_off);             /* mov eax, dword ptr [rbp - tcc_off] */
520	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
521
522	offset = ctx->tail_call_direct_label - (prog + 2 - start);
523	EMIT2(X86_JA, offset);                        /* ja out */
524	EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
525	EMIT2_off32(0x89, 0x85, tcc_off);             /* mov dword ptr [rbp - tcc_off], eax */
526
527	poke->tailcall_bypass = ip + (prog - start);
528	poke->adj_off = X86_TAIL_CALL_OFFSET;
529	poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE;
530	poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
531
532	emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
533		  poke->tailcall_bypass);
534
535	pop_callee_regs(&prog, callee_regs_used);
536	EMIT1(0x58);                                  /* pop rax */
537	if (stack_depth)
538		EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
539
540	memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
541	prog += X86_PATCH_SIZE;
542
543	/* out: */
544	ctx->tail_call_direct_label = prog - start;
545
546	*pprog = prog;
547}
548
549static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
550{
551	struct bpf_jit_poke_descriptor *poke;
552	struct bpf_array *array;
553	struct bpf_prog *target;
554	int i, ret;
555
556	for (i = 0; i < prog->aux->size_poke_tab; i++) {
557		poke = &prog->aux->poke_tab[i];
558		if (poke->aux && poke->aux != prog->aux)
559			continue;
560
561		WARN_ON_ONCE(READ_ONCE(poke->tailcall_target_stable));
562
563		if (poke->reason != BPF_POKE_REASON_TAIL_CALL)
564			continue;
565
566		array = container_of(poke->tail_call.map, struct bpf_array, map);
567		mutex_lock(&array->aux->poke_mutex);
568		target = array->ptrs[poke->tail_call.key];
569		if (target) {
570			/* Plain memcpy is used when image is not live yet
571			 * and still not locked as read-only. Once poke
572			 * location is active (poke->tailcall_target_stable),
573			 * any parallel bpf_arch_text_poke() might occur
574			 * still on the read-write image until we finally
575			 * locked it as read-only. Both modifications on
576			 * the given image are under text_mutex to avoid
577			 * interference.
578			 */
579			ret = __bpf_arch_text_poke(poke->tailcall_target,
580						   BPF_MOD_JUMP, NULL,
581						   (u8 *)target->bpf_func +
582						   poke->adj_off, false);
583			BUG_ON(ret < 0);
584			ret = __bpf_arch_text_poke(poke->tailcall_bypass,
585						   BPF_MOD_JUMP,
586						   (u8 *)poke->tailcall_target +
587						   X86_PATCH_SIZE, NULL, false);
588			BUG_ON(ret < 0);
589		}
590		WRITE_ONCE(poke->tailcall_target_stable, true);
591		mutex_unlock(&array->aux->poke_mutex);
592	}
593}
594
595static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
596			   u32 dst_reg, const u32 imm32)
597{
598	u8 *prog = *pprog;
599	u8 b1, b2, b3;
600	int cnt = 0;
601
602	/*
603	 * Optimization: if imm32 is positive, use 'mov %eax, imm32'
604	 * (which zero-extends imm32) to save 2 bytes.
605	 */
606	if (sign_propagate && (s32)imm32 < 0) {
607		/* 'mov %rax, imm32' sign extends imm32 */
608		b1 = add_1mod(0x48, dst_reg);
609		b2 = 0xC7;
610		b3 = 0xC0;
611		EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
612		goto done;
613	}
614
615	/*
616	 * Optimization: if imm32 is zero, use 'xor %eax, %eax'
617	 * to save 3 bytes.
618	 */
619	if (imm32 == 0) {
620		if (is_ereg(dst_reg))
621			EMIT1(add_2mod(0x40, dst_reg, dst_reg));
622		b2 = 0x31; /* xor */
623		b3 = 0xC0;
624		EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
625		goto done;
626	}
627
628	/* mov %eax, imm32 */
629	if (is_ereg(dst_reg))
630		EMIT1(add_1mod(0x40, dst_reg));
631	EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
632done:
633	*pprog = prog;
634}
635
636static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
637			   const u32 imm32_hi, const u32 imm32_lo)
638{
639	u8 *prog = *pprog;
640	int cnt = 0;
641
642	if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
643		/*
644		 * For emitting plain u32, where sign bit must not be
645		 * propagated LLVM tends to load imm64 over mov32
646		 * directly, so save couple of bytes by just doing
647		 * 'mov %eax, imm32' instead.
648		 */
649		emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
650	} else {
651		/* movabsq %rax, imm64 */
652		EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
653		EMIT(imm32_lo, 4);
654		EMIT(imm32_hi, 4);
655	}
656
657	*pprog = prog;
658}
659
660static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
661{
662	u8 *prog = *pprog;
663	int cnt = 0;
664
665	if (is64) {
666		/* mov dst, src */
667		EMIT_mov(dst_reg, src_reg);
668	} else {
669		/* mov32 dst, src */
670		if (is_ereg(dst_reg) || is_ereg(src_reg))
671			EMIT1(add_2mod(0x40, dst_reg, src_reg));
672		EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
673	}
674
675	*pprog = prog;
676}
677
678/* LDX: dst_reg = *(u8*)(src_reg + off) */
679static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
680{
681	u8 *prog = *pprog;
682	int cnt = 0;
683
684	switch (size) {
685	case BPF_B:
686		/* Emit 'movzx rax, byte ptr [rax + off]' */
687		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
688		break;
689	case BPF_H:
690		/* Emit 'movzx rax, word ptr [rax + off]' */
691		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
692		break;
693	case BPF_W:
694		/* Emit 'mov eax, dword ptr [rax+0x14]' */
695		if (is_ereg(dst_reg) || is_ereg(src_reg))
696			EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
697		else
698			EMIT1(0x8B);
699		break;
700	case BPF_DW:
701		/* Emit 'mov rax, qword ptr [rax+0x14]' */
702		EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
703		break;
704	}
705	/*
706	 * If insn->off == 0 we can save one extra byte, but
707	 * special case of x86 R13 which always needs an offset
708	 * is not worth the hassle
709	 */
710	if (is_imm8(off))
711		EMIT2(add_2reg(0x40, src_reg, dst_reg), off);
712	else
713		EMIT1_off32(add_2reg(0x80, src_reg, dst_reg), off);
714	*pprog = prog;
715}
716
717/* STX: *(u8*)(dst_reg + off) = src_reg */
718static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
719{
720	u8 *prog = *pprog;
721	int cnt = 0;
722
723	switch (size) {
724	case BPF_B:
725		/* Emit 'mov byte ptr [rax + off], al' */
726		if (is_ereg(dst_reg) || is_ereg_8l(src_reg))
727			/* Add extra byte for eregs or SIL,DIL,BPL in src_reg */
728			EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
729		else
730			EMIT1(0x88);
731		break;
732	case BPF_H:
733		if (is_ereg(dst_reg) || is_ereg(src_reg))
734			EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
735		else
736			EMIT2(0x66, 0x89);
737		break;
738	case BPF_W:
739		if (is_ereg(dst_reg) || is_ereg(src_reg))
740			EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
741		else
742			EMIT1(0x89);
743		break;
744	case BPF_DW:
745		EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
746		break;
747	}
748	if (is_imm8(off))
749		EMIT2(add_2reg(0x40, dst_reg, src_reg), off);
750	else
751		EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), off);
752	*pprog = prog;
753}
754
755static bool ex_handler_bpf(const struct exception_table_entry *x,
756			   struct pt_regs *regs, int trapnr,
757			   unsigned long error_code, unsigned long fault_addr)
758{
759	u32 reg = x->fixup >> 8;
760
761	/* jump over faulting load and clear dest register */
762	*(unsigned long *)((void *)regs + reg) = 0;
763	regs->ip += x->fixup & 0xff;
764	return true;
765}
766
767static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
768			     bool *regs_used, bool *tail_call_seen)
769{
770	int i;
771
772	for (i = 1; i <= insn_cnt; i++, insn++) {
773		if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
774			*tail_call_seen = true;
775		if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
776			regs_used[0] = true;
777		if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
778			regs_used[1] = true;
779		if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
780			regs_used[2] = true;
781		if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
782			regs_used[3] = true;
783	}
784}
785
786static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
787		  int oldproglen, struct jit_context *ctx)
788{
789	bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
790	struct bpf_insn *insn = bpf_prog->insnsi;
791	bool callee_regs_used[4] = {};
792	int insn_cnt = bpf_prog->len;
793	bool tail_call_seen = false;
794	bool seen_exit = false;
795	u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
796	int i, cnt = 0, excnt = 0;
797	int proglen = 0;
798	u8 *prog = temp;
799
800	detect_reg_usage(insn, insn_cnt, callee_regs_used,
801			 &tail_call_seen);
802
803	/* tail call's presence in current prog implies it is reachable */
804	tail_call_reachable |= tail_call_seen;
805
806	emit_prologue(&prog, bpf_prog->aux->stack_depth,
807		      bpf_prog_was_classic(bpf_prog), tail_call_reachable,
808		      bpf_prog->aux->func_idx != 0);
809	push_callee_regs(&prog, callee_regs_used);
810	addrs[0] = prog - temp;
811
812	for (i = 1; i <= insn_cnt; i++, insn++) {
813		const s32 imm32 = insn->imm;
814		u32 dst_reg = insn->dst_reg;
815		u32 src_reg = insn->src_reg;
816		u8 b2 = 0, b3 = 0;
817		s64 jmp_offset;
818		u8 jmp_cond;
819		int ilen;
820		u8 *func;
821
822		switch (insn->code) {
823			/* ALU */
824		case BPF_ALU | BPF_ADD | BPF_X:
825		case BPF_ALU | BPF_SUB | BPF_X:
826		case BPF_ALU | BPF_AND | BPF_X:
827		case BPF_ALU | BPF_OR | BPF_X:
828		case BPF_ALU | BPF_XOR | BPF_X:
829		case BPF_ALU64 | BPF_ADD | BPF_X:
830		case BPF_ALU64 | BPF_SUB | BPF_X:
831		case BPF_ALU64 | BPF_AND | BPF_X:
832		case BPF_ALU64 | BPF_OR | BPF_X:
833		case BPF_ALU64 | BPF_XOR | BPF_X:
834			switch (BPF_OP(insn->code)) {
835			case BPF_ADD: b2 = 0x01; break;
836			case BPF_SUB: b2 = 0x29; break;
837			case BPF_AND: b2 = 0x21; break;
838			case BPF_OR: b2 = 0x09; break;
839			case BPF_XOR: b2 = 0x31; break;
840			}
841			if (BPF_CLASS(insn->code) == BPF_ALU64)
842				EMIT1(add_2mod(0x48, dst_reg, src_reg));
843			else if (is_ereg(dst_reg) || is_ereg(src_reg))
844				EMIT1(add_2mod(0x40, dst_reg, src_reg));
845			EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
846			break;
847
848		case BPF_ALU64 | BPF_MOV | BPF_X:
849		case BPF_ALU | BPF_MOV | BPF_X:
850			emit_mov_reg(&prog,
851				     BPF_CLASS(insn->code) == BPF_ALU64,
852				     dst_reg, src_reg);
853			break;
854
855			/* neg dst */
856		case BPF_ALU | BPF_NEG:
857		case BPF_ALU64 | BPF_NEG:
858			if (BPF_CLASS(insn->code) == BPF_ALU64)
859				EMIT1(add_1mod(0x48, dst_reg));
860			else if (is_ereg(dst_reg))
861				EMIT1(add_1mod(0x40, dst_reg));
862			EMIT2(0xF7, add_1reg(0xD8, dst_reg));
863			break;
864
865		case BPF_ALU | BPF_ADD | BPF_K:
866		case BPF_ALU | BPF_SUB | BPF_K:
867		case BPF_ALU | BPF_AND | BPF_K:
868		case BPF_ALU | BPF_OR | BPF_K:
869		case BPF_ALU | BPF_XOR | BPF_K:
870		case BPF_ALU64 | BPF_ADD | BPF_K:
871		case BPF_ALU64 | BPF_SUB | BPF_K:
872		case BPF_ALU64 | BPF_AND | BPF_K:
873		case BPF_ALU64 | BPF_OR | BPF_K:
874		case BPF_ALU64 | BPF_XOR | BPF_K:
875			if (BPF_CLASS(insn->code) == BPF_ALU64)
876				EMIT1(add_1mod(0x48, dst_reg));
877			else if (is_ereg(dst_reg))
878				EMIT1(add_1mod(0x40, dst_reg));
879
880			/*
881			 * b3 holds 'normal' opcode, b2 short form only valid
882			 * in case dst is eax/rax.
883			 */
884			switch (BPF_OP(insn->code)) {
885			case BPF_ADD:
886				b3 = 0xC0;
887				b2 = 0x05;
888				break;
889			case BPF_SUB:
890				b3 = 0xE8;
891				b2 = 0x2D;
892				break;
893			case BPF_AND:
894				b3 = 0xE0;
895				b2 = 0x25;
896				break;
897			case BPF_OR:
898				b3 = 0xC8;
899				b2 = 0x0D;
900				break;
901			case BPF_XOR:
902				b3 = 0xF0;
903				b2 = 0x35;
904				break;
905			}
906
907			if (is_imm8(imm32))
908				EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
909			else if (is_axreg(dst_reg))
910				EMIT1_off32(b2, imm32);
911			else
912				EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
913			break;
914
915		case BPF_ALU64 | BPF_MOV | BPF_K:
916		case BPF_ALU | BPF_MOV | BPF_K:
917			emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
918				       dst_reg, imm32);
919			break;
920
921		case BPF_LD | BPF_IMM | BPF_DW:
922			emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
923			insn++;
924			i++;
925			break;
926
927			/* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
928		case BPF_ALU | BPF_MOD | BPF_X:
929		case BPF_ALU | BPF_DIV | BPF_X:
930		case BPF_ALU | BPF_MOD | BPF_K:
931		case BPF_ALU | BPF_DIV | BPF_K:
932		case BPF_ALU64 | BPF_MOD | BPF_X:
933		case BPF_ALU64 | BPF_DIV | BPF_X:
934		case BPF_ALU64 | BPF_MOD | BPF_K:
935		case BPF_ALU64 | BPF_DIV | BPF_K:
936			EMIT1(0x50); /* push rax */
937			EMIT1(0x52); /* push rdx */
938
939			if (BPF_SRC(insn->code) == BPF_X)
940				/* mov r11, src_reg */
941				EMIT_mov(AUX_REG, src_reg);
942			else
943				/* mov r11, imm32 */
944				EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
945
946			/* mov rax, dst_reg */
947			EMIT_mov(BPF_REG_0, dst_reg);
948
949			/*
950			 * xor edx, edx
951			 * equivalent to 'xor rdx, rdx', but one byte less
952			 */
953			EMIT2(0x31, 0xd2);
954
955			if (BPF_CLASS(insn->code) == BPF_ALU64)
956				/* div r11 */
957				EMIT3(0x49, 0xF7, 0xF3);
958			else
959				/* div r11d */
960				EMIT3(0x41, 0xF7, 0xF3);
961
962			if (BPF_OP(insn->code) == BPF_MOD)
963				/* mov r11, rdx */
964				EMIT3(0x49, 0x89, 0xD3);
965			else
966				/* mov r11, rax */
967				EMIT3(0x49, 0x89, 0xC3);
968
969			EMIT1(0x5A); /* pop rdx */
970			EMIT1(0x58); /* pop rax */
971
972			/* mov dst_reg, r11 */
973			EMIT_mov(dst_reg, AUX_REG);
974			break;
975
976		case BPF_ALU | BPF_MUL | BPF_K:
977		case BPF_ALU | BPF_MUL | BPF_X:
978		case BPF_ALU64 | BPF_MUL | BPF_K:
979		case BPF_ALU64 | BPF_MUL | BPF_X:
980		{
981			bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
982
983			if (dst_reg != BPF_REG_0)
984				EMIT1(0x50); /* push rax */
985			if (dst_reg != BPF_REG_3)
986				EMIT1(0x52); /* push rdx */
987
988			/* mov r11, dst_reg */
989			EMIT_mov(AUX_REG, dst_reg);
990
991			if (BPF_SRC(insn->code) == BPF_X)
992				emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
993			else
994				emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
995
996			if (is64)
997				EMIT1(add_1mod(0x48, AUX_REG));
998			else if (is_ereg(AUX_REG))
999				EMIT1(add_1mod(0x40, AUX_REG));
1000			/* mul(q) r11 */
1001			EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
1002
1003			if (dst_reg != BPF_REG_3)
1004				EMIT1(0x5A); /* pop rdx */
1005			if (dst_reg != BPF_REG_0) {
1006				/* mov dst_reg, rax */
1007				EMIT_mov(dst_reg, BPF_REG_0);
1008				EMIT1(0x58); /* pop rax */
1009			}
1010			break;
1011		}
1012			/* Shifts */
1013		case BPF_ALU | BPF_LSH | BPF_K:
1014		case BPF_ALU | BPF_RSH | BPF_K:
1015		case BPF_ALU | BPF_ARSH | BPF_K:
1016		case BPF_ALU64 | BPF_LSH | BPF_K:
1017		case BPF_ALU64 | BPF_RSH | BPF_K:
1018		case BPF_ALU64 | BPF_ARSH | BPF_K:
1019			if (BPF_CLASS(insn->code) == BPF_ALU64)
1020				EMIT1(add_1mod(0x48, dst_reg));
1021			else if (is_ereg(dst_reg))
1022				EMIT1(add_1mod(0x40, dst_reg));
1023
1024			switch (BPF_OP(insn->code)) {
1025			case BPF_LSH: b3 = 0xE0; break;
1026			case BPF_RSH: b3 = 0xE8; break;
1027			case BPF_ARSH: b3 = 0xF8; break;
1028			}
1029
1030			if (imm32 == 1)
1031				EMIT2(0xD1, add_1reg(b3, dst_reg));
1032			else
1033				EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
1034			break;
1035
1036		case BPF_ALU | BPF_LSH | BPF_X:
1037		case BPF_ALU | BPF_RSH | BPF_X:
1038		case BPF_ALU | BPF_ARSH | BPF_X:
1039		case BPF_ALU64 | BPF_LSH | BPF_X:
1040		case BPF_ALU64 | BPF_RSH | BPF_X:
1041		case BPF_ALU64 | BPF_ARSH | BPF_X:
1042
1043			/* Check for bad case when dst_reg == rcx */
1044			if (dst_reg == BPF_REG_4) {
1045				/* mov r11, dst_reg */
1046				EMIT_mov(AUX_REG, dst_reg);
1047				dst_reg = AUX_REG;
1048			}
1049
1050			if (src_reg != BPF_REG_4) { /* common case */
1051				EMIT1(0x51); /* push rcx */
1052
1053				/* mov rcx, src_reg */
1054				EMIT_mov(BPF_REG_4, src_reg);
1055			}
1056
1057			/* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
1058			if (BPF_CLASS(insn->code) == BPF_ALU64)
1059				EMIT1(add_1mod(0x48, dst_reg));
1060			else if (is_ereg(dst_reg))
1061				EMIT1(add_1mod(0x40, dst_reg));
1062
1063			switch (BPF_OP(insn->code)) {
1064			case BPF_LSH: b3 = 0xE0; break;
1065			case BPF_RSH: b3 = 0xE8; break;
1066			case BPF_ARSH: b3 = 0xF8; break;
1067			}
1068			EMIT2(0xD3, add_1reg(b3, dst_reg));
1069
1070			if (src_reg != BPF_REG_4)
1071				EMIT1(0x59); /* pop rcx */
1072
1073			if (insn->dst_reg == BPF_REG_4)
1074				/* mov dst_reg, r11 */
1075				EMIT_mov(insn->dst_reg, AUX_REG);
1076			break;
1077
1078		case BPF_ALU | BPF_END | BPF_FROM_BE:
1079			switch (imm32) {
1080			case 16:
1081				/* Emit 'ror %ax, 8' to swap lower 2 bytes */
1082				EMIT1(0x66);
1083				if (is_ereg(dst_reg))
1084					EMIT1(0x41);
1085				EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
1086
1087				/* Emit 'movzwl eax, ax' */
1088				if (is_ereg(dst_reg))
1089					EMIT3(0x45, 0x0F, 0xB7);
1090				else
1091					EMIT2(0x0F, 0xB7);
1092				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
1093				break;
1094			case 32:
1095				/* Emit 'bswap eax' to swap lower 4 bytes */
1096				if (is_ereg(dst_reg))
1097					EMIT2(0x41, 0x0F);
1098				else
1099					EMIT1(0x0F);
1100				EMIT1(add_1reg(0xC8, dst_reg));
1101				break;
1102			case 64:
1103				/* Emit 'bswap rax' to swap 8 bytes */
1104				EMIT3(add_1mod(0x48, dst_reg), 0x0F,
1105				      add_1reg(0xC8, dst_reg));
1106				break;
1107			}
1108			break;
1109
1110		case BPF_ALU | BPF_END | BPF_FROM_LE:
1111			switch (imm32) {
1112			case 16:
1113				/*
1114				 * Emit 'movzwl eax, ax' to zero extend 16-bit
1115				 * into 64 bit
1116				 */
1117				if (is_ereg(dst_reg))
1118					EMIT3(0x45, 0x0F, 0xB7);
1119				else
1120					EMIT2(0x0F, 0xB7);
1121				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
1122				break;
1123			case 32:
1124				/* Emit 'mov eax, eax' to clear upper 32-bits */
1125				if (is_ereg(dst_reg))
1126					EMIT1(0x45);
1127				EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
1128				break;
1129			case 64:
1130				/* nop */
1131				break;
1132			}
1133			break;
1134
1135			/* speculation barrier */
1136		case BPF_ST | BPF_NOSPEC:
1137			if (boot_cpu_has(X86_FEATURE_XMM2))
1138				EMIT_LFENCE();
1139			break;
1140
1141			/* ST: *(u8*)(dst_reg + off) = imm */
1142		case BPF_ST | BPF_MEM | BPF_B:
1143			if (is_ereg(dst_reg))
1144				EMIT2(0x41, 0xC6);
1145			else
1146				EMIT1(0xC6);
1147			goto st;
1148		case BPF_ST | BPF_MEM | BPF_H:
1149			if (is_ereg(dst_reg))
1150				EMIT3(0x66, 0x41, 0xC7);
1151			else
1152				EMIT2(0x66, 0xC7);
1153			goto st;
1154		case BPF_ST | BPF_MEM | BPF_W:
1155			if (is_ereg(dst_reg))
1156				EMIT2(0x41, 0xC7);
1157			else
1158				EMIT1(0xC7);
1159			goto st;
1160		case BPF_ST | BPF_MEM | BPF_DW:
1161			EMIT2(add_1mod(0x48, dst_reg), 0xC7);
1162
1163st:			if (is_imm8(insn->off))
1164				EMIT2(add_1reg(0x40, dst_reg), insn->off);
1165			else
1166				EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
1167
1168			EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
1169			break;
1170
1171			/* STX: *(u8*)(dst_reg + off) = src_reg */
1172		case BPF_STX | BPF_MEM | BPF_B:
1173		case BPF_STX | BPF_MEM | BPF_H:
1174		case BPF_STX | BPF_MEM | BPF_W:
1175		case BPF_STX | BPF_MEM | BPF_DW:
1176			emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
1177			break;
1178
1179			/* LDX: dst_reg = *(u8*)(src_reg + off) */
1180		case BPF_LDX | BPF_MEM | BPF_B:
1181		case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1182		case BPF_LDX | BPF_MEM | BPF_H:
1183		case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1184		case BPF_LDX | BPF_MEM | BPF_W:
1185		case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1186		case BPF_LDX | BPF_MEM | BPF_DW:
1187		case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1188			emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
1189			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
1190				struct exception_table_entry *ex;
1191				u8 *_insn = image + proglen;
1192				s64 delta;
1193
1194				if (!bpf_prog->aux->extable)
1195					break;
1196
1197				if (excnt >= bpf_prog->aux->num_exentries) {
1198					pr_err("ex gen bug\n");
1199					return -EFAULT;
1200				}
1201				ex = &bpf_prog->aux->extable[excnt++];
1202
1203				delta = _insn - (u8 *)&ex->insn;
1204				if (!is_simm32(delta)) {
1205					pr_err("extable->insn doesn't fit into 32-bit\n");
1206					return -EFAULT;
1207				}
1208				ex->insn = delta;
1209
1210				delta = (u8 *)ex_handler_bpf - (u8 *)&ex->handler;
1211				if (!is_simm32(delta)) {
1212					pr_err("extable->handler doesn't fit into 32-bit\n");
1213					return -EFAULT;
1214				}
1215				ex->handler = delta;
1216
1217				if (dst_reg > BPF_REG_9) {
1218					pr_err("verifier error\n");
1219					return -EFAULT;
1220				}
1221				/*
1222				 * Compute size of x86 insn and its target dest x86 register.
1223				 * ex_handler_bpf() will use lower 8 bits to adjust
1224				 * pt_regs->ip to jump over this x86 instruction
1225				 * and upper bits to figure out which pt_regs to zero out.
1226				 * End result: x86 insn "mov rbx, qword ptr [rax+0x14]"
1227				 * of 4 bytes will be ignored and rbx will be zero inited.
1228				 */
1229				ex->fixup = (prog - temp) | (reg2pt_regs[dst_reg] << 8);
1230			}
1231			break;
1232
1233			/* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
1234		case BPF_STX | BPF_XADD | BPF_W:
1235			/* Emit 'lock add dword ptr [rax + off], eax' */
1236			if (is_ereg(dst_reg) || is_ereg(src_reg))
1237				EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
1238			else
1239				EMIT2(0xF0, 0x01);
1240			goto xadd;
1241		case BPF_STX | BPF_XADD | BPF_DW:
1242			EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
1243xadd:			if (is_imm8(insn->off))
1244				EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
1245			else
1246				EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
1247					    insn->off);
1248			break;
1249
1250			/* call */
1251		case BPF_JMP | BPF_CALL:
1252			func = (u8 *) __bpf_call_base + imm32;
1253			if (tail_call_reachable) {
1254				/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
1255				EMIT3_off32(0x48, 0x8B, 0x85,
1256					    -round_up(bpf_prog->aux->stack_depth, 8) - 8);
1257				if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
1258					return -EINVAL;
1259			} else {
1260				if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
1261					return -EINVAL;
1262			}
1263			break;
1264
1265		case BPF_JMP | BPF_TAIL_CALL:
1266			if (imm32)
1267				emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
1268							  &prog, image + addrs[i - 1],
1269							  callee_regs_used,
1270							  bpf_prog->aux->stack_depth,
1271							  ctx);
1272			else
1273				emit_bpf_tail_call_indirect(&prog,
1274							    callee_regs_used,
1275							    bpf_prog->aux->stack_depth,
1276							    image + addrs[i - 1],
1277							    ctx);
1278			break;
1279
1280			/* cond jump */
1281		case BPF_JMP | BPF_JEQ | BPF_X:
1282		case BPF_JMP | BPF_JNE | BPF_X:
1283		case BPF_JMP | BPF_JGT | BPF_X:
1284		case BPF_JMP | BPF_JLT | BPF_X:
1285		case BPF_JMP | BPF_JGE | BPF_X:
1286		case BPF_JMP | BPF_JLE | BPF_X:
1287		case BPF_JMP | BPF_JSGT | BPF_X:
1288		case BPF_JMP | BPF_JSLT | BPF_X:
1289		case BPF_JMP | BPF_JSGE | BPF_X:
1290		case BPF_JMP | BPF_JSLE | BPF_X:
1291		case BPF_JMP32 | BPF_JEQ | BPF_X:
1292		case BPF_JMP32 | BPF_JNE | BPF_X:
1293		case BPF_JMP32 | BPF_JGT | BPF_X:
1294		case BPF_JMP32 | BPF_JLT | BPF_X:
1295		case BPF_JMP32 | BPF_JGE | BPF_X:
1296		case BPF_JMP32 | BPF_JLE | BPF_X:
1297		case BPF_JMP32 | BPF_JSGT | BPF_X:
1298		case BPF_JMP32 | BPF_JSLT | BPF_X:
1299		case BPF_JMP32 | BPF_JSGE | BPF_X:
1300		case BPF_JMP32 | BPF_JSLE | BPF_X:
1301			/* cmp dst_reg, src_reg */
1302			if (BPF_CLASS(insn->code) == BPF_JMP)
1303				EMIT1(add_2mod(0x48, dst_reg, src_reg));
1304			else if (is_ereg(dst_reg) || is_ereg(src_reg))
1305				EMIT1(add_2mod(0x40, dst_reg, src_reg));
1306			EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
1307			goto emit_cond_jmp;
1308
1309		case BPF_JMP | BPF_JSET | BPF_X:
1310		case BPF_JMP32 | BPF_JSET | BPF_X:
1311			/* test dst_reg, src_reg */
1312			if (BPF_CLASS(insn->code) == BPF_JMP)
1313				EMIT1(add_2mod(0x48, dst_reg, src_reg));
1314			else if (is_ereg(dst_reg) || is_ereg(src_reg))
1315				EMIT1(add_2mod(0x40, dst_reg, src_reg));
1316			EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
1317			goto emit_cond_jmp;
1318
1319		case BPF_JMP | BPF_JSET | BPF_K:
1320		case BPF_JMP32 | BPF_JSET | BPF_K:
1321			/* test dst_reg, imm32 */
1322			if (BPF_CLASS(insn->code) == BPF_JMP)
1323				EMIT1(add_1mod(0x48, dst_reg));
1324			else if (is_ereg(dst_reg))
1325				EMIT1(add_1mod(0x40, dst_reg));
1326			EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
1327			goto emit_cond_jmp;
1328
1329		case BPF_JMP | BPF_JEQ | BPF_K:
1330		case BPF_JMP | BPF_JNE | BPF_K:
1331		case BPF_JMP | BPF_JGT | BPF_K:
1332		case BPF_JMP | BPF_JLT | BPF_K:
1333		case BPF_JMP | BPF_JGE | BPF_K:
1334		case BPF_JMP | BPF_JLE | BPF_K:
1335		case BPF_JMP | BPF_JSGT | BPF_K:
1336		case BPF_JMP | BPF_JSLT | BPF_K:
1337		case BPF_JMP | BPF_JSGE | BPF_K:
1338		case BPF_JMP | BPF_JSLE | BPF_K:
1339		case BPF_JMP32 | BPF_JEQ | BPF_K:
1340		case BPF_JMP32 | BPF_JNE | BPF_K:
1341		case BPF_JMP32 | BPF_JGT | BPF_K:
1342		case BPF_JMP32 | BPF_JLT | BPF_K:
1343		case BPF_JMP32 | BPF_JGE | BPF_K:
1344		case BPF_JMP32 | BPF_JLE | BPF_K:
1345		case BPF_JMP32 | BPF_JSGT | BPF_K:
1346		case BPF_JMP32 | BPF_JSLT | BPF_K:
1347		case BPF_JMP32 | BPF_JSGE | BPF_K:
1348		case BPF_JMP32 | BPF_JSLE | BPF_K:
1349			/* test dst_reg, dst_reg to save one extra byte */
1350			if (imm32 == 0) {
1351				if (BPF_CLASS(insn->code) == BPF_JMP)
1352					EMIT1(add_2mod(0x48, dst_reg, dst_reg));
1353				else if (is_ereg(dst_reg))
1354					EMIT1(add_2mod(0x40, dst_reg, dst_reg));
1355				EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg));
1356				goto emit_cond_jmp;
1357			}
1358
1359			/* cmp dst_reg, imm8/32 */
1360			if (BPF_CLASS(insn->code) == BPF_JMP)
1361				EMIT1(add_1mod(0x48, dst_reg));
1362			else if (is_ereg(dst_reg))
1363				EMIT1(add_1mod(0x40, dst_reg));
1364
1365			if (is_imm8(imm32))
1366				EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
1367			else
1368				EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
1369
1370emit_cond_jmp:		/* Convert BPF opcode to x86 */
1371			switch (BPF_OP(insn->code)) {
1372			case BPF_JEQ:
1373				jmp_cond = X86_JE;
1374				break;
1375			case BPF_JSET:
1376			case BPF_JNE:
1377				jmp_cond = X86_JNE;
1378				break;
1379			case BPF_JGT:
1380				/* GT is unsigned '>', JA in x86 */
1381				jmp_cond = X86_JA;
1382				break;
1383			case BPF_JLT:
1384				/* LT is unsigned '<', JB in x86 */
1385				jmp_cond = X86_JB;
1386				break;
1387			case BPF_JGE:
1388				/* GE is unsigned '>=', JAE in x86 */
1389				jmp_cond = X86_JAE;
1390				break;
1391			case BPF_JLE:
1392				/* LE is unsigned '<=', JBE in x86 */
1393				jmp_cond = X86_JBE;
1394				break;
1395			case BPF_JSGT:
1396				/* Signed '>', GT in x86 */
1397				jmp_cond = X86_JG;
1398				break;
1399			case BPF_JSLT:
1400				/* Signed '<', LT in x86 */
1401				jmp_cond = X86_JL;
1402				break;
1403			case BPF_JSGE:
1404				/* Signed '>=', GE in x86 */
1405				jmp_cond = X86_JGE;
1406				break;
1407			case BPF_JSLE:
1408				/* Signed '<=', LE in x86 */
1409				jmp_cond = X86_JLE;
1410				break;
1411			default: /* to silence GCC warning */
1412				return -EFAULT;
1413			}
1414			jmp_offset = addrs[i + insn->off] - addrs[i];
1415			if (is_imm8(jmp_offset)) {
1416				EMIT2(jmp_cond, jmp_offset);
1417			} else if (is_simm32(jmp_offset)) {
1418				EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
1419			} else {
1420				pr_err("cond_jmp gen bug %llx\n", jmp_offset);
1421				return -EFAULT;
1422			}
1423
1424			break;
1425
1426		case BPF_JMP | BPF_JA:
1427			if (insn->off == -1)
1428				/* -1 jmp instructions will always jump
1429				 * backwards two bytes. Explicitly handling
1430				 * this case avoids wasting too many passes
1431				 * when there are long sequences of replaced
1432				 * dead code.
1433				 */
1434				jmp_offset = -2;
1435			else
1436				jmp_offset = addrs[i + insn->off] - addrs[i];
1437
1438			if (!jmp_offset)
1439				/* Optimize out nop jumps */
1440				break;
1441emit_jmp:
1442			if (is_imm8(jmp_offset)) {
1443				EMIT2(0xEB, jmp_offset);
1444			} else if (is_simm32(jmp_offset)) {
1445				EMIT1_off32(0xE9, jmp_offset);
1446			} else {
1447				pr_err("jmp gen bug %llx\n", jmp_offset);
1448				return -EFAULT;
1449			}
1450			break;
1451
1452		case BPF_JMP | BPF_EXIT:
1453			if (seen_exit) {
1454				jmp_offset = ctx->cleanup_addr - addrs[i];
1455				goto emit_jmp;
1456			}
1457			seen_exit = true;
1458			/* Update cleanup_addr */
1459			ctx->cleanup_addr = proglen;
1460			pop_callee_regs(&prog, callee_regs_used);
1461			EMIT1(0xC9);         /* leave */
1462			emit_return(&prog, image + addrs[i - 1] + (prog - temp));
1463			break;
1464
1465		default:
1466			/*
1467			 * By design x86-64 JIT should support all BPF instructions.
1468			 * This error will be seen if new instruction was added
1469			 * to the interpreter, but not to the JIT, or if there is
1470			 * junk in bpf_prog.
1471			 */
1472			pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1473			return -EINVAL;
1474		}
1475
1476		ilen = prog - temp;
1477		if (ilen > BPF_MAX_INSN_SIZE) {
1478			pr_err("bpf_jit: fatal insn size error\n");
1479			return -EFAULT;
1480		}
1481
1482		if (image) {
1483			/*
1484			 * When populating the image, assert that:
1485			 *
1486			 *  i) We do not write beyond the allocated space, and
1487			 * ii) addrs[i] did not change from the prior run, in order
1488			 *     to validate assumptions made for computing branch
1489			 *     displacements.
1490			 */
1491			if (unlikely(proglen + ilen > oldproglen ||
1492				     proglen + ilen != addrs[i])) {
1493				pr_err("bpf_jit: fatal error\n");
1494				return -EFAULT;
1495			}
1496			memcpy(image + proglen, temp, ilen);
1497		}
1498		proglen += ilen;
1499		addrs[i] = proglen;
1500		prog = temp;
1501	}
1502
1503	if (image && excnt != bpf_prog->aux->num_exentries) {
1504		pr_err("extable is not populated\n");
1505		return -EFAULT;
1506	}
1507	return proglen;
1508}
1509
1510static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
1511		      int stack_size)
1512{
1513	int i;
1514	/* Store function arguments to stack.
1515	 * For a function that accepts two pointers the sequence will be:
1516	 * mov QWORD PTR [rbp-0x10],rdi
1517	 * mov QWORD PTR [rbp-0x8],rsi
1518	 */
1519	for (i = 0; i < min(nr_args, 6); i++)
1520		emit_stx(prog, bytes_to_bpf_size(m->arg_size[i]),
1521			 BPF_REG_FP,
1522			 i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
1523			 -(stack_size - i * 8));
1524}
1525
1526static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
1527			 int stack_size)
1528{
1529	int i;
1530
1531	/* Restore function arguments from stack.
1532	 * For a function that accepts two pointers the sequence will be:
1533	 * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10]
1534	 * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8]
1535	 */
1536	for (i = 0; i < min(nr_args, 6); i++)
1537		emit_ldx(prog, bytes_to_bpf_size(m->arg_size[i]),
1538			 i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
1539			 BPF_REG_FP,
1540			 -(stack_size - i * 8));
1541}
1542
1543static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
1544			   struct bpf_prog *p, int stack_size, bool save_ret)
1545{
1546	u8 *prog = *pprog;
1547	int cnt = 0;
1548
1549	if (p->aux->sleepable) {
1550		if (emit_call(&prog, __bpf_prog_enter_sleepable, prog))
1551			return -EINVAL;
1552	} else {
1553		if (emit_call(&prog, __bpf_prog_enter, prog))
1554			return -EINVAL;
1555		/* remember prog start time returned by __bpf_prog_enter */
1556		emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0);
1557	}
1558
1559	/* arg1: lea rdi, [rbp - stack_size] */
1560	EMIT4(0x48, 0x8D, 0x7D, -stack_size);
1561	/* arg2: progs[i]->insnsi for interpreter */
1562	if (!p->jited)
1563		emit_mov_imm64(&prog, BPF_REG_2,
1564			       (long) p->insnsi >> 32,
1565			       (u32) (long) p->insnsi);
1566	/* call JITed bpf program or interpreter */
1567	if (emit_call(&prog, p->bpf_func, prog))
1568		return -EINVAL;
1569
1570	/*
1571	 * BPF_TRAMP_MODIFY_RETURN trampolines can modify the return
1572	 * of the previous call which is then passed on the stack to
1573	 * the next BPF program.
1574	 *
1575	 * BPF_TRAMP_FENTRY trampoline may need to return the return
1576	 * value of BPF_PROG_TYPE_STRUCT_OPS prog.
1577	 */
1578	if (save_ret)
1579		emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
1580
1581	if (p->aux->sleepable) {
1582		if (emit_call(&prog, __bpf_prog_exit_sleepable, prog))
1583			return -EINVAL;
1584	} else {
1585		/* arg1: mov rdi, progs[i] */
1586		emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32,
1587			       (u32) (long) p);
1588		/* arg2: mov rsi, rbx <- start time in nsec */
1589		emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6);
1590		if (emit_call(&prog, __bpf_prog_exit, prog))
1591			return -EINVAL;
1592	}
1593
1594	*pprog = prog;
1595	return 0;
1596}
1597
1598static void emit_nops(u8 **pprog, unsigned int len)
1599{
1600	unsigned int i, noplen;
1601	u8 *prog = *pprog;
1602	int cnt = 0;
1603
1604	while (len > 0) {
1605		noplen = len;
1606
1607		if (noplen > ASM_NOP_MAX)
1608			noplen = ASM_NOP_MAX;
1609
1610		for (i = 0; i < noplen; i++)
1611			EMIT1(ideal_nops[noplen][i]);
1612		len -= noplen;
1613	}
1614
1615	*pprog = prog;
1616}
1617
1618static void emit_align(u8 **pprog, u32 align)
1619{
1620	u8 *target, *prog = *pprog;
1621
1622	target = PTR_ALIGN(prog, align);
1623	if (target != prog)
1624		emit_nops(&prog, target - prog);
1625
1626	*pprog = prog;
1627}
1628
1629static int emit_cond_near_jump(u8 **pprog, void *func, void *ip, u8 jmp_cond)
1630{
1631	u8 *prog = *pprog;
1632	int cnt = 0;
1633	s64 offset;
1634
1635	offset = func - (ip + 2 + 4);
1636	if (!is_simm32(offset)) {
1637		pr_err("Target %p is out of range\n", func);
1638		return -EINVAL;
1639	}
1640	EMIT2_off32(0x0F, jmp_cond + 0x10, offset);
1641	*pprog = prog;
1642	return 0;
1643}
1644
1645static int invoke_bpf(const struct btf_func_model *m, u8 **pprog,
1646		      struct bpf_tramp_progs *tp, int stack_size,
1647		      bool save_ret)
1648{
1649	int i;
1650	u8 *prog = *pprog;
1651
1652	for (i = 0; i < tp->nr_progs; i++) {
1653		if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size,
1654				    save_ret))
1655			return -EINVAL;
1656	}
1657	*pprog = prog;
1658	return 0;
1659}
1660
1661static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
1662			      struct bpf_tramp_progs *tp, int stack_size,
1663			      u8 **branches)
1664{
1665	u8 *prog = *pprog;
1666	int i, cnt = 0;
1667
1668	/* The first fmod_ret program will receive a garbage return value.
1669	 * Set this to 0 to avoid confusing the program.
1670	 */
1671	emit_mov_imm32(&prog, false, BPF_REG_0, 0);
1672	emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
1673	for (i = 0; i < tp->nr_progs; i++) {
1674		if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size, true))
1675			return -EINVAL;
1676
1677		/* mod_ret prog stored return value into [rbp - 8]. Emit:
1678		 * if (*(u64 *)(rbp - 8) !=  0)
1679		 *	goto do_fexit;
1680		 */
1681		/* cmp QWORD PTR [rbp - 0x8], 0x0 */
1682		EMIT4(0x48, 0x83, 0x7d, 0xf8); EMIT1(0x00);
1683
1684		/* Save the location of the branch and Generate 6 nops
1685		 * (4 bytes for an offset and 2 bytes for the jump) These nops
1686		 * are replaced with a conditional jump once do_fexit (i.e. the
1687		 * start of the fexit invocation) is finalized.
1688		 */
1689		branches[i] = prog;
1690		emit_nops(&prog, 4 + 2);
1691	}
1692
1693	*pprog = prog;
1694	return 0;
1695}
1696
1697static bool is_valid_bpf_tramp_flags(unsigned int flags)
1698{
1699	if ((flags & BPF_TRAMP_F_RESTORE_REGS) &&
1700	    (flags & BPF_TRAMP_F_SKIP_FRAME))
1701		return false;
1702
1703	/*
1704	 * BPF_TRAMP_F_RET_FENTRY_RET is only used by bpf_struct_ops,
1705	 * and it must be used alone.
1706	 */
1707	if ((flags & BPF_TRAMP_F_RET_FENTRY_RET) &&
1708	    (flags & ~BPF_TRAMP_F_RET_FENTRY_RET))
1709		return false;
1710
1711	return true;
1712}
1713
1714/* Example:
1715 * __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev);
1716 * its 'struct btf_func_model' will be nr_args=2
1717 * The assembly code when eth_type_trans is executing after trampoline:
1718 *
1719 * push rbp
1720 * mov rbp, rsp
1721 * sub rsp, 16                     // space for skb and dev
1722 * push rbx                        // temp regs to pass start time
1723 * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
1724 * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
1725 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
1726 * mov rbx, rax                    // remember start time in bpf stats are enabled
1727 * lea rdi, [rbp - 16]             // R1==ctx of bpf prog
1728 * call addr_of_jited_FENTRY_prog
1729 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
1730 * mov rsi, rbx                    // prog start time
1731 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
1732 * mov rdi, qword ptr [rbp - 16]   // restore skb pointer from stack
1733 * mov rsi, qword ptr [rbp - 8]    // restore dev pointer from stack
1734 * pop rbx
1735 * leave
1736 * ret
1737 *
1738 * eth_type_trans has 5 byte nop at the beginning. These 5 bytes will be
1739 * replaced with 'call generated_bpf_trampoline'. When it returns
1740 * eth_type_trans will continue executing with original skb and dev pointers.
1741 *
1742 * The assembly code when eth_type_trans is called from trampoline:
1743 *
1744 * push rbp
1745 * mov rbp, rsp
1746 * sub rsp, 24                     // space for skb, dev, return value
1747 * push rbx                        // temp regs to pass start time
1748 * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
1749 * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
1750 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
1751 * mov rbx, rax                    // remember start time if bpf stats are enabled
1752 * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
1753 * call addr_of_jited_FENTRY_prog  // bpf prog can access skb and dev
1754 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
1755 * mov rsi, rbx                    // prog start time
1756 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
1757 * mov rdi, qword ptr [rbp - 24]   // restore skb pointer from stack
1758 * mov rsi, qword ptr [rbp - 16]   // restore dev pointer from stack
1759 * call eth_type_trans+5           // execute body of eth_type_trans
1760 * mov qword ptr [rbp - 8], rax    // save return value
1761 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
1762 * mov rbx, rax                    // remember start time in bpf stats are enabled
1763 * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
1764 * call addr_of_jited_FEXIT_prog   // bpf prog can access skb, dev, return value
1765 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
1766 * mov rsi, rbx                    // prog start time
1767 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
1768 * mov rax, qword ptr [rbp - 8]    // restore eth_type_trans's return value
1769 * pop rbx
1770 * leave
1771 * add rsp, 8                      // skip eth_type_trans's frame
1772 * ret                             // return to its caller
1773 */
1774int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
1775				const struct btf_func_model *m, u32 flags,
1776				struct bpf_tramp_progs *tprogs,
1777				void *orig_call)
1778{
1779	int ret, i, cnt = 0, nr_args = m->nr_args;
1780	int stack_size = nr_args * 8;
1781	struct bpf_tramp_progs *fentry = &tprogs[BPF_TRAMP_FENTRY];
1782	struct bpf_tramp_progs *fexit = &tprogs[BPF_TRAMP_FEXIT];
1783	struct bpf_tramp_progs *fmod_ret = &tprogs[BPF_TRAMP_MODIFY_RETURN];
1784	u8 **branches = NULL;
1785	u8 *prog;
1786	bool save_ret;
1787
1788	/* x86-64 supports up to 6 arguments. 7+ can be added in the future */
1789	if (nr_args > 6)
1790		return -ENOTSUPP;
1791
1792	if (!is_valid_bpf_tramp_flags(flags))
1793		return -EINVAL;
1794
1795	/* room for return value of orig_call or fentry prog */
1796	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
1797	if (save_ret)
1798		stack_size += 8;
1799
1800	if (flags & BPF_TRAMP_F_SKIP_FRAME)
1801		/* skip patched call instruction and point orig_call to actual
1802		 * body of the kernel function.
1803		 */
1804		orig_call += X86_PATCH_SIZE;
1805
1806	prog = image;
1807
1808	EMIT1(0x55);		 /* push rbp */
1809	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
1810	EMIT4(0x48, 0x83, 0xEC, stack_size); /* sub rsp, stack_size */
1811	EMIT1(0x53);		 /* push rbx */
1812
1813	save_regs(m, &prog, nr_args, stack_size);
1814
1815	if (flags & BPF_TRAMP_F_CALL_ORIG) {
1816		/* arg1: mov rdi, im */
1817		emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
1818		if (emit_call(&prog, __bpf_tramp_enter, prog)) {
1819			ret = -EINVAL;
1820			goto cleanup;
1821		}
1822	}
1823
1824	if (fentry->nr_progs)
1825		if (invoke_bpf(m, &prog, fentry, stack_size,
1826			       flags & BPF_TRAMP_F_RET_FENTRY_RET))
1827			return -EINVAL;
1828
1829	if (fmod_ret->nr_progs) {
1830		branches = kcalloc(fmod_ret->nr_progs, sizeof(u8 *),
1831				   GFP_KERNEL);
1832		if (!branches)
1833			return -ENOMEM;
1834
1835		if (invoke_bpf_mod_ret(m, &prog, fmod_ret, stack_size,
1836				       branches)) {
1837			ret = -EINVAL;
1838			goto cleanup;
1839		}
1840	}
1841
1842	if (flags & BPF_TRAMP_F_CALL_ORIG) {
1843		restore_regs(m, &prog, nr_args, stack_size);
1844
1845		/* call original function */
1846		if (emit_call(&prog, orig_call, prog)) {
1847			ret = -EINVAL;
1848			goto cleanup;
1849		}
1850		/* remember return value in a stack for bpf prog to access */
1851		emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
1852		im->ip_after_call = prog;
1853		memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
1854		prog += X86_PATCH_SIZE;
1855	}
1856
1857	if (fmod_ret->nr_progs) {
1858		/* From Intel 64 and IA-32 Architectures Optimization
1859		 * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
1860		 * Coding Rule 11: All branch targets should be 16-byte
1861		 * aligned.
1862		 */
1863		emit_align(&prog, 16);
1864		/* Update the branches saved in invoke_bpf_mod_ret with the
1865		 * aligned address of do_fexit.
1866		 */
1867		for (i = 0; i < fmod_ret->nr_progs; i++)
1868			emit_cond_near_jump(&branches[i], prog, branches[i],
1869					    X86_JNE);
1870	}
1871
1872	if (fexit->nr_progs)
1873		if (invoke_bpf(m, &prog, fexit, stack_size, false)) {
1874			ret = -EINVAL;
1875			goto cleanup;
1876		}
1877
1878	if (flags & BPF_TRAMP_F_RESTORE_REGS)
1879		restore_regs(m, &prog, nr_args, stack_size);
1880
1881	/* This needs to be done regardless. If there were fmod_ret programs,
1882	 * the return value is only updated on the stack and still needs to be
1883	 * restored to R0.
1884	 */
1885	if (flags & BPF_TRAMP_F_CALL_ORIG) {
1886		im->ip_epilogue = prog;
1887		/* arg1: mov rdi, im */
1888		emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
1889		if (emit_call(&prog, __bpf_tramp_exit, prog)) {
1890			ret = -EINVAL;
1891			goto cleanup;
1892		}
1893	}
1894	/* restore return value of orig_call or fentry prog back into RAX */
1895	if (save_ret)
1896		emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
1897
1898	EMIT1(0x5B); /* pop rbx */
1899	EMIT1(0xC9); /* leave */
1900	if (flags & BPF_TRAMP_F_SKIP_FRAME)
1901		/* skip our return address and return to parent */
1902		EMIT4(0x48, 0x83, 0xC4, 8); /* add rsp, 8 */
1903	emit_return(&prog, prog);
1904	/* Make sure the trampoline generation logic doesn't overflow */
1905	if (WARN_ON_ONCE(prog > (u8 *)image_end - BPF_INSN_SAFETY)) {
1906		ret = -EFAULT;
1907		goto cleanup;
1908	}
1909	ret = prog - (u8 *)image;
1910
1911cleanup:
1912	kfree(branches);
1913	return ret;
1914}
1915
1916static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs)
1917{
1918	u8 *jg_reloc, *prog = *pprog;
1919	int pivot, err, jg_bytes = 1, cnt = 0;
1920	s64 jg_offset;
1921
1922	if (a == b) {
1923		/* Leaf node of recursion, i.e. not a range of indices
1924		 * anymore.
1925		 */
1926		EMIT1(add_1mod(0x48, BPF_REG_3));	/* cmp rdx,func */
1927		if (!is_simm32(progs[a]))
1928			return -1;
1929		EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3),
1930			    progs[a]);
1931		err = emit_cond_near_jump(&prog,	/* je func */
1932					  (void *)progs[a], prog,
1933					  X86_JE);
1934		if (err)
1935			return err;
1936
1937		emit_indirect_jump(&prog, 2 /* rdx */, prog);
1938
1939		*pprog = prog;
1940		return 0;
1941	}
1942
1943	/* Not a leaf node, so we pivot, and recursively descend into
1944	 * the lower and upper ranges.
1945	 */
1946	pivot = (b - a) / 2;
1947	EMIT1(add_1mod(0x48, BPF_REG_3));		/* cmp rdx,func */
1948	if (!is_simm32(progs[a + pivot]))
1949		return -1;
1950	EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3), progs[a + pivot]);
1951
1952	if (pivot > 2) {				/* jg upper_part */
1953		/* Require near jump. */
1954		jg_bytes = 4;
1955		EMIT2_off32(0x0F, X86_JG + 0x10, 0);
1956	} else {
1957		EMIT2(X86_JG, 0);
1958	}
1959	jg_reloc = prog;
1960
1961	err = emit_bpf_dispatcher(&prog, a, a + pivot,	/* emit lower_part */
1962				  progs);
1963	if (err)
1964		return err;
1965
1966	/* From Intel 64 and IA-32 Architectures Optimization
1967	 * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
1968	 * Coding Rule 11: All branch targets should be 16-byte
1969	 * aligned.
1970	 */
1971	emit_align(&prog, 16);
1972	jg_offset = prog - jg_reloc;
1973	emit_code(jg_reloc - jg_bytes, jg_offset, jg_bytes);
1974
1975	err = emit_bpf_dispatcher(&prog, a + pivot + 1,	/* emit upper_part */
1976				  b, progs);
1977	if (err)
1978		return err;
1979
1980	*pprog = prog;
1981	return 0;
1982}
1983
1984static int cmp_ips(const void *a, const void *b)
1985{
1986	const s64 *ipa = a;
1987	const s64 *ipb = b;
1988
1989	if (*ipa > *ipb)
1990		return 1;
1991	if (*ipa < *ipb)
1992		return -1;
1993	return 0;
1994}
1995
1996int arch_prepare_bpf_dispatcher(void *image, s64 *funcs, int num_funcs)
1997{
1998	u8 *prog = image;
1999
2000	sort(funcs, num_funcs, sizeof(funcs[0]), cmp_ips, NULL);
2001	return emit_bpf_dispatcher(&prog, 0, num_funcs - 1, funcs);
2002}
2003
2004struct x64_jit_data {
2005	struct bpf_binary_header *header;
2006	int *addrs;
2007	u8 *image;
2008	int proglen;
2009	struct jit_context ctx;
2010};
2011
2012struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
2013{
2014	struct bpf_binary_header *header = NULL;
2015	struct bpf_prog *tmp, *orig_prog = prog;
2016	struct x64_jit_data *jit_data;
2017	int proglen, oldproglen = 0;
2018	struct jit_context ctx = {};
2019	bool tmp_blinded = false;
2020	bool extra_pass = false;
2021	u8 *image = NULL;
2022	int *addrs;
2023	int pass;
2024	int i;
2025
2026	if (!prog->jit_requested)
2027		return orig_prog;
2028
2029	tmp = bpf_jit_blind_constants(prog);
2030	/*
2031	 * If blinding was requested and we failed during blinding,
2032	 * we must fall back to the interpreter.
2033	 */
2034	if (IS_ERR(tmp))
2035		return orig_prog;
2036	if (tmp != prog) {
2037		tmp_blinded = true;
2038		prog = tmp;
2039	}
2040
2041	jit_data = prog->aux->jit_data;
2042	if (!jit_data) {
2043		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
2044		if (!jit_data) {
2045			prog = orig_prog;
2046			goto out;
2047		}
2048		prog->aux->jit_data = jit_data;
2049	}
2050	addrs = jit_data->addrs;
2051	if (addrs) {
2052		ctx = jit_data->ctx;
2053		oldproglen = jit_data->proglen;
2054		image = jit_data->image;
2055		header = jit_data->header;
2056		extra_pass = true;
2057		goto skip_init_addrs;
2058	}
2059	addrs = kvmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
2060	if (!addrs) {
2061		prog = orig_prog;
2062		goto out_addrs;
2063	}
2064
2065	/*
2066	 * Before first pass, make a rough estimation of addrs[]
2067	 * each BPF instruction is translated to less than 64 bytes
2068	 */
2069	for (proglen = 0, i = 0; i <= prog->len; i++) {
2070		proglen += 64;
2071		addrs[i] = proglen;
2072	}
2073	ctx.cleanup_addr = proglen;
2074skip_init_addrs:
2075
2076	/*
2077	 * JITed image shrinks with every pass and the loop iterates
2078	 * until the image stops shrinking. Very large BPF programs
2079	 * may converge on the last pass. In such case do one more
2080	 * pass to emit the final image.
2081	 */
2082	for (pass = 0; pass < 20 || image; pass++) {
2083		proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
2084		if (proglen <= 0) {
2085out_image:
2086			image = NULL;
2087			if (header)
2088				bpf_jit_binary_free(header);
2089			prog = orig_prog;
2090			goto out_addrs;
2091		}
2092		if (image) {
2093			if (proglen != oldproglen) {
2094				pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
2095				       proglen, oldproglen);
2096				goto out_image;
2097			}
2098			break;
2099		}
2100		if (proglen == oldproglen) {
2101			/*
2102			 * The number of entries in extable is the number of BPF_LDX
2103			 * insns that access kernel memory via "pointer to BTF type".
2104			 * The verifier changed their opcode from LDX|MEM|size
2105			 * to LDX|PROBE_MEM|size to make JITing easier.
2106			 */
2107			u32 align = __alignof__(struct exception_table_entry);
2108			u32 extable_size = prog->aux->num_exentries *
2109				sizeof(struct exception_table_entry);
2110
2111			/* allocate module memory for x86 insns and extable */
2112			header = bpf_jit_binary_alloc(roundup(proglen, align) + extable_size,
2113						      &image, align, jit_fill_hole);
2114			if (!header) {
2115				prog = orig_prog;
2116				goto out_addrs;
2117			}
2118			prog->aux->extable = (void *) image + roundup(proglen, align);
2119		}
2120		oldproglen = proglen;
2121		cond_resched();
2122	}
2123
2124	if (bpf_jit_enable > 1)
2125		bpf_jit_dump(prog->len, proglen, pass + 1, image);
2126
2127	if (image) {
2128		if (!prog->is_func || extra_pass) {
2129			bpf_tail_call_direct_fixup(prog);
2130			bpf_jit_binary_lock_ro(header);
2131		} else {
2132			jit_data->addrs = addrs;
2133			jit_data->ctx = ctx;
2134			jit_data->proglen = proglen;
2135			jit_data->image = image;
2136			jit_data->header = header;
2137		}
2138		prog->bpf_func = (void *)image;
2139		prog->jited = 1;
2140		prog->jited_len = proglen;
2141	} else {
2142		prog = orig_prog;
2143	}
2144
2145	if (!image || !prog->is_func || extra_pass) {
2146		if (image)
2147			bpf_prog_fill_jited_linfo(prog, addrs + 1);
2148out_addrs:
2149		kvfree(addrs);
2150		kfree(jit_data);
2151		prog->aux->jit_data = NULL;
2152	}
2153out:
2154	if (tmp_blinded)
2155		bpf_jit_prog_release_other(prog, prog == orig_prog ?
2156					   tmp : orig_prog);
2157	return prog;
2158}
2159