OSDN Git Service

bpf: Support new sign-extension mov insns
authorYonghong Song <yonghong.song@linux.dev>
Fri, 28 Jul 2023 01:12:02 +0000 (18:12 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Fri, 28 Jul 2023 01:52:33 +0000 (18:52 -0700)
Add interpreter/jit support for new sign-extension mov insns.
The original 'MOV' insn is extended to support reg-to-reg
signed version for both ALU and ALU64 operations. For ALU mode,
the insn->off value of 8 or 16 indicates sign-extension
from 8- or 16-bit value to 32-bit value. For ALU64 mode,
the insn->off value of 8/16/32 indicates sign-extension
from 8-, 16- or 32-bit value to 64-bit value.

Acked-by: Eduard Zingerman <eddyz87@gmail.com>
Signed-off-by: Yonghong Song <yonghong.song@linux.dev>
Link: https://lore.kernel.org/r/20230728011202.3712300-1-yonghong.song@linux.dev
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
arch/x86/net/bpf_jit_comp.c
kernel/bpf/core.c
kernel/bpf/verifier.c

index 54478a9..031ef3c 100644 (file)
@@ -701,6 +701,38 @@ static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
        *pprog = prog;
 }
 
+static void emit_movsx_reg(u8 **pprog, int num_bits, bool is64, u32 dst_reg,
+                          u32 src_reg)
+{
+       u8 *prog = *pprog;
+
+       if (is64) {
+               /* movs[b,w,l]q dst, src */
+               if (num_bits == 8)
+                       EMIT4(add_2mod(0x48, src_reg, dst_reg), 0x0f, 0xbe,
+                             add_2reg(0xC0, src_reg, dst_reg));
+               else if (num_bits == 16)
+                       EMIT4(add_2mod(0x48, src_reg, dst_reg), 0x0f, 0xbf,
+                             add_2reg(0xC0, src_reg, dst_reg));
+               else if (num_bits == 32)
+                       EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x63,
+                             add_2reg(0xC0, src_reg, dst_reg));
+       } else {
+               /* movs[b,w]l dst, src */
+               if (num_bits == 8) {
+                       EMIT4(add_2mod(0x40, src_reg, dst_reg), 0x0f, 0xbe,
+                             add_2reg(0xC0, src_reg, dst_reg));
+               } else if (num_bits == 16) {
+                       if (is_ereg(dst_reg) || is_ereg(src_reg))
+                               EMIT1(add_2mod(0x40, src_reg, dst_reg));
+                       EMIT3(add_2mod(0x0f, src_reg, dst_reg), 0xbf,
+                             add_2reg(0xC0, src_reg, dst_reg));
+               }
+       }
+
+       *pprog = prog;
+}
+
 /* Emit the suffix (ModR/M etc) for addressing *(ptr_reg + off) and val_reg */
 static void emit_insn_suffix(u8 **pprog, u32 ptr_reg, u32 val_reg, int off)
 {
@@ -1051,9 +1083,14 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image
 
                case BPF_ALU64 | BPF_MOV | BPF_X:
                case BPF_ALU | BPF_MOV | BPF_X:
-                       emit_mov_reg(&prog,
-                                    BPF_CLASS(insn->code) == BPF_ALU64,
-                                    dst_reg, src_reg);
+                       if (insn->off == 0)
+                               emit_mov_reg(&prog,
+                                            BPF_CLASS(insn->code) == BPF_ALU64,
+                                            dst_reg, src_reg);
+                       else
+                               emit_movsx_reg(&prog, insn->off,
+                                              BPF_CLASS(insn->code) == BPF_ALU64,
+                                              dst_reg, src_reg);
                        break;
 
                        /* neg dst */
index 01b72fc..c37c454 100644 (file)
@@ -61,6 +61,7 @@
 #define AX     regs[BPF_REG_AX]
 #define ARG1   regs[BPF_REG_ARG1]
 #define CTX    regs[BPF_REG_CTX]
+#define OFF    insn->off
 #define IMM    insn->imm
 
 struct bpf_mem_alloc bpf_global_ma;
@@ -1739,13 +1740,36 @@ select_insn:
                DST = -DST;
                CONT;
        ALU_MOV_X:
-               DST = (u32) SRC;
+               switch (OFF) {
+               case 0:
+                       DST = (u32) SRC;
+                       break;
+               case 8:
+                       DST = (u32)(s8) SRC;
+                       break;
+               case 16:
+                       DST = (u32)(s16) SRC;
+                       break;
+               }
                CONT;
        ALU_MOV_K:
                DST = (u32) IMM;
                CONT;
        ALU64_MOV_X:
-               DST = SRC;
+               switch (OFF) {
+               case 0:
+                       DST = SRC;
+                       break;
+               case 8:
+                       DST = (s8) SRC;
+                       break;
+               case 16:
+                       DST = (s16) SRC;
+                       break;
+               case 32:
+                       DST = (s32) SRC;
+                       break;
+               }
                CONT;
        ALU64_MOV_K:
                DST = IMM;
index b154854..2f3eebc 100644 (file)
@@ -3421,7 +3421,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
                        return 0;
                if (opcode == BPF_MOV) {
                        if (BPF_SRC(insn->code) == BPF_X) {
-                               /* dreg = sreg
+                               /* dreg = sreg or dreg = (s8, s16, s32)sreg
                                 * dreg needs precision after this insn
                                 * sreg needs precision before this insn
                                 */
@@ -5905,6 +5905,69 @@ out:
        set_sext64_default_val(reg, size);
 }
 
+static void set_sext32_default_val(struct bpf_reg_state *reg, int size)
+{
+       if (size == 1) {
+               reg->s32_min_value = S8_MIN;
+               reg->s32_max_value = S8_MAX;
+       } else {
+               /* size == 2 */
+               reg->s32_min_value = S16_MIN;
+               reg->s32_max_value = S16_MAX;
+       }
+       reg->u32_min_value = 0;
+       reg->u32_max_value = U32_MAX;
+}
+
+static void coerce_subreg_to_size_sx(struct bpf_reg_state *reg, int size)
+{
+       s32 init_s32_max, init_s32_min, s32_max, s32_min, u32_val;
+       u32 top_smax_value, top_smin_value;
+       u32 num_bits = size * 8;
+
+       if (tnum_is_const(reg->var_off)) {
+               u32_val = reg->var_off.value;
+               if (size == 1)
+                       reg->var_off = tnum_const((s8)u32_val);
+               else
+                       reg->var_off = tnum_const((s16)u32_val);
+
+               u32_val = reg->var_off.value;
+               reg->s32_min_value = reg->s32_max_value = u32_val;
+               reg->u32_min_value = reg->u32_max_value = u32_val;
+               return;
+       }
+
+       top_smax_value = ((u32)reg->s32_max_value >> num_bits) << num_bits;
+       top_smin_value = ((u32)reg->s32_min_value >> num_bits) << num_bits;
+
+       if (top_smax_value != top_smin_value)
+               goto out;
+
+       /* find the s32_min and s32_min after sign extension */
+       if (size == 1) {
+               init_s32_max = (s8)reg->s32_max_value;
+               init_s32_min = (s8)reg->s32_min_value;
+       } else {
+               /* size == 2 */
+               init_s32_max = (s16)reg->s32_max_value;
+               init_s32_min = (s16)reg->s32_min_value;
+       }
+       s32_max = max(init_s32_max, init_s32_min);
+       s32_min = min(init_s32_max, init_s32_min);
+
+       if (s32_min >= 0 == s32_max >= 0) {
+               reg->s32_min_value = s32_min;
+               reg->s32_max_value = s32_max;
+               reg->u32_min_value = (u32)s32_min;
+               reg->u32_max_value = (u32)s32_max;
+               return;
+       }
+
+out:
+       set_sext32_default_val(reg, size);
+}
+
 static bool bpf_map_is_rdonly(const struct bpf_map *map)
 {
        /* A map is considered read-only if the following condition are true:
@@ -13038,11 +13101,24 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
        } else if (opcode == BPF_MOV) {
 
                if (BPF_SRC(insn->code) == BPF_X) {
-                       if (insn->imm != 0 || insn->off != 0) {
+                       if (insn->imm != 0) {
                                verbose(env, "BPF_MOV uses reserved fields\n");
                                return -EINVAL;
                        }
 
+                       if (BPF_CLASS(insn->code) == BPF_ALU) {
+                               if (insn->off != 0 && insn->off != 8 && insn->off != 16) {
+                                       verbose(env, "BPF_MOV uses reserved fields\n");
+                                       return -EINVAL;
+                               }
+                       } else {
+                               if (insn->off != 0 && insn->off != 8 && insn->off != 16 &&
+                                   insn->off != 32) {
+                                       verbose(env, "BPF_MOV uses reserved fields\n");
+                                       return -EINVAL;
+                               }
+                       }
+
                        /* check src operand */
                        err = check_reg_arg(env, insn->src_reg, SRC_OP);
                        if (err)
@@ -13066,18 +13142,33 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
                                       !tnum_is_const(src_reg->var_off);
 
                        if (BPF_CLASS(insn->code) == BPF_ALU64) {
-                               /* case: R1 = R2
-                                * copy register state to dest reg
-                                */
-                               if (need_id)
-                                       /* Assign src and dst registers the same ID
-                                        * that will be used by find_equal_scalars()
-                                        * to propagate min/max range.
+                               if (insn->off == 0) {
+                                       /* case: R1 = R2
+                                        * copy register state to dest reg
                                         */
-                                       src_reg->id = ++env->id_gen;
-                               copy_register_state(dst_reg, src_reg);
-                               dst_reg->live |= REG_LIVE_WRITTEN;
-                               dst_reg->subreg_def = DEF_NOT_SUBREG;
+                                       if (need_id)
+                                               /* Assign src and dst registers the same ID
+                                                * that will be used by find_equal_scalars()
+                                                * to propagate min/max range.
+                                                */
+                                               src_reg->id = ++env->id_gen;
+                                       copy_register_state(dst_reg, src_reg);
+                                       dst_reg->live |= REG_LIVE_WRITTEN;
+                                       dst_reg->subreg_def = DEF_NOT_SUBREG;
+                               } else {
+                                       /* case: R1 = (s8, s16 s32)R2 */
+                                       bool no_sext;
+
+                                       no_sext = src_reg->umax_value < (1ULL << (insn->off - 1));
+                                       if (no_sext && need_id)
+                                               src_reg->id = ++env->id_gen;
+                                       copy_register_state(dst_reg, src_reg);
+                                       if (!no_sext)
+                                               dst_reg->id = 0;
+                                       coerce_reg_to_size_sx(dst_reg, insn->off >> 3);
+                                       dst_reg->live |= REG_LIVE_WRITTEN;
+                                       dst_reg->subreg_def = DEF_NOT_SUBREG;
+                               }
                        } else {
                                /* R1 = (u32) R2 */
                                if (is_pointer_value(env, insn->src_reg)) {
@@ -13086,19 +13177,33 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
                                                insn->src_reg);
                                        return -EACCES;
                                } else if (src_reg->type == SCALAR_VALUE) {
-                                       bool is_src_reg_u32 = src_reg->umax_value <= U32_MAX;
-
-                                       if (is_src_reg_u32 && need_id)
-                                               src_reg->id = ++env->id_gen;
-                                       copy_register_state(dst_reg, src_reg);
-                                       /* Make sure ID is cleared if src_reg is not in u32 range otherwise
-                                        * dst_reg min/max could be incorrectly
-                                        * propagated into src_reg by find_equal_scalars()
-                                        */
-                                       if (!is_src_reg_u32)
-                                               dst_reg->id = 0;
-                                       dst_reg->live |= REG_LIVE_WRITTEN;
-                                       dst_reg->subreg_def = env->insn_idx + 1;
+                                       if (insn->off == 0) {
+                                               bool is_src_reg_u32 = src_reg->umax_value <= U32_MAX;
+
+                                               if (is_src_reg_u32 && need_id)
+                                                       src_reg->id = ++env->id_gen;
+                                               copy_register_state(dst_reg, src_reg);
+                                               /* Make sure ID is cleared if src_reg is not in u32
+                                                * range otherwise dst_reg min/max could be incorrectly
+                                                * propagated into src_reg by find_equal_scalars()
+                                                */
+                                               if (!is_src_reg_u32)
+                                                       dst_reg->id = 0;
+                                               dst_reg->live |= REG_LIVE_WRITTEN;
+                                               dst_reg->subreg_def = env->insn_idx + 1;
+                                       } else {
+                                               /* case: W1 = (s8, s16)W2 */
+                                               bool no_sext = src_reg->umax_value < (1ULL << (insn->off - 1));
+
+                                               if (no_sext && need_id)
+                                                       src_reg->id = ++env->id_gen;
+                                               copy_register_state(dst_reg, src_reg);
+                                               if (!no_sext)
+                                                       dst_reg->id = 0;
+                                               dst_reg->live |= REG_LIVE_WRITTEN;
+                                               dst_reg->subreg_def = env->insn_idx + 1;
+                                               coerce_subreg_to_size_sx(dst_reg, insn->off >> 3);
+                                       }
                                } else {
                                        mark_reg_unknown(env, regs,
                                                         insn->dst_reg);