OSDN Git Service

perf/x86/uncore: Correct the number of CHAs on EMR
[tomoyo/tomoyo-test1.git] / arch / riscv / net / bpf_jit_comp64.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include <linux/memory.h>
12 #include <linux/stop_machine.h>
13 #include <asm/patch.h>
14 #include "bpf_jit.h"
15
16 #define RV_FENTRY_NINSNS 2
17
18 #define RV_REG_TCC RV_REG_A6
19 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
20
21 static const int regmap[] = {
22         [BPF_REG_0] =   RV_REG_A5,
23         [BPF_REG_1] =   RV_REG_A0,
24         [BPF_REG_2] =   RV_REG_A1,
25         [BPF_REG_3] =   RV_REG_A2,
26         [BPF_REG_4] =   RV_REG_A3,
27         [BPF_REG_5] =   RV_REG_A4,
28         [BPF_REG_6] =   RV_REG_S1,
29         [BPF_REG_7] =   RV_REG_S2,
30         [BPF_REG_8] =   RV_REG_S3,
31         [BPF_REG_9] =   RV_REG_S4,
32         [BPF_REG_FP] =  RV_REG_S5,
33         [BPF_REG_AX] =  RV_REG_T0,
34 };
35
36 static const int pt_regmap[] = {
37         [RV_REG_A0] = offsetof(struct pt_regs, a0),
38         [RV_REG_A1] = offsetof(struct pt_regs, a1),
39         [RV_REG_A2] = offsetof(struct pt_regs, a2),
40         [RV_REG_A3] = offsetof(struct pt_regs, a3),
41         [RV_REG_A4] = offsetof(struct pt_regs, a4),
42         [RV_REG_A5] = offsetof(struct pt_regs, a5),
43         [RV_REG_S1] = offsetof(struct pt_regs, s1),
44         [RV_REG_S2] = offsetof(struct pt_regs, s2),
45         [RV_REG_S3] = offsetof(struct pt_regs, s3),
46         [RV_REG_S4] = offsetof(struct pt_regs, s4),
47         [RV_REG_S5] = offsetof(struct pt_regs, s5),
48         [RV_REG_T0] = offsetof(struct pt_regs, t0),
49 };
50
51 enum {
52         RV_CTX_F_SEEN_TAIL_CALL =       0,
53         RV_CTX_F_SEEN_CALL =            RV_REG_RA,
54         RV_CTX_F_SEEN_S1 =              RV_REG_S1,
55         RV_CTX_F_SEEN_S2 =              RV_REG_S2,
56         RV_CTX_F_SEEN_S3 =              RV_REG_S3,
57         RV_CTX_F_SEEN_S4 =              RV_REG_S4,
58         RV_CTX_F_SEEN_S5 =              RV_REG_S5,
59         RV_CTX_F_SEEN_S6 =              RV_REG_S6,
60 };
61
62 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
63 {
64         u8 reg = regmap[bpf_reg];
65
66         switch (reg) {
67         case RV_CTX_F_SEEN_S1:
68         case RV_CTX_F_SEEN_S2:
69         case RV_CTX_F_SEEN_S3:
70         case RV_CTX_F_SEEN_S4:
71         case RV_CTX_F_SEEN_S5:
72         case RV_CTX_F_SEEN_S6:
73                 __set_bit(reg, &ctx->flags);
74         }
75         return reg;
76 };
77
78 static bool seen_reg(int reg, struct rv_jit_context *ctx)
79 {
80         switch (reg) {
81         case RV_CTX_F_SEEN_CALL:
82         case RV_CTX_F_SEEN_S1:
83         case RV_CTX_F_SEEN_S2:
84         case RV_CTX_F_SEEN_S3:
85         case RV_CTX_F_SEEN_S4:
86         case RV_CTX_F_SEEN_S5:
87         case RV_CTX_F_SEEN_S6:
88                 return test_bit(reg, &ctx->flags);
89         }
90         return false;
91 }
92
93 static void mark_fp(struct rv_jit_context *ctx)
94 {
95         __set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
96 }
97
98 static void mark_call(struct rv_jit_context *ctx)
99 {
100         __set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
101 }
102
103 static bool seen_call(struct rv_jit_context *ctx)
104 {
105         return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
106 }
107
108 static void mark_tail_call(struct rv_jit_context *ctx)
109 {
110         __set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
111 }
112
113 static bool seen_tail_call(struct rv_jit_context *ctx)
114 {
115         return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
116 }
117
118 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
119 {
120         mark_tail_call(ctx);
121
122         if (seen_call(ctx)) {
123                 __set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
124                 return RV_REG_S6;
125         }
126         return RV_REG_A6;
127 }
128
129 static bool is_32b_int(s64 val)
130 {
131         return -(1L << 31) <= val && val < (1L << 31);
132 }
133
134 static bool in_auipc_jalr_range(s64 val)
135 {
136         /*
137          * auipc+jalr can reach any signed PC-relative offset in the range
138          * [-2^31 - 2^11, 2^31 - 2^11).
139          */
140         return (-(1L << 31) - (1L << 11)) <= val &&
141                 val < ((1L << 31) - (1L << 11));
142 }
143
144 /* Emit fixed-length instructions for address */
145 static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
146 {
147         u64 ip = (u64)(ctx->insns + ctx->ninsns);
148         s64 off = addr - ip;
149         s64 upper = (off + (1 << 11)) >> 12;
150         s64 lower = off & 0xfff;
151
152         if (extra_pass && !in_auipc_jalr_range(off)) {
153                 pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
154                 return -ERANGE;
155         }
156
157         emit(rv_auipc(rd, upper), ctx);
158         emit(rv_addi(rd, rd, lower), ctx);
159         return 0;
160 }
161
162 /* Emit variable-length instructions for 32-bit and 64-bit imm */
163 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
164 {
165         /* Note that the immediate from the add is sign-extended,
166          * which means that we need to compensate this by adding 2^12,
167          * when the 12th bit is set. A simpler way of doing this, and
168          * getting rid of the check, is to just add 2**11 before the
169          * shift. The "Loading a 32-Bit constant" example from the
170          * "Computer Organization and Design, RISC-V edition" book by
171          * Patterson/Hennessy highlights this fact.
172          *
173          * This also means that we need to process LSB to MSB.
174          */
175         s64 upper = (val + (1 << 11)) >> 12;
176         /* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
177          * and addi are signed and RVC checks will perform signed comparisons.
178          */
179         s64 lower = ((val & 0xfff) << 52) >> 52;
180         int shift;
181
182         if (is_32b_int(val)) {
183                 if (upper)
184                         emit_lui(rd, upper, ctx);
185
186                 if (!upper) {
187                         emit_li(rd, lower, ctx);
188                         return;
189                 }
190
191                 emit_addiw(rd, rd, lower, ctx);
192                 return;
193         }
194
195         shift = __ffs(upper);
196         upper >>= shift;
197         shift += 12;
198
199         emit_imm(rd, upper, ctx);
200
201         emit_slli(rd, rd, shift, ctx);
202         if (lower)
203                 emit_addi(rd, rd, lower, ctx);
204 }
205
206 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
207 {
208         int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
209
210         if (seen_reg(RV_REG_RA, ctx)) {
211                 emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
212                 store_offset -= 8;
213         }
214         emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
215         store_offset -= 8;
216         if (seen_reg(RV_REG_S1, ctx)) {
217                 emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
218                 store_offset -= 8;
219         }
220         if (seen_reg(RV_REG_S2, ctx)) {
221                 emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
222                 store_offset -= 8;
223         }
224         if (seen_reg(RV_REG_S3, ctx)) {
225                 emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
226                 store_offset -= 8;
227         }
228         if (seen_reg(RV_REG_S4, ctx)) {
229                 emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
230                 store_offset -= 8;
231         }
232         if (seen_reg(RV_REG_S5, ctx)) {
233                 emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
234                 store_offset -= 8;
235         }
236         if (seen_reg(RV_REG_S6, ctx)) {
237                 emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
238                 store_offset -= 8;
239         }
240
241         emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
242         /* Set return value. */
243         if (!is_tail_call)
244                 emit_mv(RV_REG_A0, RV_REG_A5, ctx);
245         emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
246                   is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
247                   ctx);
248 }
249
250 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
251                      struct rv_jit_context *ctx)
252 {
253         switch (cond) {
254         case BPF_JEQ:
255                 emit(rv_beq(rd, rs, rvoff >> 1), ctx);
256                 return;
257         case BPF_JGT:
258                 emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
259                 return;
260         case BPF_JLT:
261                 emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
262                 return;
263         case BPF_JGE:
264                 emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
265                 return;
266         case BPF_JLE:
267                 emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
268                 return;
269         case BPF_JNE:
270                 emit(rv_bne(rd, rs, rvoff >> 1), ctx);
271                 return;
272         case BPF_JSGT:
273                 emit(rv_blt(rs, rd, rvoff >> 1), ctx);
274                 return;
275         case BPF_JSLT:
276                 emit(rv_blt(rd, rs, rvoff >> 1), ctx);
277                 return;
278         case BPF_JSGE:
279                 emit(rv_bge(rd, rs, rvoff >> 1), ctx);
280                 return;
281         case BPF_JSLE:
282                 emit(rv_bge(rs, rd, rvoff >> 1), ctx);
283         }
284 }
285
286 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
287                         struct rv_jit_context *ctx)
288 {
289         s64 upper, lower;
290
291         if (is_13b_int(rvoff)) {
292                 emit_bcc(cond, rd, rs, rvoff, ctx);
293                 return;
294         }
295
296         /* Adjust for jal */
297         rvoff -= 4;
298
299         /* Transform, e.g.:
300          *   bne rd,rs,foo
301          * to
302          *   beq rd,rs,<.L1>
303          *   (auipc foo)
304          *   jal(r) foo
305          * .L1
306          */
307         cond = invert_bpf_cond(cond);
308         if (is_21b_int(rvoff)) {
309                 emit_bcc(cond, rd, rs, 8, ctx);
310                 emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
311                 return;
312         }
313
314         /* 32b No need for an additional rvoff adjustment, since we
315          * get that from the auipc at PC', where PC = PC' + 4.
316          */
317         upper = (rvoff + (1 << 11)) >> 12;
318         lower = rvoff & 0xfff;
319
320         emit_bcc(cond, rd, rs, 12, ctx);
321         emit(rv_auipc(RV_REG_T1, upper), ctx);
322         emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
323 }
324
325 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
326 {
327         emit_slli(reg, reg, 32, ctx);
328         emit_srli(reg, reg, 32, ctx);
329 }
330
331 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
332 {
333         int tc_ninsn, off, start_insn = ctx->ninsns;
334         u8 tcc = rv_tail_call_reg(ctx);
335
336         /* a0: &ctx
337          * a1: &array
338          * a2: index
339          *
340          * if (index >= array->map.max_entries)
341          *      goto out;
342          */
343         tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
344                    ctx->offset[0];
345         emit_zext_32(RV_REG_A2, ctx);
346
347         off = offsetof(struct bpf_array, map.max_entries);
348         if (is_12b_check(off, insn))
349                 return -1;
350         emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
351         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
352         emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
353
354         /* if (--TCC < 0)
355          *     goto out;
356          */
357         emit_addi(RV_REG_TCC, tcc, -1, ctx);
358         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
359         emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
360
361         /* prog = array->ptrs[index];
362          * if (!prog)
363          *     goto out;
364          */
365         emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
366         emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
367         off = offsetof(struct bpf_array, ptrs);
368         if (is_12b_check(off, insn))
369                 return -1;
370         emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
371         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
372         emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
373
374         /* goto *(prog->bpf_func + 4); */
375         off = offsetof(struct bpf_prog, bpf_func);
376         if (is_12b_check(off, insn))
377                 return -1;
378         emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
379         __build_epilogue(true, ctx);
380         return 0;
381 }
382
383 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
384                       struct rv_jit_context *ctx)
385 {
386         u8 code = insn->code;
387
388         switch (code) {
389         case BPF_JMP | BPF_JA:
390         case BPF_JMP | BPF_CALL:
391         case BPF_JMP | BPF_EXIT:
392         case BPF_JMP | BPF_TAIL_CALL:
393                 break;
394         default:
395                 *rd = bpf_to_rv_reg(insn->dst_reg, ctx);
396         }
397
398         if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
399             code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
400             code & BPF_LDX || code & BPF_STX)
401                 *rs = bpf_to_rv_reg(insn->src_reg, ctx);
402 }
403
404 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
405 {
406         emit_mv(RV_REG_T2, *rd, ctx);
407         emit_zext_32(RV_REG_T2, ctx);
408         emit_mv(RV_REG_T1, *rs, ctx);
409         emit_zext_32(RV_REG_T1, ctx);
410         *rd = RV_REG_T2;
411         *rs = RV_REG_T1;
412 }
413
414 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
415 {
416         emit_addiw(RV_REG_T2, *rd, 0, ctx);
417         emit_addiw(RV_REG_T1, *rs, 0, ctx);
418         *rd = RV_REG_T2;
419         *rs = RV_REG_T1;
420 }
421
422 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
423 {
424         emit_mv(RV_REG_T2, *rd, ctx);
425         emit_zext_32(RV_REG_T2, ctx);
426         emit_zext_32(RV_REG_T1, ctx);
427         *rd = RV_REG_T2;
428 }
429
430 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
431 {
432         emit_addiw(RV_REG_T2, *rd, 0, ctx);
433         *rd = RV_REG_T2;
434 }
435
436 static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
437                               struct rv_jit_context *ctx)
438 {
439         s64 upper, lower;
440
441         if (rvoff && fixed_addr && is_21b_int(rvoff)) {
442                 emit(rv_jal(rd, rvoff >> 1), ctx);
443                 return 0;
444         } else if (in_auipc_jalr_range(rvoff)) {
445                 upper = (rvoff + (1 << 11)) >> 12;
446                 lower = rvoff & 0xfff;
447                 emit(rv_auipc(RV_REG_T1, upper), ctx);
448                 emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
449                 return 0;
450         }
451
452         pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
453         return -ERANGE;
454 }
455
456 static bool is_signed_bpf_cond(u8 cond)
457 {
458         return cond == BPF_JSGT || cond == BPF_JSLT ||
459                 cond == BPF_JSGE || cond == BPF_JSLE;
460 }
461
462 static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
463 {
464         s64 off = 0;
465         u64 ip;
466
467         if (addr && ctx->insns) {
468                 ip = (u64)(long)(ctx->insns + ctx->ninsns);
469                 off = addr - ip;
470         }
471
472         return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
473 }
474
475 static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
476                         struct rv_jit_context *ctx)
477 {
478         u8 r0;
479         int jmp_offset;
480
481         if (off) {
482                 if (is_12b_int(off)) {
483                         emit_addi(RV_REG_T1, rd, off, ctx);
484                 } else {
485                         emit_imm(RV_REG_T1, off, ctx);
486                         emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
487                 }
488                 rd = RV_REG_T1;
489         }
490
491         switch (imm) {
492         /* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
493         case BPF_ADD:
494                 emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
495                      rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
496                 break;
497         case BPF_AND:
498                 emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
499                      rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
500                 break;
501         case BPF_OR:
502                 emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
503                      rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
504                 break;
505         case BPF_XOR:
506                 emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
507                      rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
508                 break;
509         /* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
510         case BPF_ADD | BPF_FETCH:
511                 emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) :
512                      rv_amoadd_w(rs, rs, rd, 0, 0), ctx);
513                 if (!is64)
514                         emit_zext_32(rs, ctx);
515                 break;
516         case BPF_AND | BPF_FETCH:
517                 emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) :
518                      rv_amoand_w(rs, rs, rd, 0, 0), ctx);
519                 if (!is64)
520                         emit_zext_32(rs, ctx);
521                 break;
522         case BPF_OR | BPF_FETCH:
523                 emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) :
524                      rv_amoor_w(rs, rs, rd, 0, 0), ctx);
525                 if (!is64)
526                         emit_zext_32(rs, ctx);
527                 break;
528         case BPF_XOR | BPF_FETCH:
529                 emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) :
530                      rv_amoxor_w(rs, rs, rd, 0, 0), ctx);
531                 if (!is64)
532                         emit_zext_32(rs, ctx);
533                 break;
534         /* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
535         case BPF_XCHG:
536                 emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) :
537                      rv_amoswap_w(rs, rs, rd, 0, 0), ctx);
538                 if (!is64)
539                         emit_zext_32(rs, ctx);
540                 break;
541         /* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
542         case BPF_CMPXCHG:
543                 r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
544                 emit(is64 ? rv_addi(RV_REG_T2, r0, 0) :
545                      rv_addiw(RV_REG_T2, r0, 0), ctx);
546                 emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
547                      rv_lr_w(r0, 0, rd, 0, 0), ctx);
548                 jmp_offset = ninsns_rvoff(8);
549                 emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
550                 emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
551                      rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
552                 jmp_offset = ninsns_rvoff(-6);
553                 emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
554                 emit(rv_fence(0x3, 0x3), ctx);
555                 break;
556         }
557 }
558
559 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
560 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
561
562 bool ex_handler_bpf(const struct exception_table_entry *ex,
563                     struct pt_regs *regs)
564 {
565         off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
566         int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
567
568         *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
569         regs->epc = (unsigned long)&ex->fixup - offset;
570
571         return true;
572 }
573
574 /* For accesses to BTF pointers, add an entry to the exception table */
575 static int add_exception_handler(const struct bpf_insn *insn,
576                                  struct rv_jit_context *ctx,
577                                  int dst_reg, int insn_len)
578 {
579         struct exception_table_entry *ex;
580         unsigned long pc;
581         off_t offset;
582
583         if (!ctx->insns || !ctx->prog->aux->extable ||
584             (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX))
585                 return 0;
586
587         if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
588                 return -EINVAL;
589
590         if (WARN_ON_ONCE(insn_len > ctx->ninsns))
591                 return -EINVAL;
592
593         if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
594                 return -EINVAL;
595
596         ex = &ctx->prog->aux->extable[ctx->nexentries];
597         pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
598
599         offset = pc - (long)&ex->insn;
600         if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
601                 return -ERANGE;
602         ex->insn = offset;
603
604         /*
605          * Since the extable follows the program, the fixup offset is always
606          * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
607          * to keep things simple, and put the destination register in the upper
608          * bits. We don't need to worry about buildtime or runtime sort
609          * modifying the upper bits because the table is already sorted, and
610          * isn't part of the main exception table.
611          */
612         offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
613         if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
614                 return -ERANGE;
615
616         ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
617                 FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
618         ex->type = EX_TYPE_BPF;
619
620         ctx->nexentries++;
621         return 0;
622 }
623
624 static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
625 {
626         s64 rvoff;
627         struct rv_jit_context ctx;
628
629         ctx.ninsns = 0;
630         ctx.insns = (u16 *)insns;
631
632         if (!target) {
633                 emit(rv_nop(), &ctx);
634                 emit(rv_nop(), &ctx);
635                 return 0;
636         }
637
638         rvoff = (s64)(target - ip);
639         return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
640 }
641
642 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
643                        void *old_addr, void *new_addr)
644 {
645         u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
646         bool is_call = poke_type == BPF_MOD_CALL;
647         int ret;
648
649         if (!is_kernel_text((unsigned long)ip) &&
650             !is_bpf_text_address((unsigned long)ip))
651                 return -ENOTSUPP;
652
653         ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
654         if (ret)
655                 return ret;
656
657         if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
658                 return -EFAULT;
659
660         ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
661         if (ret)
662                 return ret;
663
664         cpus_read_lock();
665         mutex_lock(&text_mutex);
666         if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
667                 ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
668         mutex_unlock(&text_mutex);
669         cpus_read_unlock();
670
671         return ret;
672 }
673
674 static void store_args(int nregs, int args_off, struct rv_jit_context *ctx)
675 {
676         int i;
677
678         for (i = 0; i < nregs; i++) {
679                 emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
680                 args_off -= 8;
681         }
682 }
683
684 static void restore_args(int nregs, int args_off, struct rv_jit_context *ctx)
685 {
686         int i;
687
688         for (i = 0; i < nregs; i++) {
689                 emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
690                 args_off -= 8;
691         }
692 }
693
694 static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
695                            int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
696 {
697         int ret, branch_off;
698         struct bpf_prog *p = l->link.prog;
699         int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
700
701         if (l->cookie) {
702                 emit_imm(RV_REG_T1, l->cookie, ctx);
703                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
704         } else {
705                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
706         }
707
708         /* arg1: prog */
709         emit_imm(RV_REG_A0, (const s64)p, ctx);
710         /* arg2: &run_ctx */
711         emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
712         ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
713         if (ret)
714                 return ret;
715
716         /* if (__bpf_prog_enter(prog) == 0)
717          *      goto skip_exec_of_prog;
718          */
719         branch_off = ctx->ninsns;
720         /* nop reserved for conditional jump */
721         emit(rv_nop(), ctx);
722
723         /* store prog start time */
724         emit_mv(RV_REG_S1, RV_REG_A0, ctx);
725
726         /* arg1: &args_off */
727         emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
728         if (!p->jited)
729                 /* arg2: progs[i]->insnsi for interpreter */
730                 emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
731         ret = emit_call((const u64)p->bpf_func, true, ctx);
732         if (ret)
733                 return ret;
734
735         if (save_ret)
736                 emit_sd(RV_REG_FP, -retval_off, regmap[BPF_REG_0], ctx);
737
738         /* update branch with beqz */
739         if (ctx->insns) {
740                 int offset = ninsns_rvoff(ctx->ninsns - branch_off);
741                 u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
742                 *(u32 *)(ctx->insns + branch_off) = insn;
743         }
744
745         /* arg1: prog */
746         emit_imm(RV_REG_A0, (const s64)p, ctx);
747         /* arg2: prog start time */
748         emit_mv(RV_REG_A1, RV_REG_S1, ctx);
749         /* arg3: &run_ctx */
750         emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
751         ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
752
753         return ret;
754 }
755
756 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
757                                          const struct btf_func_model *m,
758                                          struct bpf_tramp_links *tlinks,
759                                          void *func_addr, u32 flags,
760                                          struct rv_jit_context *ctx)
761 {
762         int i, ret, offset;
763         int *branches_off = NULL;
764         int stack_size = 0, nregs = m->nr_args;
765         int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
766         struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
767         struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
768         struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
769         void *orig_call = func_addr;
770         bool save_ret;
771         u32 insn;
772
773         /* Two types of generated trampoline stack layout:
774          *
775          * 1. trampoline called from function entry
776          * --------------------------------------
777          * FP + 8           [ RA to parent func ] return address to parent
778          *                                        function
779          * FP + 0           [ FP of parent func ] frame pointer of parent
780          *                                        function
781          * FP - 8           [ T0 to traced func ] return address of traced
782          *                                        function
783          * FP - 16          [ FP of traced func ] frame pointer of traced
784          *                                        function
785          * --------------------------------------
786          *
787          * 2. trampoline called directly
788          * --------------------------------------
789          * FP - 8           [ RA to caller func ] return address to caller
790          *                                        function
791          * FP - 16          [ FP of caller func ] frame pointer of caller
792          *                                        function
793          * --------------------------------------
794          *
795          * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
796          *                                        BPF_TRAMP_F_RET_FENTRY_RET
797          *                  [ argN              ]
798          *                  [ ...               ]
799          * FP - args_off    [ arg1              ]
800          *
801          * FP - nregs_off   [ regs count        ]
802          *
803          * FP - ip_off      [ traced func       ] BPF_TRAMP_F_IP_ARG
804          *
805          * FP - run_ctx_off [ bpf_tramp_run_ctx ]
806          *
807          * FP - sreg_off    [ callee saved reg  ]
808          *
809          *                  [ pads              ] pads for 16 bytes alignment
810          */
811
812         if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
813                 return -ENOTSUPP;
814
815         /* extra regiters for struct arguments */
816         for (i = 0; i < m->nr_args; i++)
817                 if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
818                         nregs += round_up(m->arg_size[i], 8) / 8 - 1;
819
820         /* 8 arguments passed by registers */
821         if (nregs > 8)
822                 return -ENOTSUPP;
823
824         /* room of trampoline frame to store return address and frame pointer */
825         stack_size += 16;
826
827         save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
828         if (save_ret) {
829                 stack_size += 8;
830                 retval_off = stack_size;
831         }
832
833         stack_size += nregs * 8;
834         args_off = stack_size;
835
836         stack_size += 8;
837         nregs_off = stack_size;
838
839         if (flags & BPF_TRAMP_F_IP_ARG) {
840                 stack_size += 8;
841                 ip_off = stack_size;
842         }
843
844         stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
845         run_ctx_off = stack_size;
846
847         stack_size += 8;
848         sreg_off = stack_size;
849
850         stack_size = round_up(stack_size, 16);
851
852         if (func_addr) {
853                 /* For the trampoline called from function entry,
854                  * the frame of traced function and the frame of
855                  * trampoline need to be considered.
856                  */
857                 emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
858                 emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
859                 emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
860                 emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
861
862                 emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
863                 emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
864                 emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
865                 emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
866         } else {
867                 /* For the trampoline called directly, just handle
868                  * the frame of trampoline.
869                  */
870                 emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
871                 emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
872                 emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
873                 emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
874         }
875
876         /* callee saved register S1 to pass start time */
877         emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
878
879         /* store ip address of the traced function */
880         if (flags & BPF_TRAMP_F_IP_ARG) {
881                 emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
882                 emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
883         }
884
885         emit_li(RV_REG_T1, nregs, ctx);
886         emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
887
888         store_args(nregs, args_off, ctx);
889
890         /* skip to actual body of traced function */
891         if (flags & BPF_TRAMP_F_SKIP_FRAME)
892                 orig_call += RV_FENTRY_NINSNS * 4;
893
894         if (flags & BPF_TRAMP_F_CALL_ORIG) {
895                 emit_imm(RV_REG_A0, (const s64)im, ctx);
896                 ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
897                 if (ret)
898                         return ret;
899         }
900
901         for (i = 0; i < fentry->nr_links; i++) {
902                 ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
903                                       flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
904                 if (ret)
905                         return ret;
906         }
907
908         if (fmod_ret->nr_links) {
909                 branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
910                 if (!branches_off)
911                         return -ENOMEM;
912
913                 /* cleanup to avoid garbage return value confusion */
914                 emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
915                 for (i = 0; i < fmod_ret->nr_links; i++) {
916                         ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
917                                               run_ctx_off, true, ctx);
918                         if (ret)
919                                 goto out;
920                         emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
921                         branches_off[i] = ctx->ninsns;
922                         /* nop reserved for conditional jump */
923                         emit(rv_nop(), ctx);
924                 }
925         }
926
927         if (flags & BPF_TRAMP_F_CALL_ORIG) {
928                 restore_args(nregs, args_off, ctx);
929                 ret = emit_call((const u64)orig_call, true, ctx);
930                 if (ret)
931                         goto out;
932                 emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
933                 im->ip_after_call = ctx->insns + ctx->ninsns;
934                 /* 2 nops reserved for auipc+jalr pair */
935                 emit(rv_nop(), ctx);
936                 emit(rv_nop(), ctx);
937         }
938
939         /* update branches saved in invoke_bpf_mod_ret with bnez */
940         for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
941                 offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
942                 insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
943                 *(u32 *)(ctx->insns + branches_off[i]) = insn;
944         }
945
946         for (i = 0; i < fexit->nr_links; i++) {
947                 ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
948                                       run_ctx_off, false, ctx);
949                 if (ret)
950                         goto out;
951         }
952
953         if (flags & BPF_TRAMP_F_CALL_ORIG) {
954                 im->ip_epilogue = ctx->insns + ctx->ninsns;
955                 emit_imm(RV_REG_A0, (const s64)im, ctx);
956                 ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
957                 if (ret)
958                         goto out;
959         }
960
961         if (flags & BPF_TRAMP_F_RESTORE_REGS)
962                 restore_args(nregs, args_off, ctx);
963
964         if (save_ret)
965                 emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
966
967         emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
968
969         if (func_addr) {
970                 /* trampoline called from function entry */
971                 emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
972                 emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
973                 emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
974
975                 emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
976                 emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
977                 emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);
978
979                 if (flags & BPF_TRAMP_F_SKIP_FRAME)
980                         /* return to parent function */
981                         emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
982                 else
983                         /* return to traced function */
984                         emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
985         } else {
986                 /* trampoline called directly */
987                 emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
988                 emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
989                 emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
990
991                 emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
992         }
993
994         ret = ctx->ninsns;
995 out:
996         kfree(branches_off);
997         return ret;
998 }
999
1000 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
1001                                 void *image_end, const struct btf_func_model *m,
1002                                 u32 flags, struct bpf_tramp_links *tlinks,
1003                                 void *func_addr)
1004 {
1005         int ret;
1006         struct rv_jit_context ctx;
1007
1008         ctx.ninsns = 0;
1009         ctx.insns = NULL;
1010         ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1011         if (ret < 0)
1012                 return ret;
1013
1014         if (ninsns_rvoff(ret) > (long)image_end - (long)image)
1015                 return -EFBIG;
1016
1017         ctx.ninsns = 0;
1018         ctx.insns = image;
1019         ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1020         if (ret < 0)
1021                 return ret;
1022
1023         bpf_flush_icache(ctx.insns, ctx.insns + ctx.ninsns);
1024
1025         return ninsns_rvoff(ret);
1026 }
1027
1028 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1029                       bool extra_pass)
1030 {
1031         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1032                     BPF_CLASS(insn->code) == BPF_JMP;
1033         int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1034         struct bpf_prog_aux *aux = ctx->prog->aux;
1035         u8 rd = -1, rs = -1, code = insn->code;
1036         s16 off = insn->off;
1037         s32 imm = insn->imm;
1038
1039         init_regs(&rd, &rs, insn, ctx);
1040
1041         switch (code) {
1042         /* dst = src */
1043         case BPF_ALU | BPF_MOV | BPF_X:
1044         case BPF_ALU64 | BPF_MOV | BPF_X:
1045                 if (imm == 1) {
1046                         /* Special mov32 for zext */
1047                         emit_zext_32(rd, ctx);
1048                         break;
1049                 }
1050                 switch (insn->off) {
1051                 case 0:
1052                         emit_mv(rd, rs, ctx);
1053                         break;
1054                 case 8:
1055                 case 16:
1056                         emit_slli(RV_REG_T1, rs, 64 - insn->off, ctx);
1057                         emit_srai(rd, RV_REG_T1, 64 - insn->off, ctx);
1058                         break;
1059                 case 32:
1060                         emit_addiw(rd, rs, 0, ctx);
1061                         break;
1062                 }
1063                 if (!is64 && !aux->verifier_zext)
1064                         emit_zext_32(rd, ctx);
1065                 break;
1066
1067         /* dst = dst OP src */
1068         case BPF_ALU | BPF_ADD | BPF_X:
1069         case BPF_ALU64 | BPF_ADD | BPF_X:
1070                 emit_add(rd, rd, rs, ctx);
1071                 if (!is64 && !aux->verifier_zext)
1072                         emit_zext_32(rd, ctx);
1073                 break;
1074         case BPF_ALU | BPF_SUB | BPF_X:
1075         case BPF_ALU64 | BPF_SUB | BPF_X:
1076                 if (is64)
1077                         emit_sub(rd, rd, rs, ctx);
1078                 else
1079                         emit_subw(rd, rd, rs, ctx);
1080
1081                 if (!is64 && !aux->verifier_zext)
1082                         emit_zext_32(rd, ctx);
1083                 break;
1084         case BPF_ALU | BPF_AND | BPF_X:
1085         case BPF_ALU64 | BPF_AND | BPF_X:
1086                 emit_and(rd, rd, rs, ctx);
1087                 if (!is64 && !aux->verifier_zext)
1088                         emit_zext_32(rd, ctx);
1089                 break;
1090         case BPF_ALU | BPF_OR | BPF_X:
1091         case BPF_ALU64 | BPF_OR | BPF_X:
1092                 emit_or(rd, rd, rs, ctx);
1093                 if (!is64 && !aux->verifier_zext)
1094                         emit_zext_32(rd, ctx);
1095                 break;
1096         case BPF_ALU | BPF_XOR | BPF_X:
1097         case BPF_ALU64 | BPF_XOR | BPF_X:
1098                 emit_xor(rd, rd, rs, ctx);
1099                 if (!is64 && !aux->verifier_zext)
1100                         emit_zext_32(rd, ctx);
1101                 break;
1102         case BPF_ALU | BPF_MUL | BPF_X:
1103         case BPF_ALU64 | BPF_MUL | BPF_X:
1104                 emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1105                 if (!is64 && !aux->verifier_zext)
1106                         emit_zext_32(rd, ctx);
1107                 break;
1108         case BPF_ALU | BPF_DIV | BPF_X:
1109         case BPF_ALU64 | BPF_DIV | BPF_X:
1110                 if (off)
1111                         emit(is64 ? rv_div(rd, rd, rs) : rv_divw(rd, rd, rs), ctx);
1112                 else
1113                         emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1114                 if (!is64 && !aux->verifier_zext)
1115                         emit_zext_32(rd, ctx);
1116                 break;
1117         case BPF_ALU | BPF_MOD | BPF_X:
1118         case BPF_ALU64 | BPF_MOD | BPF_X:
1119                 if (off)
1120                         emit(is64 ? rv_rem(rd, rd, rs) : rv_remw(rd, rd, rs), ctx);
1121                 else
1122                         emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1123                 if (!is64 && !aux->verifier_zext)
1124                         emit_zext_32(rd, ctx);
1125                 break;
1126         case BPF_ALU | BPF_LSH | BPF_X:
1127         case BPF_ALU64 | BPF_LSH | BPF_X:
1128                 emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1129                 if (!is64 && !aux->verifier_zext)
1130                         emit_zext_32(rd, ctx);
1131                 break;
1132         case BPF_ALU | BPF_RSH | BPF_X:
1133         case BPF_ALU64 | BPF_RSH | BPF_X:
1134                 emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1135                 if (!is64 && !aux->verifier_zext)
1136                         emit_zext_32(rd, ctx);
1137                 break;
1138         case BPF_ALU | BPF_ARSH | BPF_X:
1139         case BPF_ALU64 | BPF_ARSH | BPF_X:
1140                 emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1141                 if (!is64 && !aux->verifier_zext)
1142                         emit_zext_32(rd, ctx);
1143                 break;
1144
1145         /* dst = -dst */
1146         case BPF_ALU | BPF_NEG:
1147         case BPF_ALU64 | BPF_NEG:
1148                 emit_sub(rd, RV_REG_ZERO, rd, ctx);
1149                 if (!is64 && !aux->verifier_zext)
1150                         emit_zext_32(rd, ctx);
1151                 break;
1152
1153         /* dst = BSWAP##imm(dst) */
1154         case BPF_ALU | BPF_END | BPF_FROM_LE:
1155                 switch (imm) {
1156                 case 16:
1157                         emit_slli(rd, rd, 48, ctx);
1158                         emit_srli(rd, rd, 48, ctx);
1159                         break;
1160                 case 32:
1161                         if (!aux->verifier_zext)
1162                                 emit_zext_32(rd, ctx);
1163                         break;
1164                 case 64:
1165                         /* Do nothing */
1166                         break;
1167                 }
1168                 break;
1169
1170         case BPF_ALU | BPF_END | BPF_FROM_BE:
1171         case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1172                 emit_li(RV_REG_T2, 0, ctx);
1173
1174                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1175                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1176                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1177                 emit_srli(rd, rd, 8, ctx);
1178                 if (imm == 16)
1179                         goto out_be;
1180
1181                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1182                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1183                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1184                 emit_srli(rd, rd, 8, ctx);
1185
1186                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1187                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1188                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1189                 emit_srli(rd, rd, 8, ctx);
1190                 if (imm == 32)
1191                         goto out_be;
1192
1193                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1194                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1195                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1196                 emit_srli(rd, rd, 8, ctx);
1197
1198                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1199                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1200                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1201                 emit_srli(rd, rd, 8, ctx);
1202
1203                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1204                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1205                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1206                 emit_srli(rd, rd, 8, ctx);
1207
1208                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1209                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1210                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1211                 emit_srli(rd, rd, 8, ctx);
1212 out_be:
1213                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1214                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1215
1216                 emit_mv(rd, RV_REG_T2, ctx);
1217                 break;
1218
1219         /* dst = imm */
1220         case BPF_ALU | BPF_MOV | BPF_K:
1221         case BPF_ALU64 | BPF_MOV | BPF_K:
1222                 emit_imm(rd, imm, ctx);
1223                 if (!is64 && !aux->verifier_zext)
1224                         emit_zext_32(rd, ctx);
1225                 break;
1226
1227         /* dst = dst OP imm */
1228         case BPF_ALU | BPF_ADD | BPF_K:
1229         case BPF_ALU64 | BPF_ADD | BPF_K:
1230                 if (is_12b_int(imm)) {
1231                         emit_addi(rd, rd, imm, ctx);
1232                 } else {
1233                         emit_imm(RV_REG_T1, imm, ctx);
1234                         emit_add(rd, rd, RV_REG_T1, ctx);
1235                 }
1236                 if (!is64 && !aux->verifier_zext)
1237                         emit_zext_32(rd, ctx);
1238                 break;
1239         case BPF_ALU | BPF_SUB | BPF_K:
1240         case BPF_ALU64 | BPF_SUB | BPF_K:
1241                 if (is_12b_int(-imm)) {
1242                         emit_addi(rd, rd, -imm, ctx);
1243                 } else {
1244                         emit_imm(RV_REG_T1, imm, ctx);
1245                         emit_sub(rd, rd, RV_REG_T1, ctx);
1246                 }
1247                 if (!is64 && !aux->verifier_zext)
1248                         emit_zext_32(rd, ctx);
1249                 break;
1250         case BPF_ALU | BPF_AND | BPF_K:
1251         case BPF_ALU64 | BPF_AND | BPF_K:
1252                 if (is_12b_int(imm)) {
1253                         emit_andi(rd, rd, imm, ctx);
1254                 } else {
1255                         emit_imm(RV_REG_T1, imm, ctx);
1256                         emit_and(rd, rd, RV_REG_T1, ctx);
1257                 }
1258                 if (!is64 && !aux->verifier_zext)
1259                         emit_zext_32(rd, ctx);
1260                 break;
1261         case BPF_ALU | BPF_OR | BPF_K:
1262         case BPF_ALU64 | BPF_OR | BPF_K:
1263                 if (is_12b_int(imm)) {
1264                         emit(rv_ori(rd, rd, imm), ctx);
1265                 } else {
1266                         emit_imm(RV_REG_T1, imm, ctx);
1267                         emit_or(rd, rd, RV_REG_T1, ctx);
1268                 }
1269                 if (!is64 && !aux->verifier_zext)
1270                         emit_zext_32(rd, ctx);
1271                 break;
1272         case BPF_ALU | BPF_XOR | BPF_K:
1273         case BPF_ALU64 | BPF_XOR | BPF_K:
1274                 if (is_12b_int(imm)) {
1275                         emit(rv_xori(rd, rd, imm), ctx);
1276                 } else {
1277                         emit_imm(RV_REG_T1, imm, ctx);
1278                         emit_xor(rd, rd, RV_REG_T1, ctx);
1279                 }
1280                 if (!is64 && !aux->verifier_zext)
1281                         emit_zext_32(rd, ctx);
1282                 break;
1283         case BPF_ALU | BPF_MUL | BPF_K:
1284         case BPF_ALU64 | BPF_MUL | BPF_K:
1285                 emit_imm(RV_REG_T1, imm, ctx);
1286                 emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1287                      rv_mulw(rd, rd, RV_REG_T1), ctx);
1288                 if (!is64 && !aux->verifier_zext)
1289                         emit_zext_32(rd, ctx);
1290                 break;
1291         case BPF_ALU | BPF_DIV | BPF_K:
1292         case BPF_ALU64 | BPF_DIV | BPF_K:
1293                 emit_imm(RV_REG_T1, imm, ctx);
1294                 if (off)
1295                         emit(is64 ? rv_div(rd, rd, RV_REG_T1) :
1296                              rv_divw(rd, rd, RV_REG_T1), ctx);
1297                 else
1298                         emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1299                              rv_divuw(rd, rd, RV_REG_T1), ctx);
1300                 if (!is64 && !aux->verifier_zext)
1301                         emit_zext_32(rd, ctx);
1302                 break;
1303         case BPF_ALU | BPF_MOD | BPF_K:
1304         case BPF_ALU64 | BPF_MOD | BPF_K:
1305                 emit_imm(RV_REG_T1, imm, ctx);
1306                 if (off)
1307                         emit(is64 ? rv_rem(rd, rd, RV_REG_T1) :
1308                              rv_remw(rd, rd, RV_REG_T1), ctx);
1309                 else
1310                         emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1311                              rv_remuw(rd, rd, RV_REG_T1), ctx);
1312                 if (!is64 && !aux->verifier_zext)
1313                         emit_zext_32(rd, ctx);
1314                 break;
1315         case BPF_ALU | BPF_LSH | BPF_K:
1316         case BPF_ALU64 | BPF_LSH | BPF_K:
1317                 emit_slli(rd, rd, imm, ctx);
1318
1319                 if (!is64 && !aux->verifier_zext)
1320                         emit_zext_32(rd, ctx);
1321                 break;
1322         case BPF_ALU | BPF_RSH | BPF_K:
1323         case BPF_ALU64 | BPF_RSH | BPF_K:
1324                 if (is64)
1325                         emit_srli(rd, rd, imm, ctx);
1326                 else
1327                         emit(rv_srliw(rd, rd, imm), ctx);
1328
1329                 if (!is64 && !aux->verifier_zext)
1330                         emit_zext_32(rd, ctx);
1331                 break;
1332         case BPF_ALU | BPF_ARSH | BPF_K:
1333         case BPF_ALU64 | BPF_ARSH | BPF_K:
1334                 if (is64)
1335                         emit_srai(rd, rd, imm, ctx);
1336                 else
1337                         emit(rv_sraiw(rd, rd, imm), ctx);
1338
1339                 if (!is64 && !aux->verifier_zext)
1340                         emit_zext_32(rd, ctx);
1341                 break;
1342
1343         /* JUMP off */
1344         case BPF_JMP | BPF_JA:
1345         case BPF_JMP32 | BPF_JA:
1346                 if (BPF_CLASS(code) == BPF_JMP)
1347                         rvoff = rv_offset(i, off, ctx);
1348                 else
1349                         rvoff = rv_offset(i, imm, ctx);
1350                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1351                 if (ret)
1352                         return ret;
1353                 break;
1354
1355         /* IF (dst COND src) JUMP off */
1356         case BPF_JMP | BPF_JEQ | BPF_X:
1357         case BPF_JMP32 | BPF_JEQ | BPF_X:
1358         case BPF_JMP | BPF_JGT | BPF_X:
1359         case BPF_JMP32 | BPF_JGT | BPF_X:
1360         case BPF_JMP | BPF_JLT | BPF_X:
1361         case BPF_JMP32 | BPF_JLT | BPF_X:
1362         case BPF_JMP | BPF_JGE | BPF_X:
1363         case BPF_JMP32 | BPF_JGE | BPF_X:
1364         case BPF_JMP | BPF_JLE | BPF_X:
1365         case BPF_JMP32 | BPF_JLE | BPF_X:
1366         case BPF_JMP | BPF_JNE | BPF_X:
1367         case BPF_JMP32 | BPF_JNE | BPF_X:
1368         case BPF_JMP | BPF_JSGT | BPF_X:
1369         case BPF_JMP32 | BPF_JSGT | BPF_X:
1370         case BPF_JMP | BPF_JSLT | BPF_X:
1371         case BPF_JMP32 | BPF_JSLT | BPF_X:
1372         case BPF_JMP | BPF_JSGE | BPF_X:
1373         case BPF_JMP32 | BPF_JSGE | BPF_X:
1374         case BPF_JMP | BPF_JSLE | BPF_X:
1375         case BPF_JMP32 | BPF_JSLE | BPF_X:
1376         case BPF_JMP | BPF_JSET | BPF_X:
1377         case BPF_JMP32 | BPF_JSET | BPF_X:
1378                 rvoff = rv_offset(i, off, ctx);
1379                 if (!is64) {
1380                         s = ctx->ninsns;
1381                         if (is_signed_bpf_cond(BPF_OP(code)))
1382                                 emit_sext_32_rd_rs(&rd, &rs, ctx);
1383                         else
1384                                 emit_zext_32_rd_rs(&rd, &rs, ctx);
1385                         e = ctx->ninsns;
1386
1387                         /* Adjust for extra insns */
1388                         rvoff -= ninsns_rvoff(e - s);
1389                 }
1390
1391                 if (BPF_OP(code) == BPF_JSET) {
1392                         /* Adjust for and */
1393                         rvoff -= 4;
1394                         emit_and(RV_REG_T1, rd, rs, ctx);
1395                         emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
1396                                     ctx);
1397                 } else {
1398                         emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1399                 }
1400                 break;
1401
1402         /* IF (dst COND imm) JUMP off */
1403         case BPF_JMP | BPF_JEQ | BPF_K:
1404         case BPF_JMP32 | BPF_JEQ | BPF_K:
1405         case BPF_JMP | BPF_JGT | BPF_K:
1406         case BPF_JMP32 | BPF_JGT | BPF_K:
1407         case BPF_JMP | BPF_JLT | BPF_K:
1408         case BPF_JMP32 | BPF_JLT | BPF_K:
1409         case BPF_JMP | BPF_JGE | BPF_K:
1410         case BPF_JMP32 | BPF_JGE | BPF_K:
1411         case BPF_JMP | BPF_JLE | BPF_K:
1412         case BPF_JMP32 | BPF_JLE | BPF_K:
1413         case BPF_JMP | BPF_JNE | BPF_K:
1414         case BPF_JMP32 | BPF_JNE | BPF_K:
1415         case BPF_JMP | BPF_JSGT | BPF_K:
1416         case BPF_JMP32 | BPF_JSGT | BPF_K:
1417         case BPF_JMP | BPF_JSLT | BPF_K:
1418         case BPF_JMP32 | BPF_JSLT | BPF_K:
1419         case BPF_JMP | BPF_JSGE | BPF_K:
1420         case BPF_JMP32 | BPF_JSGE | BPF_K:
1421         case BPF_JMP | BPF_JSLE | BPF_K:
1422         case BPF_JMP32 | BPF_JSLE | BPF_K:
1423                 rvoff = rv_offset(i, off, ctx);
1424                 s = ctx->ninsns;
1425                 if (imm) {
1426                         emit_imm(RV_REG_T1, imm, ctx);
1427                         rs = RV_REG_T1;
1428                 } else {
1429                         /* If imm is 0, simply use zero register. */
1430                         rs = RV_REG_ZERO;
1431                 }
1432                 if (!is64) {
1433                         if (is_signed_bpf_cond(BPF_OP(code)))
1434                                 emit_sext_32_rd(&rd, ctx);
1435                         else
1436                                 emit_zext_32_rd_t1(&rd, ctx);
1437                 }
1438                 e = ctx->ninsns;
1439
1440                 /* Adjust for extra insns */
1441                 rvoff -= ninsns_rvoff(e - s);
1442                 emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1443                 break;
1444
1445         case BPF_JMP | BPF_JSET | BPF_K:
1446         case BPF_JMP32 | BPF_JSET | BPF_K:
1447                 rvoff = rv_offset(i, off, ctx);
1448                 s = ctx->ninsns;
1449                 if (is_12b_int(imm)) {
1450                         emit_andi(RV_REG_T1, rd, imm, ctx);
1451                 } else {
1452                         emit_imm(RV_REG_T1, imm, ctx);
1453                         emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1454                 }
1455                 /* For jset32, we should clear the upper 32 bits of t1, but
1456                  * sign-extension is sufficient here and saves one instruction,
1457                  * as t1 is used only in comparison against zero.
1458                  */
1459                 if (!is64 && imm < 0)
1460                         emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
1461                 e = ctx->ninsns;
1462                 rvoff -= ninsns_rvoff(e - s);
1463                 emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1464                 break;
1465
1466         /* function call */
1467         case BPF_JMP | BPF_CALL:
1468         {
1469                 bool fixed_addr;
1470                 u64 addr;
1471
1472                 mark_call(ctx);
1473                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1474                                             &addr, &fixed_addr);
1475                 if (ret < 0)
1476                         return ret;
1477
1478                 ret = emit_call(addr, fixed_addr, ctx);
1479                 if (ret)
1480                         return ret;
1481
1482                 emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1483                 break;
1484         }
1485         /* tail call */
1486         case BPF_JMP | BPF_TAIL_CALL:
1487                 if (emit_bpf_tail_call(i, ctx))
1488                         return -1;
1489                 break;
1490
1491         /* function return */
1492         case BPF_JMP | BPF_EXIT:
1493                 if (i == ctx->prog->len - 1)
1494                         break;
1495
1496                 rvoff = epilogue_offset(ctx);
1497                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1498                 if (ret)
1499                         return ret;
1500                 break;
1501
1502         /* dst = imm64 */
1503         case BPF_LD | BPF_IMM | BPF_DW:
1504         {
1505                 struct bpf_insn insn1 = insn[1];
1506                 u64 imm64;
1507
1508                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
1509                 if (bpf_pseudo_func(insn)) {
1510                         /* fixed-length insns for extra jit pass */
1511                         ret = emit_addr(rd, imm64, extra_pass, ctx);
1512                         if (ret)
1513                                 return ret;
1514                 } else {
1515                         emit_imm(rd, imm64, ctx);
1516                 }
1517
1518                 return 1;
1519         }
1520
1521         /* LDX: dst = *(unsigned size *)(src + off) */
1522         case BPF_LDX | BPF_MEM | BPF_B:
1523         case BPF_LDX | BPF_MEM | BPF_H:
1524         case BPF_LDX | BPF_MEM | BPF_W:
1525         case BPF_LDX | BPF_MEM | BPF_DW:
1526         case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1527         case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1528         case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1529         case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1530         /* LDSX: dst = *(signed size *)(src + off) */
1531         case BPF_LDX | BPF_MEMSX | BPF_B:
1532         case BPF_LDX | BPF_MEMSX | BPF_H:
1533         case BPF_LDX | BPF_MEMSX | BPF_W:
1534         case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1535         case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1536         case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1537         {
1538                 int insn_len, insns_start;
1539                 bool sign_ext;
1540
1541                 sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
1542                            BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
1543
1544                 switch (BPF_SIZE(code)) {
1545                 case BPF_B:
1546                         if (is_12b_int(off)) {
1547                                 insns_start = ctx->ninsns;
1548                                 if (sign_ext)
1549                                         emit(rv_lb(rd, off, rs), ctx);
1550                                 else
1551                                         emit(rv_lbu(rd, off, rs), ctx);
1552                                 insn_len = ctx->ninsns - insns_start;
1553                                 break;
1554                         }
1555
1556                         emit_imm(RV_REG_T1, off, ctx);
1557                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1558                         insns_start = ctx->ninsns;
1559                         if (sign_ext)
1560                                 emit(rv_lb(rd, 0, RV_REG_T1), ctx);
1561                         else
1562                                 emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1563                         insn_len = ctx->ninsns - insns_start;
1564                         break;
1565                 case BPF_H:
1566                         if (is_12b_int(off)) {
1567                                 insns_start = ctx->ninsns;
1568                                 if (sign_ext)
1569                                         emit(rv_lh(rd, off, rs), ctx);
1570                                 else
1571                                         emit(rv_lhu(rd, off, rs), ctx);
1572                                 insn_len = ctx->ninsns - insns_start;
1573                                 break;
1574                         }
1575
1576                         emit_imm(RV_REG_T1, off, ctx);
1577                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1578                         insns_start = ctx->ninsns;
1579                         if (sign_ext)
1580                                 emit(rv_lh(rd, 0, RV_REG_T1), ctx);
1581                         else
1582                                 emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1583                         insn_len = ctx->ninsns - insns_start;
1584                         break;
1585                 case BPF_W:
1586                         if (is_12b_int(off)) {
1587                                 insns_start = ctx->ninsns;
1588                                 if (sign_ext)
1589                                         emit(rv_lw(rd, off, rs), ctx);
1590                                 else
1591                                         emit(rv_lwu(rd, off, rs), ctx);
1592                                 insn_len = ctx->ninsns - insns_start;
1593                                 break;
1594                         }
1595
1596                         emit_imm(RV_REG_T1, off, ctx);
1597                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1598                         insns_start = ctx->ninsns;
1599                         if (sign_ext)
1600                                 emit(rv_lw(rd, 0, RV_REG_T1), ctx);
1601                         else
1602                                 emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1603                         insn_len = ctx->ninsns - insns_start;
1604                         break;
1605                 case BPF_DW:
1606                         if (is_12b_int(off)) {
1607                                 insns_start = ctx->ninsns;
1608                                 emit_ld(rd, off, rs, ctx);
1609                                 insn_len = ctx->ninsns - insns_start;
1610                                 break;
1611                         }
1612
1613                         emit_imm(RV_REG_T1, off, ctx);
1614                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1615                         insns_start = ctx->ninsns;
1616                         emit_ld(rd, 0, RV_REG_T1, ctx);
1617                         insn_len = ctx->ninsns - insns_start;
1618                         break;
1619                 }
1620
1621                 ret = add_exception_handler(insn, ctx, rd, insn_len);
1622                 if (ret)
1623                         return ret;
1624
1625                 if (BPF_SIZE(code) != BPF_DW && insn_is_zext(&insn[1]))
1626                         return 1;
1627                 break;
1628         }
1629         /* speculation barrier */
1630         case BPF_ST | BPF_NOSPEC:
1631                 break;
1632
1633         /* ST: *(size *)(dst + off) = imm */
1634         case BPF_ST | BPF_MEM | BPF_B:
1635                 emit_imm(RV_REG_T1, imm, ctx);
1636                 if (is_12b_int(off)) {
1637                         emit(rv_sb(rd, off, RV_REG_T1), ctx);
1638                         break;
1639                 }
1640
1641                 emit_imm(RV_REG_T2, off, ctx);
1642                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1643                 emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1644                 break;
1645
1646         case BPF_ST | BPF_MEM | BPF_H:
1647                 emit_imm(RV_REG_T1, imm, ctx);
1648                 if (is_12b_int(off)) {
1649                         emit(rv_sh(rd, off, RV_REG_T1), ctx);
1650                         break;
1651                 }
1652
1653                 emit_imm(RV_REG_T2, off, ctx);
1654                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1655                 emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1656                 break;
1657         case BPF_ST | BPF_MEM | BPF_W:
1658                 emit_imm(RV_REG_T1, imm, ctx);
1659                 if (is_12b_int(off)) {
1660                         emit_sw(rd, off, RV_REG_T1, ctx);
1661                         break;
1662                 }
1663
1664                 emit_imm(RV_REG_T2, off, ctx);
1665                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1666                 emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1667                 break;
1668         case BPF_ST | BPF_MEM | BPF_DW:
1669                 emit_imm(RV_REG_T1, imm, ctx);
1670                 if (is_12b_int(off)) {
1671                         emit_sd(rd, off, RV_REG_T1, ctx);
1672                         break;
1673                 }
1674
1675                 emit_imm(RV_REG_T2, off, ctx);
1676                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1677                 emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1678                 break;
1679
1680         /* STX: *(size *)(dst + off) = src */
1681         case BPF_STX | BPF_MEM | BPF_B:
1682                 if (is_12b_int(off)) {
1683                         emit(rv_sb(rd, off, rs), ctx);
1684                         break;
1685                 }
1686
1687                 emit_imm(RV_REG_T1, off, ctx);
1688                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1689                 emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1690                 break;
1691         case BPF_STX | BPF_MEM | BPF_H:
1692                 if (is_12b_int(off)) {
1693                         emit(rv_sh(rd, off, rs), ctx);
1694                         break;
1695                 }
1696
1697                 emit_imm(RV_REG_T1, off, ctx);
1698                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1699                 emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1700                 break;
1701         case BPF_STX | BPF_MEM | BPF_W:
1702                 if (is_12b_int(off)) {
1703                         emit_sw(rd, off, rs, ctx);
1704                         break;
1705                 }
1706
1707                 emit_imm(RV_REG_T1, off, ctx);
1708                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1709                 emit_sw(RV_REG_T1, 0, rs, ctx);
1710                 break;
1711         case BPF_STX | BPF_MEM | BPF_DW:
1712                 if (is_12b_int(off)) {
1713                         emit_sd(rd, off, rs, ctx);
1714                         break;
1715                 }
1716
1717                 emit_imm(RV_REG_T1, off, ctx);
1718                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1719                 emit_sd(RV_REG_T1, 0, rs, ctx);
1720                 break;
1721         case BPF_STX | BPF_ATOMIC | BPF_W:
1722         case BPF_STX | BPF_ATOMIC | BPF_DW:
1723                 emit_atomic(rd, rs, off, imm,
1724                             BPF_SIZE(code) == BPF_DW, ctx);
1725                 break;
1726         default:
1727                 pr_err("bpf-jit: unknown opcode %02x\n", code);
1728                 return -EINVAL;
1729         }
1730
1731         return 0;
1732 }
1733
1734 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1735 {
1736         int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
1737
1738         bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1739         if (bpf_stack_adjust)
1740                 mark_fp(ctx);
1741
1742         if (seen_reg(RV_REG_RA, ctx))
1743                 stack_adjust += 8;
1744         stack_adjust += 8; /* RV_REG_FP */
1745         if (seen_reg(RV_REG_S1, ctx))
1746                 stack_adjust += 8;
1747         if (seen_reg(RV_REG_S2, ctx))
1748                 stack_adjust += 8;
1749         if (seen_reg(RV_REG_S3, ctx))
1750                 stack_adjust += 8;
1751         if (seen_reg(RV_REG_S4, ctx))
1752                 stack_adjust += 8;
1753         if (seen_reg(RV_REG_S5, ctx))
1754                 stack_adjust += 8;
1755         if (seen_reg(RV_REG_S6, ctx))
1756                 stack_adjust += 8;
1757
1758         stack_adjust = round_up(stack_adjust, 16);
1759         stack_adjust += bpf_stack_adjust;
1760
1761         store_offset = stack_adjust - 8;
1762
1763         /* nops reserved for auipc+jalr pair */
1764         for (i = 0; i < RV_FENTRY_NINSNS; i++)
1765                 emit(rv_nop(), ctx);
1766
1767         /* First instruction is always setting the tail-call-counter
1768          * (TCC) register. This instruction is skipped for tail calls.
1769          * Force using a 4-byte (non-compressed) instruction.
1770          */
1771         emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1772
1773         emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1774
1775         if (seen_reg(RV_REG_RA, ctx)) {
1776                 emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1777                 store_offset -= 8;
1778         }
1779         emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1780         store_offset -= 8;
1781         if (seen_reg(RV_REG_S1, ctx)) {
1782                 emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1783                 store_offset -= 8;
1784         }
1785         if (seen_reg(RV_REG_S2, ctx)) {
1786                 emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1787                 store_offset -= 8;
1788         }
1789         if (seen_reg(RV_REG_S3, ctx)) {
1790                 emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1791                 store_offset -= 8;
1792         }
1793         if (seen_reg(RV_REG_S4, ctx)) {
1794                 emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1795                 store_offset -= 8;
1796         }
1797         if (seen_reg(RV_REG_S5, ctx)) {
1798                 emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1799                 store_offset -= 8;
1800         }
1801         if (seen_reg(RV_REG_S6, ctx)) {
1802                 emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1803                 store_offset -= 8;
1804         }
1805
1806         emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1807
1808         if (bpf_stack_adjust)
1809                 emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1810
1811         /* Program contains calls and tail calls, so RV_REG_TCC need
1812          * to be saved across calls.
1813          */
1814         if (seen_tail_call(ctx) && seen_call(ctx))
1815                 emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1816
1817         ctx->stack_size = stack_adjust;
1818 }
1819
1820 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1821 {
1822         __build_epilogue(false, ctx);
1823 }
1824
1825 bool bpf_jit_supports_kfunc_call(void)
1826 {
1827         return true;
1828 }