OSDN Git Service

bpf, verifier: remove unneeded flow key in check_helper_mem_access
[uclinux-h8/linux.git] / kernel / bpf / verifier.c
index bb07e74..4f727c9 100644 (file)
@@ -1,5 +1,6 @@
 /* Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
  * Copyright (c) 2016 Facebook
+ * Copyright (c) 2018 Covalent IO, Inc. http://covalent.io
  *
  * This program is free software; you can redistribute it and/or
  * modify it under the terms of version 2 of the GNU General Public
@@ -80,8 +81,8 @@ static const struct bpf_verifier_ops * const bpf_verifier_ops[] = {
  * (like pointer plus pointer becomes SCALAR_VALUE type)
  *
  * When verifier sees load or store instructions the type of base register
- * can be: PTR_TO_MAP_VALUE, PTR_TO_CTX, PTR_TO_STACK. These are three pointer
- * types recognized by check_mem_access() function.
+ * can be: PTR_TO_MAP_VALUE, PTR_TO_CTX, PTR_TO_STACK, PTR_TO_SOCKET. These are
+ * four pointer types recognized by check_mem_access() function.
  *
  * PTR_TO_MAP_VALUE means that this register is pointing to 'map element value'
  * and the range of [ptr, ptr + map's value_size) is accessible.
@@ -140,6 +141,24 @@ static const struct bpf_verifier_ops * const bpf_verifier_ops[] = {
  *
  * After the call R0 is set to return type of the function and registers R1-R5
  * are set to NOT_INIT to indicate that they are no longer readable.
+ *
+ * The following reference types represent a potential reference to a kernel
+ * resource which, after first being allocated, must be checked and freed by
+ * the BPF program:
+ * - PTR_TO_SOCKET_OR_NULL, PTR_TO_SOCKET
+ *
+ * When the verifier sees a helper call return a reference type, it allocates a
+ * pointer id for the reference and stores it in the current function state.
+ * Similar to the way that PTR_TO_MAP_VALUE_OR_NULL is converted into
+ * PTR_TO_MAP_VALUE, PTR_TO_SOCKET_OR_NULL becomes PTR_TO_SOCKET when the type
+ * passes through a NULL-check conditional. For the branch wherein the state is
+ * changed to CONST_IMM, the verifier releases the reference.
+ *
+ * For each helper function that allocates a reference, such as
+ * bpf_sk_lookup_tcp(), there is a corresponding release function, such as
+ * bpf_sk_release(). When a reference type passes into the release function,
+ * the verifier also releases the reference. If any unchecked or unreleased
+ * reference remains at the end of the program, the verifier rejects it.
  */
 
 /* verifier_state + insn_idx are pushed to stack when branch is encountered */
@@ -189,6 +208,7 @@ struct bpf_call_arg_meta {
        int access_size;
        s64 msize_smax_value;
        u64 msize_umax_value;
+       int ptr_id;
 };
 
 static DEFINE_MUTEX(bpf_verifier_lock);
@@ -249,6 +269,46 @@ static bool type_is_pkt_pointer(enum bpf_reg_type type)
               type == PTR_TO_PACKET_META;
 }
 
+static bool reg_type_may_be_null(enum bpf_reg_type type)
+{
+       return type == PTR_TO_MAP_VALUE_OR_NULL ||
+              type == PTR_TO_SOCKET_OR_NULL;
+}
+
+static bool type_is_refcounted(enum bpf_reg_type type)
+{
+       return type == PTR_TO_SOCKET;
+}
+
+static bool type_is_refcounted_or_null(enum bpf_reg_type type)
+{
+       return type == PTR_TO_SOCKET || type == PTR_TO_SOCKET_OR_NULL;
+}
+
+static bool reg_is_refcounted(const struct bpf_reg_state *reg)
+{
+       return type_is_refcounted(reg->type);
+}
+
+static bool reg_is_refcounted_or_null(const struct bpf_reg_state *reg)
+{
+       return type_is_refcounted_or_null(reg->type);
+}
+
+static bool arg_type_is_refcounted(enum bpf_arg_type type)
+{
+       return type == ARG_PTR_TO_SOCKET;
+}
+
+/* Determine whether the function releases some resources allocated by another
+ * function call. The first reference type argument will be assumed to be
+ * released by release_reference().
+ */
+static bool is_release_function(enum bpf_func_id func_id)
+{
+       return func_id == BPF_FUNC_sk_release;
+}
+
 /* string representation of 'enum bpf_reg_type' */
 static const char * const reg_type_str[] = {
        [NOT_INIT]              = "?",
@@ -261,6 +321,16 @@ static const char * const reg_type_str[] = {
        [PTR_TO_PACKET]         = "pkt",
        [PTR_TO_PACKET_META]    = "pkt_meta",
        [PTR_TO_PACKET_END]     = "pkt_end",
+       [PTR_TO_FLOW_KEYS]      = "flow_keys",
+       [PTR_TO_SOCKET]         = "sock",
+       [PTR_TO_SOCKET_OR_NULL] = "sock_or_null",
+};
+
+static char slot_type_char[] = {
+       [STACK_INVALID] = '?',
+       [STACK_SPILL]   = 'r',
+       [STACK_MISC]    = 'm',
+       [STACK_ZERO]    = '0',
 };
 
 static void print_liveness(struct bpf_verifier_env *env,
@@ -349,72 +419,179 @@ static void print_verifier_state(struct bpf_verifier_env *env,
                }
        }
        for (i = 0; i < state->allocated_stack / BPF_REG_SIZE; i++) {
-               if (state->stack[i].slot_type[0] == STACK_SPILL) {
-                       verbose(env, " fp%d",
-                               (-i - 1) * BPF_REG_SIZE);
-                       print_liveness(env, state->stack[i].spilled_ptr.live);
+               char types_buf[BPF_REG_SIZE + 1];
+               bool valid = false;
+               int j;
+
+               for (j = 0; j < BPF_REG_SIZE; j++) {
+                       if (state->stack[i].slot_type[j] != STACK_INVALID)
+                               valid = true;
+                       types_buf[j] = slot_type_char[
+                                       state->stack[i].slot_type[j]];
+               }
+               types_buf[BPF_REG_SIZE] = 0;
+               if (!valid)
+                       continue;
+               verbose(env, " fp%d", (-i - 1) * BPF_REG_SIZE);
+               print_liveness(env, state->stack[i].spilled_ptr.live);
+               if (state->stack[i].slot_type[0] == STACK_SPILL)
                        verbose(env, "=%s",
                                reg_type_str[state->stack[i].spilled_ptr.type]);
-               }
-               if (state->stack[i].slot_type[0] == STACK_ZERO)
-                       verbose(env, " fp%d=0", (-i - 1) * BPF_REG_SIZE);
+               else
+                       verbose(env, "=%s", types_buf);
+       }
+       if (state->acquired_refs && state->refs[0].id) {
+               verbose(env, " refs=%d", state->refs[0].id);
+               for (i = 1; i < state->acquired_refs; i++)
+                       if (state->refs[i].id)
+                               verbose(env, ",%d", state->refs[i].id);
        }
        verbose(env, "\n");
 }
 
-static int copy_stack_state(struct bpf_func_state *dst,
-                           const struct bpf_func_state *src)
-{
-       if (!src->stack)
-               return 0;
-       if (WARN_ON_ONCE(dst->allocated_stack < src->allocated_stack)) {
-               /* internal bug, make state invalid to reject the program */
-               memset(dst, 0, sizeof(*dst));
-               return -EFAULT;
-       }
-       memcpy(dst->stack, src->stack,
-              sizeof(*src->stack) * (src->allocated_stack / BPF_REG_SIZE));
-       return 0;
-}
+#define COPY_STATE_FN(NAME, COUNT, FIELD, SIZE)                                \
+static int copy_##NAME##_state(struct bpf_func_state *dst,             \
+                              const struct bpf_func_state *src)        \
+{                                                                      \
+       if (!src->FIELD)                                                \
+               return 0;                                               \
+       if (WARN_ON_ONCE(dst->COUNT < src->COUNT)) {                    \
+               /* internal bug, make state invalid to reject the program */ \
+               memset(dst, 0, sizeof(*dst));                           \
+               return -EFAULT;                                         \
+       }                                                               \
+       memcpy(dst->FIELD, src->FIELD,                                  \
+              sizeof(*src->FIELD) * (src->COUNT / SIZE));              \
+       return 0;                                                       \
+}
+/* copy_reference_state() */
+COPY_STATE_FN(reference, acquired_refs, refs, 1)
+/* copy_stack_state() */
+COPY_STATE_FN(stack, allocated_stack, stack, BPF_REG_SIZE)
+#undef COPY_STATE_FN
+
+#define REALLOC_STATE_FN(NAME, COUNT, FIELD, SIZE)                     \
+static int realloc_##NAME##_state(struct bpf_func_state *state, int size, \
+                                 bool copy_old)                        \
+{                                                                      \
+       u32 old_size = state->COUNT;                                    \
+       struct bpf_##NAME##_state *new_##FIELD;                         \
+       int slot = size / SIZE;                                         \
+                                                                       \
+       if (size <= old_size || !size) {                                \
+               if (copy_old)                                           \
+                       return 0;                                       \
+               state->COUNT = slot * SIZE;                             \
+               if (!size && old_size) {                                \
+                       kfree(state->FIELD);                            \
+                       state->FIELD = NULL;                            \
+               }                                                       \
+               return 0;                                               \
+       }                                                               \
+       new_##FIELD = kmalloc_array(slot, sizeof(struct bpf_##NAME##_state), \
+                                   GFP_KERNEL);                        \
+       if (!new_##FIELD)                                               \
+               return -ENOMEM;                                         \
+       if (copy_old) {                                                 \
+               if (state->FIELD)                                       \
+                       memcpy(new_##FIELD, state->FIELD,               \
+                              sizeof(*new_##FIELD) * (old_size / SIZE)); \
+               memset(new_##FIELD + old_size / SIZE, 0,                \
+                      sizeof(*new_##FIELD) * (size - old_size) / SIZE); \
+       }                                                               \
+       state->COUNT = slot * SIZE;                                     \
+       kfree(state->FIELD);                                            \
+       state->FIELD = new_##FIELD;                                     \
+       return 0;                                                       \
+}
+/* realloc_reference_state() */
+REALLOC_STATE_FN(reference, acquired_refs, refs, 1)
+/* realloc_stack_state() */
+REALLOC_STATE_FN(stack, allocated_stack, stack, BPF_REG_SIZE)
+#undef REALLOC_STATE_FN
 
 /* do_check() starts with zero-sized stack in struct bpf_verifier_state to
  * make it consume minimal amount of memory. check_stack_write() access from
  * the program calls into realloc_func_state() to grow the stack size.
  * Note there is a non-zero 'parent' pointer inside bpf_verifier_state
- * which this function copies over. It points to previous bpf_verifier_state
- * which is never reallocated
+ * which realloc_stack_state() copies over. It points to previous
+ * bpf_verifier_state which is never reallocated.
+ */
+static int realloc_func_state(struct bpf_func_state *state, int stack_size,
+                             int refs_size, bool copy_old)
+{
+       int err = realloc_reference_state(state, refs_size, copy_old);
+       if (err)
+               return err;
+       return realloc_stack_state(state, stack_size, copy_old);
+}
+
+/* Acquire a pointer id from the env and update the state->refs to include
+ * this new pointer reference.
+ * On success, returns a valid pointer id to associate with the register
+ * On failure, returns a negative errno.
  */
-static int realloc_func_state(struct bpf_func_state *state, int size,
-                             bool copy_old)
+static int acquire_reference_state(struct bpf_verifier_env *env, int insn_idx)
 {
-       u32 old_size = state->allocated_stack;
-       struct bpf_stack_state *new_stack;
-       int slot = size / BPF_REG_SIZE;
+       struct bpf_func_state *state = cur_func(env);
+       int new_ofs = state->acquired_refs;
+       int id, err;
+
+       err = realloc_reference_state(state, state->acquired_refs + 1, true);
+       if (err)
+               return err;
+       id = ++env->id_gen;
+       state->refs[new_ofs].id = id;
+       state->refs[new_ofs].insn_idx = insn_idx;
 
-       if (size <= old_size || !size) {
-               if (copy_old)
+       return id;
+}
+
+/* release function corresponding to acquire_reference_state(). Idempotent. */
+static int __release_reference_state(struct bpf_func_state *state, int ptr_id)
+{
+       int i, last_idx;
+
+       if (!ptr_id)
+               return -EFAULT;
+
+       last_idx = state->acquired_refs - 1;
+       for (i = 0; i < state->acquired_refs; i++) {
+               if (state->refs[i].id == ptr_id) {
+                       if (last_idx && i != last_idx)
+                               memcpy(&state->refs[i], &state->refs[last_idx],
+                                      sizeof(*state->refs));
+                       memset(&state->refs[last_idx], 0, sizeof(*state->refs));
+                       state->acquired_refs--;
                        return 0;
-               state->allocated_stack = slot * BPF_REG_SIZE;
-               if (!size && old_size) {
-                       kfree(state->stack);
-                       state->stack = NULL;
                }
-               return 0;
        }
-       new_stack = kmalloc_array(slot, sizeof(struct bpf_stack_state),
-                                 GFP_KERNEL);
-       if (!new_stack)
-               return -ENOMEM;
-       if (copy_old) {
-               if (state->stack)
-                       memcpy(new_stack, state->stack,
-                              sizeof(*new_stack) * (old_size / BPF_REG_SIZE));
-               memset(new_stack + old_size / BPF_REG_SIZE, 0,
-                      sizeof(*new_stack) * (size - old_size) / BPF_REG_SIZE);
-       }
-       state->allocated_stack = slot * BPF_REG_SIZE;
-       kfree(state->stack);
-       state->stack = new_stack;
+       return -EFAULT;
+}
+
+/* variation on the above for cases where we expect that there must be an
+ * outstanding reference for the specified ptr_id.
+ */
+static int release_reference_state(struct bpf_verifier_env *env, int ptr_id)
+{
+       struct bpf_func_state *state = cur_func(env);
+       int err;
+
+       err = __release_reference_state(state, ptr_id);
+       if (WARN_ON_ONCE(err != 0))
+               verbose(env, "verifier internal error: can't release reference\n");
+       return err;
+}
+
+static int transfer_reference_state(struct bpf_func_state *dst,
+                                   struct bpf_func_state *src)
+{
+       int err = realloc_reference_state(dst, src->acquired_refs, false);
+       if (err)
+               return err;
+       err = copy_reference_state(dst, src);
+       if (err)
+               return err;
        return 0;
 }
 
@@ -422,6 +599,7 @@ static void free_func_state(struct bpf_func_state *state)
 {
        if (!state)
                return;
+       kfree(state->refs);
        kfree(state->stack);
        kfree(state);
 }
@@ -447,10 +625,14 @@ static int copy_func_state(struct bpf_func_state *dst,
 {
        int err;
 
-       err = realloc_func_state(dst, src->allocated_stack, false);
+       err = realloc_func_state(dst, src->allocated_stack, src->acquired_refs,
+                                false);
+       if (err)
+               return err;
+       memcpy(dst, src, offsetof(struct bpf_func_state, acquired_refs));
+       err = copy_reference_state(dst, src);
        if (err)
                return err;
-       memcpy(dst, src, offsetof(struct bpf_func_state, allocated_stack));
        return copy_stack_state(dst, src);
 }
 
@@ -466,7 +648,6 @@ static int copy_verifier_state(struct bpf_verifier_state *dst_state,
                dst_state->frame[i] = NULL;
        }
        dst_state->curframe = src->curframe;
-       dst_state->parent = src->parent;
        for (i = 0; i <= src->curframe; i++) {
                dst = dst_state->frame[i];
                if (!dst) {
@@ -553,7 +734,9 @@ static void __mark_reg_not_init(struct bpf_reg_state *reg);
  */
 static void __mark_reg_known(struct bpf_reg_state *reg, u64 imm)
 {
-       reg->id = 0;
+       /* Clear id, off, and union(map_ptr, range) */
+       memset(((u8 *)reg) + sizeof(reg->type), 0,
+              offsetof(struct bpf_reg_state, var_off) - sizeof(reg->type));
        reg->var_off = tnum_const(imm);
        reg->smin_value = (s64)imm;
        reg->smax_value = (s64)imm;
@@ -572,7 +755,6 @@ static void __mark_reg_known_zero(struct bpf_reg_state *reg)
 static void __mark_reg_const_zero(struct bpf_reg_state *reg)
 {
        __mark_reg_known(reg, 0);
-       reg->off = 0;
        reg->type = SCALAR_VALUE;
 }
 
@@ -683,9 +865,12 @@ static void __mark_reg_unbounded(struct bpf_reg_state *reg)
 /* Mark a register as having a completely unknown (scalar) value. */
 static void __mark_reg_unknown(struct bpf_reg_state *reg)
 {
+       /*
+        * Clear type, id, off, and union(map_ptr, range) and
+        * padding between 'type' and union
+        */
+       memset(reg, 0, offsetof(struct bpf_reg_state, var_off));
        reg->type = SCALAR_VALUE;
-       reg->id = 0;
-       reg->off = 0;
        reg->var_off = tnum_unknown;
        reg->frameno = 0;
        __mark_reg_unbounded(reg);
@@ -732,6 +917,7 @@ static void init_reg_state(struct bpf_verifier_env *env,
        for (i = 0; i < MAX_BPF_REG; i++) {
                mark_reg_not_init(env, regs, i);
                regs[i].live = REG_LIVE_NONE;
+               regs[i].parent = NULL;
        }
 
        /* frame pointer */
@@ -823,10 +1009,6 @@ static int check_subprogs(struct bpf_verifier_env *env)
                        verbose(env, "function calls to other bpf functions are allowed for root only\n");
                        return -EPERM;
                }
-               if (bpf_prog_is_dev_bound(env->prog->aux)) {
-                       verbose(env, "function calls in offloaded programs are not supported yet\n");
-                       return -EINVAL;
-               }
                ret = add_subprog(env, i + insn[i].imm + 1);
                if (ret < 0)
                        return ret;
@@ -876,74 +1058,21 @@ next:
        return 0;
 }
 
-static
-struct bpf_verifier_state *skip_callee(struct bpf_verifier_env *env,
-                                      const struct bpf_verifier_state *state,
-                                      struct bpf_verifier_state *parent,
-                                      u32 regno)
-{
-       struct bpf_verifier_state *tmp = NULL;
-
-       /* 'parent' could be a state of caller and
-        * 'state' could be a state of callee. In such case
-        * parent->curframe < state->curframe
-        * and it's ok for r1 - r5 registers
-        *
-        * 'parent' could be a callee's state after it bpf_exit-ed.
-        * In such case parent->curframe > state->curframe
-        * and it's ok for r0 only
-        */
-       if (parent->curframe == state->curframe ||
-           (parent->curframe < state->curframe &&
-            regno >= BPF_REG_1 && regno <= BPF_REG_5) ||
-           (parent->curframe > state->curframe &&
-              regno == BPF_REG_0))
-               return parent;
-
-       if (parent->curframe > state->curframe &&
-           regno >= BPF_REG_6) {
-               /* for callee saved regs we have to skip the whole chain
-                * of states that belong to callee and mark as LIVE_READ
-                * the registers before the call
-                */
-               tmp = parent;
-               while (tmp && tmp->curframe != state->curframe) {
-                       tmp = tmp->parent;
-               }
-               if (!tmp)
-                       goto bug;
-               parent = tmp;
-       } else {
-               goto bug;
-       }
-       return parent;
-bug:
-       verbose(env, "verifier bug regno %d tmp %p\n", regno, tmp);
-       verbose(env, "regno %d parent frame %d current frame %d\n",
-               regno, parent->curframe, state->curframe);
-       return NULL;
-}
-
+/* Parentage chain of this register (or stack slot) should take care of all
+ * issues like callee-saved registers, stack slot allocation time, etc.
+ */
 static int mark_reg_read(struct bpf_verifier_env *env,
-                        const struct bpf_verifier_state *state,
-                        struct bpf_verifier_state *parent,
-                        u32 regno)
+                        const struct bpf_reg_state *state,
+                        struct bpf_reg_state *parent)
 {
        bool writes = parent == state->parent; /* Observe write marks */
 
-       if (regno == BPF_REG_FP)
-               /* We don't need to worry about FP liveness because it's read-only */
-               return 0;
-
        while (parent) {
                /* if read wasn't screened by an earlier write ... */
-               if (writes && state->frame[state->curframe]->regs[regno].live & REG_LIVE_WRITTEN)
+               if (writes && state->live & REG_LIVE_WRITTEN)
                        break;
-               parent = skip_callee(env, state, parent, regno);
-               if (!parent)
-                       return -EFAULT;
                /* ... then we depend on parent's value */
-               parent->frame[parent->curframe]->regs[regno].live |= REG_LIVE_READ;
+               parent->live |= REG_LIVE_READ;
                state = parent;
                parent = state->parent;
                writes = true;
@@ -969,7 +1098,10 @@ static int check_reg_arg(struct bpf_verifier_env *env, u32 regno,
                        verbose(env, "R%d !read_ok\n", regno);
                        return -EACCES;
                }
-               return mark_reg_read(env, vstate, vstate->parent, regno);
+               /* We don't need to worry about FP liveness because it's read-only */
+               if (regno != BPF_REG_FP)
+                       return mark_reg_read(env, &regs[regno],
+                                            regs[regno].parent);
        } else {
                /* check whether register used as dest operand can be written to */
                if (regno == BPF_REG_FP) {
@@ -993,7 +1125,10 @@ static bool is_spillable_regtype(enum bpf_reg_type type)
        case PTR_TO_PACKET:
        case PTR_TO_PACKET_META:
        case PTR_TO_PACKET_END:
+       case PTR_TO_FLOW_KEYS:
        case CONST_PTR_TO_MAP:
+       case PTR_TO_SOCKET:
+       case PTR_TO_SOCKET_OR_NULL:
                return true;
        default:
                return false;
@@ -1018,7 +1153,7 @@ static int check_stack_write(struct bpf_verifier_env *env,
        enum bpf_reg_type type;
 
        err = realloc_func_state(state, round_up(slot + 1, BPF_REG_SIZE),
-                                true);
+                                state->acquired_refs, true);
        if (err)
                return err;
        /* caller checked that off % size == 0 and -MAX_BPF_STACK <= off < 0,
@@ -1080,8 +1215,8 @@ static int check_stack_write(struct bpf_verifier_env *env,
        } else {
                u8 type = STACK_MISC;
 
-               /* regular write of data into stack */
-               state->stack[spi].spilled_ptr = (struct bpf_reg_state) {};
+               /* regular write of data into stack destroys any spilled ptr */
+               state->stack[spi].spilled_ptr.type = NOT_INIT;
 
                /* only mark the slot as written if all 8 bytes were written
                 * otherwise read propagation may incorrectly stop too soon
@@ -1106,61 +1241,6 @@ static int check_stack_write(struct bpf_verifier_env *env,
        return 0;
 }
 
-/* registers of every function are unique and mark_reg_read() propagates
- * the liveness in the following cases:
- * - from callee into caller for R1 - R5 that were used as arguments
- * - from caller into callee for R0 that used as result of the call
- * - from caller to the same caller skipping states of the callee for R6 - R9,
- *   since R6 - R9 are callee saved by implicit function prologue and
- *   caller's R6 != callee's R6, so when we propagate liveness up to
- *   parent states we need to skip callee states for R6 - R9.
- *
- * stack slot marking is different, since stacks of caller and callee are
- * accessible in both (since caller can pass a pointer to caller's stack to
- * callee which can pass it to another function), hence mark_stack_slot_read()
- * has to propagate the stack liveness to all parent states at given frame number.
- * Consider code:
- * f1() {
- *   ptr = fp - 8;
- *   *ptr = ctx;
- *   call f2 {
- *      .. = *ptr;
- *   }
- *   .. = *ptr;
- * }
- * First *ptr is reading from f1's stack and mark_stack_slot_read() has
- * to mark liveness at the f1's frame and not f2's frame.
- * Second *ptr is also reading from f1's stack and mark_stack_slot_read() has
- * to propagate liveness to f2 states at f1's frame level and further into
- * f1 states at f1's frame level until write into that stack slot
- */
-static void mark_stack_slot_read(struct bpf_verifier_env *env,
-                                const struct bpf_verifier_state *state,
-                                struct bpf_verifier_state *parent,
-                                int slot, int frameno)
-{
-       bool writes = parent == state->parent; /* Observe write marks */
-
-       while (parent) {
-               if (parent->frame[frameno]->allocated_stack <= slot * BPF_REG_SIZE)
-                       /* since LIVE_WRITTEN mark is only done for full 8-byte
-                        * write the read marks are conservative and parent
-                        * state may not even have the stack allocated. In such case
-                        * end the propagation, since the loop reached beginning
-                        * of the function
-                        */
-                       break;
-               /* if read wasn't screened by an earlier write ... */
-               if (writes && state->frame[frameno]->stack[slot].spilled_ptr.live & REG_LIVE_WRITTEN)
-                       break;
-               /* ... then we depend on parent's value */
-               parent->frame[frameno]->stack[slot].spilled_ptr.live |= REG_LIVE_READ;
-               state = parent;
-               parent = state->parent;
-               writes = true;
-       }
-}
-
 static int check_stack_read(struct bpf_verifier_env *env,
                            struct bpf_func_state *reg_state /* func where register points to */,
                            int off, int size, int value_regno)
@@ -1198,8 +1278,8 @@ static int check_stack_read(struct bpf_verifier_env *env,
                         */
                        state->regs[value_regno].live |= REG_LIVE_WRITTEN;
                }
-               mark_stack_slot_read(env, vstate, vstate->parent, spi,
-                                    reg_state->frameno);
+               mark_reg_read(env, &reg_state->stack[spi].spilled_ptr,
+                             reg_state->stack[spi].spilled_ptr.parent);
                return 0;
        } else {
                int zeros = 0;
@@ -1215,8 +1295,8 @@ static int check_stack_read(struct bpf_verifier_env *env,
                                off, i, size);
                        return -EACCES;
                }
-               mark_stack_slot_read(env, vstate, vstate->parent, spi,
-                                    reg_state->frameno);
+               mark_reg_read(env, &reg_state->stack[spi].spilled_ptr,
+                             reg_state->stack[spi].spilled_ptr.parent);
                if (value_regno >= 0) {
                        if (zeros == size) {
                                /* any size read into register is zero extended,
@@ -1321,6 +1401,7 @@ static bool may_access_direct_pkt_data(struct bpf_verifier_env *env,
        case BPF_PROG_TYPE_LWT_XMIT:
        case BPF_PROG_TYPE_SK_SKB:
        case BPF_PROG_TYPE_SK_MSG:
+       case BPF_PROG_TYPE_FLOW_DISSECTOR:
                if (meta)
                        return meta->pkt_access;
 
@@ -1404,6 +1485,40 @@ static int check_ctx_access(struct bpf_verifier_env *env, int insn_idx, int off,
        return -EACCES;
 }
 
+static int check_flow_keys_access(struct bpf_verifier_env *env, int off,
+                                 int size)
+{
+       if (size < 0 || off < 0 ||
+           (u64)off + size > sizeof(struct bpf_flow_keys)) {
+               verbose(env, "invalid access to flow keys off=%d size=%d\n",
+                       off, size);
+               return -EACCES;
+       }
+       return 0;
+}
+
+static int check_sock_access(struct bpf_verifier_env *env, u32 regno, int off,
+                            int size, enum bpf_access_type t)
+{
+       struct bpf_reg_state *regs = cur_regs(env);
+       struct bpf_reg_state *reg = &regs[regno];
+       struct bpf_insn_access_aux info;
+
+       if (reg->smin_value < 0) {
+               verbose(env, "R%d min value is negative, either use unsigned index or do a if (index >=0) check.\n",
+                       regno);
+               return -EACCES;
+       }
+
+       if (!bpf_sock_is_valid_access(off, size, t, &info)) {
+               verbose(env, "invalid bpf_sock access off=%d size=%d\n",
+                       off, size);
+               return -EACCES;
+       }
+
+       return 0;
+}
+
 static bool __is_pointer_value(bool allow_ptr_leaks,
                               const struct bpf_reg_state *reg)
 {
@@ -1413,25 +1528,39 @@ static bool __is_pointer_value(bool allow_ptr_leaks,
        return reg->type != SCALAR_VALUE;
 }
 
+static struct bpf_reg_state *reg_state(struct bpf_verifier_env *env, int regno)
+{
+       return cur_regs(env) + regno;
+}
+
 static bool is_pointer_value(struct bpf_verifier_env *env, int regno)
 {
-       return __is_pointer_value(env->allow_ptr_leaks, cur_regs(env) + regno);
+       return __is_pointer_value(env->allow_ptr_leaks, reg_state(env, regno));
 }
 
 static bool is_ctx_reg(struct bpf_verifier_env *env, int regno)
 {
-       const struct bpf_reg_state *reg = cur_regs(env) + regno;
+       const struct bpf_reg_state *reg = reg_state(env, regno);
 
-       return reg->type == PTR_TO_CTX;
+       return reg->type == PTR_TO_CTX ||
+              reg->type == PTR_TO_SOCKET;
 }
 
 static bool is_pkt_reg(struct bpf_verifier_env *env, int regno)
 {
-       const struct bpf_reg_state *reg = cur_regs(env) + regno;
+       const struct bpf_reg_state *reg = reg_state(env, regno);
 
        return type_is_pkt_pointer(reg->type);
 }
 
+static bool is_flow_key_reg(struct bpf_verifier_env *env, int regno)
+{
+       const struct bpf_reg_state *reg = reg_state(env, regno);
+
+       /* Separate to is_ctx_reg() since we still want to allow BPF_ST here. */
+       return reg->type == PTR_TO_FLOW_KEYS;
+}
+
 static int check_pkt_ptr_alignment(struct bpf_verifier_env *env,
                                   const struct bpf_reg_state *reg,
                                   int off, int size, bool strict)
@@ -1505,6 +1634,9 @@ static int check_ptr_alignment(struct bpf_verifier_env *env,
                 * right in front, treat it the very same way.
                 */
                return check_pkt_ptr_alignment(env, reg, off, size, strict);
+       case PTR_TO_FLOW_KEYS:
+               pointer_desc = "flow keys ";
+               break;
        case PTR_TO_MAP_VALUE:
                pointer_desc = "value ";
                break;
@@ -1519,6 +1651,9 @@ static int check_ptr_alignment(struct bpf_verifier_env *env,
                 */
                strict = true;
                break;
+       case PTR_TO_SOCKET:
+               pointer_desc = "sock ";
+               break;
        default:
                break;
        }
@@ -1727,9 +1862,6 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
                        else
                                mark_reg_known_zero(env, regs,
                                                    value_regno);
-                       regs[value_regno].id = 0;
-                       regs[value_regno].off = 0;
-                       regs[value_regno].range = 0;
                        regs[value_regno].type = reg_type;
                }
 
@@ -1778,6 +1910,25 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
                err = check_packet_access(env, regno, off, size, false);
                if (!err && t == BPF_READ && value_regno >= 0)
                        mark_reg_unknown(env, regs, value_regno);
+       } else if (reg->type == PTR_TO_FLOW_KEYS) {
+               if (t == BPF_WRITE && value_regno >= 0 &&
+                   is_pointer_value(env, value_regno)) {
+                       verbose(env, "R%d leaks addr into flow keys\n",
+                               value_regno);
+                       return -EACCES;
+               }
+
+               err = check_flow_keys_access(env, off, size);
+               if (!err && t == BPF_READ && value_regno >= 0)
+                       mark_reg_unknown(env, regs, value_regno);
+       } else if (reg->type == PTR_TO_SOCKET) {
+               if (t == BPF_WRITE) {
+                       verbose(env, "cannot write into socket\n");
+                       return -EACCES;
+               }
+               err = check_sock_access(env, regno, off, size, t);
+               if (!err && value_regno >= 0)
+                       mark_reg_unknown(env, regs, value_regno);
        } else {
                verbose(env, "R%d invalid mem access '%s'\n", regno,
                        reg_type_str[reg->type]);
@@ -1818,10 +1969,11 @@ static int check_xadd(struct bpf_verifier_env *env, int insn_idx, struct bpf_ins
        }
 
        if (is_ctx_reg(env, insn->dst_reg) ||
-           is_pkt_reg(env, insn->dst_reg)) {
+           is_pkt_reg(env, insn->dst_reg) ||
+           is_flow_key_reg(env, insn->dst_reg)) {
                verbose(env, "BPF_XADD stores into R%d %s is not allowed\n",
-                       insn->dst_reg, is_ctx_reg(env, insn->dst_reg) ?
-                       "context" : "packet");
+                       insn->dst_reg,
+                       reg_type_str[reg_state(env, insn->dst_reg)->type]);
                return -EACCES;
        }
 
@@ -1846,7 +1998,7 @@ static int check_stack_boundary(struct bpf_verifier_env *env, int regno,
                                int access_size, bool zero_size_allowed,
                                struct bpf_call_arg_meta *meta)
 {
-       struct bpf_reg_state *reg = cur_regs(env) + regno;
+       struct bpf_reg_state *reg = reg_state(env, regno);
        struct bpf_func_state *state = func(env, reg);
        int off, i, slot, spi;
 
@@ -1908,8 +2060,8 @@ mark:
                /* reading any byte out of 8-byte 'spill_slot' will cause
                 * the whole slot to be marked as 'read'
                 */
-               mark_stack_slot_read(env, env->cur_state, env->cur_state->parent,
-                                    spi, state->frameno);
+               mark_reg_read(env, &state->stack[spi].spilled_ptr,
+                             state->stack[spi].spilled_ptr.parent);
        }
        return update_stack_depth(env, state, off);
 }
@@ -1978,7 +2130,8 @@ static int check_func_arg(struct bpf_verifier_env *env, u32 regno,
        }
 
        if (arg_type == ARG_PTR_TO_MAP_KEY ||
-           arg_type == ARG_PTR_TO_MAP_VALUE) {
+           arg_type == ARG_PTR_TO_MAP_VALUE ||
+           arg_type == ARG_PTR_TO_UNINIT_MAP_VALUE) {
                expected_type = PTR_TO_STACK;
                if (!type_is_pkt_pointer(type) && type != PTR_TO_MAP_VALUE &&
                    type != expected_type)
@@ -1999,6 +2152,16 @@ static int check_func_arg(struct bpf_verifier_env *env, u32 regno,
                err = check_ctx_reg(env, reg, regno);
                if (err < 0)
                        return err;
+       } else if (arg_type == ARG_PTR_TO_SOCKET) {
+               expected_type = PTR_TO_SOCKET;
+               if (type != expected_type)
+                       goto err_type;
+               if (meta->ptr_id || !reg->id) {
+                       verbose(env, "verifier internal error: mismatched references meta=%d, reg=%d\n",
+                               meta->ptr_id, reg->id);
+                       return -EFAULT;
+               }
+               meta->ptr_id = reg->id;
        } else if (arg_type_is_mem_ptr(arg_type)) {
                expected_type = PTR_TO_STACK;
                /* One exception here. In case function allows for NULL to be
@@ -2038,7 +2201,8 @@ static int check_func_arg(struct bpf_verifier_env *env, u32 regno,
                err = check_helper_mem_access(env, regno,
                                              meta->map_ptr->key_size, false,
                                              NULL);
-       } else if (arg_type == ARG_PTR_TO_MAP_VALUE) {
+       } else if (arg_type == ARG_PTR_TO_MAP_VALUE ||
+                  arg_type == ARG_PTR_TO_UNINIT_MAP_VALUE) {
                /* bpf_map_xxx(..., map_ptr, ..., value) call:
                 * check [value, value + map->value_size) validity
                 */
@@ -2047,9 +2211,10 @@ static int check_func_arg(struct bpf_verifier_env *env, u32 regno,
                        verbose(env, "invalid map_ptr to access map->value\n");
                        return -EACCES;
                }
+               meta->raw_mode = (arg_type == ARG_PTR_TO_UNINIT_MAP_VALUE);
                err = check_helper_mem_access(env, regno,
                                              meta->map_ptr->value_size, false,
-                                             NULL);
+                                             meta);
        } else if (arg_type_is_mem_size(arg_type)) {
                bool zero_size_allowed = (arg_type == ARG_CONST_SIZE_OR_ZERO);
 
@@ -2129,6 +2294,7 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
                        goto error;
                break;
        case BPF_MAP_TYPE_CGROUP_STORAGE:
+       case BPF_MAP_TYPE_PERCPU_CGROUP_STORAGE:
                if (func_id != BPF_FUNC_get_local_storage)
                        goto error;
                break;
@@ -2171,6 +2337,13 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
                if (func_id != BPF_FUNC_sk_select_reuseport)
                        goto error;
                break;
+       case BPF_MAP_TYPE_QUEUE:
+       case BPF_MAP_TYPE_STACK:
+               if (func_id != BPF_FUNC_map_peek_elem &&
+                   func_id != BPF_FUNC_map_pop_elem &&
+                   func_id != BPF_FUNC_map_push_elem)
+                       goto error;
+               break;
        default:
                break;
        }
@@ -2219,13 +2392,21 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
                        goto error;
                break;
        case BPF_FUNC_get_local_storage:
-               if (map->map_type != BPF_MAP_TYPE_CGROUP_STORAGE)
+               if (map->map_type != BPF_MAP_TYPE_CGROUP_STORAGE &&
+                   map->map_type != BPF_MAP_TYPE_PERCPU_CGROUP_STORAGE)
                        goto error;
                break;
        case BPF_FUNC_sk_select_reuseport:
                if (map->map_type != BPF_MAP_TYPE_REUSEPORT_SOCKARRAY)
                        goto error;
                break;
+       case BPF_FUNC_map_peek_elem:
+       case BPF_FUNC_map_pop_elem:
+       case BPF_FUNC_map_push_elem:
+               if (map->map_type != BPF_MAP_TYPE_QUEUE &&
+                   map->map_type != BPF_MAP_TYPE_STACK)
+                       goto error;
+               break;
        default:
                break;
        }
@@ -2286,10 +2467,32 @@ static bool check_arg_pair_ok(const struct bpf_func_proto *fn)
        return true;
 }
 
+static bool check_refcount_ok(const struct bpf_func_proto *fn)
+{
+       int count = 0;
+
+       if (arg_type_is_refcounted(fn->arg1_type))
+               count++;
+       if (arg_type_is_refcounted(fn->arg2_type))
+               count++;
+       if (arg_type_is_refcounted(fn->arg3_type))
+               count++;
+       if (arg_type_is_refcounted(fn->arg4_type))
+               count++;
+       if (arg_type_is_refcounted(fn->arg5_type))
+               count++;
+
+       /* We only support one arg being unreferenced at the moment,
+        * which is sufficient for the helper functions we have right now.
+        */
+       return count <= 1;
+}
+
 static int check_func_proto(const struct bpf_func_proto *fn)
 {
        return check_raw_mode_ok(fn) &&
-              check_arg_pair_ok(fn) ? 0 : -EINVAL;
+              check_arg_pair_ok(fn) &&
+              check_refcount_ok(fn) ? 0 : -EINVAL;
 }
 
 /* Packet data might have moved, any old PTR_TO_PACKET[_META,_END]
@@ -2305,10 +2508,9 @@ static void __clear_all_pkt_pointers(struct bpf_verifier_env *env,
                if (reg_is_pkt_pointer_any(&regs[i]))
                        mark_reg_unknown(env, regs, i);
 
-       for (i = 0; i < state->allocated_stack / BPF_REG_SIZE; i++) {
-               if (state->stack[i].slot_type[0] != STACK_SPILL)
+       bpf_for_each_spilled_reg(i, state, reg) {
+               if (!reg)
                        continue;
-               reg = &state->stack[i].spilled_ptr;
                if (reg_is_pkt_pointer_any(reg))
                        __mark_reg_unknown(reg);
        }
@@ -2323,12 +2525,45 @@ static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
                __clear_all_pkt_pointers(env, vstate->frame[i]);
 }
 
+static void release_reg_references(struct bpf_verifier_env *env,
+                                  struct bpf_func_state *state, int id)
+{
+       struct bpf_reg_state *regs = state->regs, *reg;
+       int i;
+
+       for (i = 0; i < MAX_BPF_REG; i++)
+               if (regs[i].id == id)
+                       mark_reg_unknown(env, regs, i);
+
+       bpf_for_each_spilled_reg(i, state, reg) {
+               if (!reg)
+                       continue;
+               if (reg_is_refcounted(reg) && reg->id == id)
+                       __mark_reg_unknown(reg);
+       }
+}
+
+/* The pointer with the specified id has released its reference to kernel
+ * resources. Identify all copies of the same pointer and clear the reference.
+ */
+static int release_reference(struct bpf_verifier_env *env,
+                            struct bpf_call_arg_meta *meta)
+{
+       struct bpf_verifier_state *vstate = env->cur_state;
+       int i;
+
+       for (i = 0; i <= vstate->curframe; i++)
+               release_reg_references(env, vstate->frame[i], meta->ptr_id);
+
+       return release_reference_state(env, meta->ptr_id);
+}
+
 static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
                           int *insn_idx)
 {
        struct bpf_verifier_state *state = env->cur_state;
        struct bpf_func_state *caller, *callee;
-       int i, subprog, target_insn;
+       int i, err, subprog, target_insn;
 
        if (state->curframe + 1 >= MAX_CALL_FRAMES) {
                verbose(env, "the call stack of %d frames is too deep\n",
@@ -2366,11 +2601,18 @@ static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
                        state->curframe + 1 /* frameno within this callchain */,
                        subprog /* subprog number within this prog */);
 
-       /* copy r1 - r5 args that callee can access */
+       /* Transfer references to the callee */
+       err = transfer_reference_state(callee, caller);
+       if (err)
+               return err;
+
+       /* copy r1 - r5 args that callee can access.  The copy includes parent
+        * pointers, which connects us up to the liveness chain
+        */
        for (i = BPF_REG_1; i <= BPF_REG_5; i++)
                callee->regs[i] = caller->regs[i];
 
-       /* after the call regsiters r0 - r5 were scratched */
+       /* after the call registers r0 - r5 were scratched */
        for (i = 0; i < CALLER_SAVED_REGS; i++) {
                mark_reg_not_init(env, caller->regs, caller_saved[i]);
                check_reg_arg(env, caller_saved[i], DST_OP_NO_MARK);
@@ -2396,6 +2638,7 @@ static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
        struct bpf_verifier_state *state = env->cur_state;
        struct bpf_func_state *caller, *callee;
        struct bpf_reg_state *r0;
+       int err;
 
        callee = state->frame[state->curframe];
        r0 = &callee->regs[BPF_REG_0];
@@ -2415,6 +2658,11 @@ static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
        /* return to the caller whatever r0 had in the callee */
        caller->regs[BPF_REG_0] = *r0;
 
+       /* Transfer references to the caller */
+       err = transfer_reference_state(caller, callee);
+       if (err)
+               return err;
+
        *insn_idx = callee->callsite + 1;
        if (env->log.level) {
                verbose(env, "returning from callee:\n");
@@ -2454,7 +2702,10 @@ record_func_map(struct bpf_verifier_env *env, struct bpf_call_arg_meta *meta,
        if (func_id != BPF_FUNC_tail_call &&
            func_id != BPF_FUNC_map_lookup_elem &&
            func_id != BPF_FUNC_map_update_elem &&
-           func_id != BPF_FUNC_map_delete_elem)
+           func_id != BPF_FUNC_map_delete_elem &&
+           func_id != BPF_FUNC_map_push_elem &&
+           func_id != BPF_FUNC_map_pop_elem &&
+           func_id != BPF_FUNC_map_peek_elem)
                return 0;
 
        if (meta->map_ptr == NULL) {
@@ -2471,6 +2722,18 @@ record_func_map(struct bpf_verifier_env *env, struct bpf_call_arg_meta *meta,
        return 0;
 }
 
+static int check_reference_leak(struct bpf_verifier_env *env)
+{
+       struct bpf_func_state *state = cur_func(env);
+       int i;
+
+       for (i = 0; i < state->acquired_refs; i++) {
+               verbose(env, "Unreleased reference id=%d alloc_insn=%d\n",
+                       state->refs[i].id, state->refs[i].insn_idx);
+       }
+       return state->acquired_refs ? -EINVAL : 0;
+}
+
 static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn_idx)
 {
        const struct bpf_func_proto *fn = NULL;
@@ -2549,6 +2812,18 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
                        return err;
        }
 
+       if (func_id == BPF_FUNC_tail_call) {
+               err = check_reference_leak(env);
+               if (err) {
+                       verbose(env, "tail_call would lead to reference leak\n");
+                       return err;
+               }
+       } else if (is_release_function(func_id)) {
+               err = release_reference(env, &meta);
+               if (err)
+                       return err;
+       }
+
        regs = cur_regs(env);
 
        /* check that flags argument in get_local_storage(map, flags) is 0,
@@ -2580,7 +2855,6 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
                        regs[BPF_REG_0].type = PTR_TO_MAP_VALUE_OR_NULL;
                /* There is no offset yet applied, variable or fixed */
                mark_reg_known_zero(env, regs, BPF_REG_0);
-               regs[BPF_REG_0].off = 0;
                /* remember map_ptr, so that check_map_access()
                 * can check 'value_size' boundary of memory access
                 * to map element returned from bpf_map_lookup_elem()
@@ -2592,6 +2866,13 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
                }
                regs[BPF_REG_0].map_ptr = meta.map_ptr;
                regs[BPF_REG_0].id = ++env->id_gen;
+       } else if (fn->ret_type == RET_PTR_TO_SOCKET_OR_NULL) {
+               int id = acquire_reference_state(env, insn_idx);
+               if (id < 0)
+                       return id;
+               mark_reg_known_zero(env, regs, BPF_REG_0);
+               regs[BPF_REG_0].type = PTR_TO_SOCKET_OR_NULL;
+               regs[BPF_REG_0].id = id;
        } else {
                verbose(env, "unknown return type %d of func %s#%d\n",
                        fn->ret_type, func_id_name(func_id), func_id);
@@ -2722,20 +3003,20 @@ static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env,
                return -EACCES;
        }
 
-       if (ptr_reg->type == PTR_TO_MAP_VALUE_OR_NULL) {
-               verbose(env, "R%d pointer arithmetic on PTR_TO_MAP_VALUE_OR_NULL prohibited, null-check it first\n",
-                       dst);
-               return -EACCES;
-       }
-       if (ptr_reg->type == CONST_PTR_TO_MAP) {
-               verbose(env, "R%d pointer arithmetic on CONST_PTR_TO_MAP prohibited\n",
-                       dst);
+       switch (ptr_reg->type) {
+       case PTR_TO_MAP_VALUE_OR_NULL:
+               verbose(env, "R%d pointer arithmetic on %s prohibited, null-check it first\n",
+                       dst, reg_type_str[ptr_reg->type]);
                return -EACCES;
-       }
-       if (ptr_reg->type == PTR_TO_PACKET_END) {
-               verbose(env, "R%d pointer arithmetic on PTR_TO_PACKET_END prohibited\n",
-                       dst);
+       case CONST_PTR_TO_MAP:
+       case PTR_TO_PACKET_END:
+       case PTR_TO_SOCKET:
+       case PTR_TO_SOCKET_OR_NULL:
+               verbose(env, "R%d pointer arithmetic on %s prohibited\n",
+                       dst, reg_type_str[ptr_reg->type]);
                return -EACCES;
+       default:
+               break;
        }
 
        /* In case of 'scalar += pointer', dst_reg inherits pointer type and id.
@@ -2896,6 +3177,15 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
        u64 umin_val, umax_val;
        u64 insn_bitness = (BPF_CLASS(insn->code) == BPF_ALU64) ? 64 : 32;
 
+       if (insn_bitness == 32) {
+               /* Relevant for 32-bit RSH: Information can propagate towards
+                * LSB, so it isn't sufficient to only truncate the output to
+                * 32 bits.
+                */
+               coerce_reg_to_size(dst_reg, 4);
+               coerce_reg_to_size(&src_reg, 4);
+       }
+
        smin_val = src_reg.smin_value;
        smax_val = src_reg.smax_value;
        umin_val = src_reg.umin_value;
@@ -3131,7 +3421,6 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
        if (BPF_CLASS(insn->code) != BPF_ALU64) {
                /* 32-bit ALU ops are (32,32)->32 */
                coerce_reg_to_size(dst_reg, 4);
-               coerce_reg_to_size(&src_reg, 4);
        }
 
        __reg_deduce_bounds(dst_reg);
@@ -3447,10 +3736,9 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
 
        for (j = 0; j <= vstate->curframe; j++) {
                state = vstate->frame[j];
-               for (i = 0; i < state->allocated_stack / BPF_REG_SIZE; i++) {
-                       if (state->stack[i].slot_type[0] != STACK_SPILL)
+               bpf_for_each_spilled_reg(i, state, reg) {
+                       if (!reg)
                                continue;
-                       reg = &state->stack[i].spilled_ptr;
                        if (reg->type == type && reg->id == dst_reg->id)
                                reg->range = max(reg->range, new_range);
                }
@@ -3656,12 +3944,11 @@ static void reg_combine_min_max(struct bpf_reg_state *true_src,
        }
 }
 
-static void mark_map_reg(struct bpf_reg_state *regs, u32 regno, u32 id,
-                        bool is_null)
+static void mark_ptr_or_null_reg(struct bpf_func_state *state,
+                                struct bpf_reg_state *reg, u32 id,
+                                bool is_null)
 {
-       struct bpf_reg_state *reg = &regs[regno];
-
-       if (reg->type == PTR_TO_MAP_VALUE_OR_NULL && reg->id == id) {
+       if (reg_type_may_be_null(reg->type) && reg->id == id) {
                /* Old offset (both fixed and variable parts) should
                 * have been known-zero, because we don't allow pointer
                 * arithmetic on pointers that might be NULL.
@@ -3674,40 +3961,49 @@ static void mark_map_reg(struct bpf_reg_state *regs, u32 regno, u32 id,
                }
                if (is_null) {
                        reg->type = SCALAR_VALUE;
-               } else if (reg->map_ptr->inner_map_meta) {
-                       reg->type = CONST_PTR_TO_MAP;
-                       reg->map_ptr = reg->map_ptr->inner_map_meta;
-               } else {
-                       reg->type = PTR_TO_MAP_VALUE;
+               } else if (reg->type == PTR_TO_MAP_VALUE_OR_NULL) {
+                       if (reg->map_ptr->inner_map_meta) {
+                               reg->type = CONST_PTR_TO_MAP;
+                               reg->map_ptr = reg->map_ptr->inner_map_meta;
+                       } else {
+                               reg->type = PTR_TO_MAP_VALUE;
+                       }
+               } else if (reg->type == PTR_TO_SOCKET_OR_NULL) {
+                       reg->type = PTR_TO_SOCKET;
+               }
+               if (is_null || !reg_is_refcounted(reg)) {
+                       /* We don't need id from this point onwards anymore,
+                        * thus we should better reset it, so that state
+                        * pruning has chances to take effect.
+                        */
+                       reg->id = 0;
                }
-               /* We don't need id from this point onwards anymore, thus we
-                * should better reset it, so that state pruning has chances
-                * to take effect.
-                */
-               reg->id = 0;
        }
 }
 
 /* The logic is similar to find_good_pkt_pointers(), both could eventually
  * be folded together at some point.
  */
-static void mark_map_regs(struct bpf_verifier_state *vstate, u32 regno,
-                         bool is_null)
+static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
+                                 bool is_null)
 {
        struct bpf_func_state *state = vstate->frame[vstate->curframe];
-       struct bpf_reg_state *regs = state->regs;
+       struct bpf_reg_state *reg, *regs = state->regs;
        u32 id = regs[regno].id;
        int i, j;
 
+       if (reg_is_refcounted_or_null(&regs[regno]) && is_null)
+               __release_reference_state(state, id);
+
        for (i = 0; i < MAX_BPF_REG; i++)
-               mark_map_reg(regs, i, id, is_null);
+               mark_ptr_or_null_reg(state, &regs[i], id, is_null);
 
        for (j = 0; j <= vstate->curframe; j++) {
                state = vstate->frame[j];
-               for (i = 0; i < state->allocated_stack / BPF_REG_SIZE; i++) {
-                       if (state->stack[i].slot_type[0] != STACK_SPILL)
+               bpf_for_each_spilled_reg(i, state, reg) {
+                       if (!reg)
                                continue;
-                       mark_map_reg(&state->stack[i].spilled_ptr, 0, id, is_null);
+                       mark_ptr_or_null_reg(state, reg, id, is_null);
                }
        }
 }
@@ -3909,12 +4205,14 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
        /* detect if R == 0 where R is returned from bpf_map_lookup_elem() */
        if (BPF_SRC(insn->code) == BPF_K &&
            insn->imm == 0 && (opcode == BPF_JEQ || opcode == BPF_JNE) &&
-           dst_reg->type == PTR_TO_MAP_VALUE_OR_NULL) {
-               /* Mark all identical map registers in each branch as either
+           reg_type_may_be_null(dst_reg->type)) {
+               /* Mark all identical registers in each branch as either
                 * safe or unknown depending R == 0 or R != 0 conditional.
                 */
-               mark_map_regs(this_branch, insn->dst_reg, opcode == BPF_JNE);
-               mark_map_regs(other_branch, insn->dst_reg, opcode == BPF_JEQ);
+               mark_ptr_or_null_regs(this_branch, insn->dst_reg,
+                                     opcode == BPF_JNE);
+               mark_ptr_or_null_regs(other_branch, insn->dst_reg,
+                                     opcode == BPF_JEQ);
        } else if (!try_match_pkt_pointers(insn, dst_reg, &regs[insn->src_reg],
                                           this_branch, other_branch) &&
                   is_pointer_value(env, insn->dst_reg)) {
@@ -4037,6 +4335,16 @@ static int check_ld_abs(struct bpf_verifier_env *env, struct bpf_insn *insn)
        if (err)
                return err;
 
+       /* Disallow usage of BPF_LD_[ABS|IND] with reference tracking, as
+        * gen_ld_abs() may terminate the program at runtime, leading to
+        * reference leak.
+        */
+       err = check_reference_leak(env);
+       if (err) {
+               verbose(env, "BPF_LD_[ABS|IND] cannot be mixed with socket references\n");
+               return err;
+       }
+
        if (regs[BPF_REG_6].type != PTR_TO_CTX) {
                verbose(env,
                        "at the time of BPF_LD_ABS|IND R6 != pointer to skb\n");
@@ -4370,7 +4678,7 @@ static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur,
                /* explored state didn't use this */
                return true;
 
-       equal = memcmp(rold, rcur, offsetof(struct bpf_reg_state, frameno)) == 0;
+       equal = memcmp(rold, rcur, offsetof(struct bpf_reg_state, parent)) == 0;
 
        if (rold->type == PTR_TO_STACK)
                /* two stack pointers are equal only if they're pointing to
@@ -4451,6 +4759,9 @@ static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur,
        case PTR_TO_CTX:
        case CONST_PTR_TO_MAP:
        case PTR_TO_PACKET_END:
+       case PTR_TO_FLOW_KEYS:
+       case PTR_TO_SOCKET:
+       case PTR_TO_SOCKET_OR_NULL:
                /* Only valid matches are exact, which memcmp() above
                 * would have accepted
                 */
@@ -4526,6 +4837,14 @@ static bool stacksafe(struct bpf_func_state *old,
        return true;
 }
 
+static bool refsafe(struct bpf_func_state *old, struct bpf_func_state *cur)
+{
+       if (old->acquired_refs != cur->acquired_refs)
+               return false;
+       return !memcmp(old->refs, cur->refs,
+                      sizeof(*old->refs) * old->acquired_refs);
+}
+
 /* compare two verifier states
  *
  * all states stored in state_list are known to be valid, since
@@ -4571,6 +4890,9 @@ static bool func_states_equal(struct bpf_func_state *old,
 
        if (!stacksafe(old, cur, idmap))
                goto out_free;
+
+       if (!refsafe(old, cur))
+               goto out_free;
        ret = true;
 out_free:
        kfree(idmap);
@@ -4603,7 +4925,7 @@ static bool states_equal(struct bpf_verifier_env *env,
  * equivalent state (jump target or such) we didn't arrive by the straight-line
  * code, so read marks in the state must propagate to the parent regardless
  * of the state's write marks. That's what 'parent == state->parent' comparison
- * in mark_reg_read() and mark_stack_slot_read() is for.
+ * in mark_reg_read() is for.
  */
 static int propagate_liveness(struct bpf_verifier_env *env,
                              const struct bpf_verifier_state *vstate,
@@ -4624,7 +4946,8 @@ static int propagate_liveness(struct bpf_verifier_env *env,
                if (vparent->frame[vparent->curframe]->regs[i].live & REG_LIVE_READ)
                        continue;
                if (vstate->frame[vstate->curframe]->regs[i].live & REG_LIVE_READ) {
-                       err = mark_reg_read(env, vstate, vparent, i);
+                       err = mark_reg_read(env, &vstate->frame[vstate->curframe]->regs[i],
+                                           &vparent->frame[vstate->curframe]->regs[i]);
                        if (err)
                                return err;
                }
@@ -4639,7 +4962,8 @@ static int propagate_liveness(struct bpf_verifier_env *env,
                        if (parent->stack[i].spilled_ptr.live & REG_LIVE_READ)
                                continue;
                        if (state->stack[i].spilled_ptr.live & REG_LIVE_READ)
-                               mark_stack_slot_read(env, vstate, vparent, i, frame);
+                               mark_reg_read(env, &state->stack[i].spilled_ptr,
+                                             &parent->stack[i].spilled_ptr);
                }
        }
        return err;
@@ -4649,7 +4973,7 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 {
        struct bpf_verifier_state_list *new_sl;
        struct bpf_verifier_state_list *sl;
-       struct bpf_verifier_state *cur = env->cur_state;
+       struct bpf_verifier_state *cur = env->cur_state, *new;
        int i, j, err;
 
        sl = env->explored_states[insn_idx];
@@ -4691,16 +5015,18 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
                return -ENOMEM;
 
        /* add new state to the head of linked list */
-       err = copy_verifier_state(&new_sl->state, cur);
+       new = &new_sl->state;
+       err = copy_verifier_state(new, cur);
        if (err) {
-               free_verifier_state(&new_sl->state, false);
+               free_verifier_state(new, false);
                kfree(new_sl);
                return err;
        }
        new_sl->next = env->explored_states[insn_idx];
        env->explored_states[insn_idx] = new_sl;
        /* connect new state to parentage chain */
-       cur->parent = &new_sl->state;
+       for (i = 0; i < BPF_REG_FP; i++)
+               cur_regs(env)[i].parent = &new->frame[new->curframe]->regs[i];
        /* clear write marks in current state: the writes we did are not writes
         * our child did, so they don't screen off its reads from us.
         * (There are no read marks in current state, because reads always mark
@@ -4713,13 +5039,48 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
        /* all stack frames are accessible from callee, clear them all */
        for (j = 0; j <= cur->curframe; j++) {
                struct bpf_func_state *frame = cur->frame[j];
+               struct bpf_func_state *newframe = new->frame[j];
 
-               for (i = 0; i < frame->allocated_stack / BPF_REG_SIZE; i++)
+               for (i = 0; i < frame->allocated_stack / BPF_REG_SIZE; i++) {
                        frame->stack[i].spilled_ptr.live = REG_LIVE_NONE;
+                       frame->stack[i].spilled_ptr.parent =
+                                               &newframe->stack[i].spilled_ptr;
+               }
        }
        return 0;
 }
 
+/* Return true if it's OK to have the same insn return a different type. */
+static bool reg_type_mismatch_ok(enum bpf_reg_type type)
+{
+       switch (type) {
+       case PTR_TO_CTX:
+       case PTR_TO_SOCKET:
+       case PTR_TO_SOCKET_OR_NULL:
+               return false;
+       default:
+               return true;
+       }
+}
+
+/* If an instruction was previously used with particular pointer types, then we
+ * need to be careful to avoid cases such as the below, where it may be ok
+ * for one branch accessing the pointer, but not ok for the other branch:
+ *
+ * R1 = sock_ptr
+ * goto X;
+ * ...
+ * R1 = some_other_valid_ptr;
+ * goto X;
+ * ...
+ * R2 = *(u32 *)(R1 + 0);
+ */
+static bool reg_type_mismatch(enum bpf_reg_type src, enum bpf_reg_type prev)
+{
+       return src != prev && (!reg_type_mismatch_ok(src) ||
+                              !reg_type_mismatch_ok(prev));
+}
+
 static int do_check(struct bpf_verifier_env *env)
 {
        struct bpf_verifier_state *state;
@@ -4734,7 +5095,6 @@ static int do_check(struct bpf_verifier_env *env)
        if (!state)
                return -ENOMEM;
        state->curframe = 0;
-       state->parent = NULL;
        state->frame[0] = kzalloc(sizeof(struct bpf_func_state), GFP_KERNEL);
        if (!state->frame[0]) {
                kfree(state);
@@ -4814,6 +5174,7 @@ static int do_check(struct bpf_verifier_env *env)
 
                regs = cur_regs(env);
                env->insn_aux_data[insn_idx].seen = true;
+
                if (class == BPF_ALU || class == BPF_ALU64) {
                        err = check_alu_op(env, insn);
                        if (err)
@@ -4853,9 +5214,7 @@ static int do_check(struct bpf_verifier_env *env)
                                 */
                                *prev_src_type = src_reg_type;
 
-                       } else if (src_reg_type != *prev_src_type &&
-                                  (src_reg_type == PTR_TO_CTX ||
-                                   *prev_src_type == PTR_TO_CTX)) {
+                       } else if (reg_type_mismatch(src_reg_type, *prev_src_type)) {
                                /* ABuser program is trying to use the same insn
                                 * dst_reg = *(u32*) (src_reg + off)
                                 * with different pointer types:
@@ -4900,9 +5259,7 @@ static int do_check(struct bpf_verifier_env *env)
 
                        if (*prev_dst_type == NOT_INIT) {
                                *prev_dst_type = dst_reg_type;
-                       } else if (dst_reg_type != *prev_dst_type &&
-                                  (dst_reg_type == PTR_TO_CTX ||
-                                   *prev_dst_type == PTR_TO_CTX)) {
+                       } else if (reg_type_mismatch(dst_reg_type, *prev_dst_type)) {
                                verbose(env, "same insn cannot be used with different pointers\n");
                                return -EINVAL;
                        }
@@ -4919,8 +5276,9 @@ static int do_check(struct bpf_verifier_env *env)
                                return err;
 
                        if (is_ctx_reg(env, insn->dst_reg)) {
-                               verbose(env, "BPF_ST stores into R%d context is not allowed\n",
-                                       insn->dst_reg);
+                               verbose(env, "BPF_ST stores into R%d %s is not allowed\n",
+                                       insn->dst_reg,
+                                       reg_type_str[reg_state(env, insn->dst_reg)->type]);
                                return -EACCES;
                        }
 
@@ -4982,6 +5340,10 @@ static int do_check(struct bpf_verifier_env *env)
                                        continue;
                                }
 
+                               err = check_reference_leak(env);
+                               if (err)
+                                       return err;
+
                                /* eBPF calling convetion is such that R0 is used
                                 * to return the value from eBPF program.
                                 * Make sure that it's readable at this time
@@ -5095,6 +5457,12 @@ static int check_map_prog_compatibility(struct bpf_verifier_env *env,
        return 0;
 }
 
+static bool bpf_map_is_cgroup_storage(struct bpf_map *map)
+{
+       return (map->map_type == BPF_MAP_TYPE_CGROUP_STORAGE ||
+               map->map_type == BPF_MAP_TYPE_PERCPU_CGROUP_STORAGE);
+}
+
 /* look for pseudo eBPF instructions that access map FDs and
  * replace them with actual map pointers
  */
@@ -5185,10 +5553,9 @@ static int replace_map_fd_with_map_ptr(struct bpf_verifier_env *env)
                        }
                        env->used_maps[env->used_map_cnt++] = map;
 
-                       if (map->map_type == BPF_MAP_TYPE_CGROUP_STORAGE &&
+                       if (bpf_map_is_cgroup_storage(map) &&
                            bpf_cgroup_storage_assign(env->prog, map)) {
-                               verbose(env,
-                                       "only one cgroup storage is allowed\n");
+                               verbose(env, "only one cgroup storage of each type is allowed\n");
                                fdput(f);
                                return -EBUSY;
                        }
@@ -5217,11 +5584,15 @@ next_insn:
 /* drop refcnt of maps used by the rejected program */
 static void release_maps(struct bpf_verifier_env *env)
 {
+       enum bpf_cgroup_storage_type stype;
        int i;
 
-       if (env->prog->aux->cgroup_storage)
+       for_each_cgroup_storage_type(stype) {
+               if (!env->prog->aux->cgroup_storage[stype])
+                       continue;
                bpf_cgroup_storage_release(env->prog,
-                                          env->prog->aux->cgroup_storage);
+                       env->prog->aux->cgroup_storage[stype]);
+       }
 
        for (i = 0; i < env->used_map_cnt; i++)
                bpf_map_put(env->used_maps[i]);
@@ -5319,8 +5690,10 @@ static void sanitize_dead_code(struct bpf_verifier_env *env)
        }
 }
 
-/* convert load instructions that access fields of 'struct __sk_buff'
- * into sequence of instructions that access fields of 'struct sk_buff'
+/* convert load instructions that access fields of a context type into a
+ * sequence of instructions that access fields of the underlying structure:
+ *     struct __sk_buff    -> struct sk_buff
+ *     struct bpf_sock_ops -> struct sock
  */
 static int convert_ctx_accesses(struct bpf_verifier_env *env)
 {
@@ -5349,12 +5722,14 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                }
        }
 
-       if (!ops->convert_ctx_access || bpf_prog_is_dev_bound(env->prog->aux))
+       if (bpf_prog_is_dev_bound(env->prog->aux))
                return 0;
 
        insn = env->prog->insnsi + delta;
 
        for (i = 0; i < insn_cnt; i++, insn++) {
+               bpf_convert_ctx_access_t convert_ctx_access;
+
                if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) ||
                    insn->code == (BPF_LDX | BPF_MEM | BPF_H) ||
                    insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
@@ -5396,8 +5771,18 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                        continue;
                }
 
-               if (env->insn_aux_data[i + delta].ptr_type != PTR_TO_CTX)
+               switch (env->insn_aux_data[i + delta].ptr_type) {
+               case PTR_TO_CTX:
+                       if (!ops->convert_ctx_access)
+                               continue;
+                       convert_ctx_access = ops->convert_ctx_access;
+                       break;
+               case PTR_TO_SOCKET:
+                       convert_ctx_access = bpf_sock_convert_ctx_access;
+                       break;
+               default:
                        continue;
+               }
 
                ctx_field_size = env->insn_aux_data[i + delta].ctx_field_size;
                size = BPF_LDST_BYTES(insn);
@@ -5429,8 +5814,8 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                }
 
                target_size = 0;
-               cnt = ops->convert_ctx_access(type, insn, insn_buf, env->prog,
-                                             &target_size);
+               cnt = convert_ctx_access(type, insn, insn_buf, env->prog,
+                                        &target_size);
                if (cnt == 0 || cnt >= ARRAY_SIZE(insn_buf) ||
                    (ctx_field_size && !target_size)) {
                        verbose(env, "bpf verifier is misconfigured\n");
@@ -5621,10 +6006,10 @@ static int fixup_call_args(struct bpf_verifier_env *env)
        struct bpf_insn *insn = prog->insnsi;
        int i, depth;
 #endif
-       int err;
+       int err = 0;
 
-       err = 0;
-       if (env->prog->jit_requested) {
+       if (env->prog->jit_requested &&
+           !bpf_prog_is_dev_bound(env->prog->aux)) {
                err = jit_subprogs(env);
                if (err == 0)
                        return 0;
@@ -5962,6 +6347,9 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr)
                env->cur_state = NULL;
        }
 
+       if (ret == 0 && bpf_prog_is_dev_bound(env->prog->aux))
+               ret = bpf_prog_offload_finalize(env);
+
 skip_full_check:
        while (!pop_stack(env, NULL, NULL));
        free_states(env);