OSDN Git Service

bpf: Enable BPF_PROG_TYPE_SK_REUSEPORT bpf prog in reuseport selection
authorMartin KaFai Lau <kafai@fb.com>
Wed, 8 Aug 2018 08:01:26 +0000 (01:01 -0700)
committerDaniel Borkmann <daniel@iogearbox.net>
Fri, 10 Aug 2018 23:58:46 +0000 (01:58 +0200)
This patch allows a BPF_PROG_TYPE_SK_REUSEPORT bpf prog to select a
SO_REUSEPORT sk from a BPF_MAP_TYPE_REUSEPORT_ARRAY introduced in
the earlier patch.  "bpf_run_sk_reuseport()" will return -ECONNREFUSED
when the BPF_PROG_TYPE_SK_REUSEPORT prog returns SK_DROP.
The callers, in inet[6]_hashtable.c and ipv[46]/udp.c, are modified to
handle this case and return NULL immediately instead of continuing the
sk search from its hashtable.

It re-uses the existing SO_ATTACH_REUSEPORT_EBPF setsockopt to attach
BPF_PROG_TYPE_SK_REUSEPORT.  The "sk_reuseport_attach_bpf()" will check
if the attaching bpf prog is in the new SK_REUSEPORT or the existing
SOCKET_FILTER type and then check different things accordingly.

One level of "__reuseport_attach_prog()" call is removed.  The
"sk_unhashed() && ..." and "sk->sk_reuseport_cb" tests are pushed
back to "reuseport_attach_prog()" in sock_reuseport.c.  sock_reuseport.c
seems to have more knowledge on those test requirements than filter.c.
In "reuseport_attach_prog()", after new_prog is attached to reuse->prog,
the old_prog (if any) is also directly freed instead of returning the
old_prog to the caller and asking the caller to free.

The sysctl_optmem_max check is moved back to the
"sk_reuseport_attach_filter()" and "sk_reuseport_attach_bpf()".
As of other bpf prog types, the new BPF_PROG_TYPE_SK_REUSEPORT is only
bounded by the usual "bpf_prog_charge_memlock()" during load time
instead of bounded by both bpf_prog_charge_memlock and sysctl_optmem_max.

Signed-off-by: Martin KaFai Lau <kafai@fb.com>
Acked-by: Alexei Starovoitov <ast@kernel.org>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
include/linux/filter.h
include/net/sock_reuseport.h
net/core/filter.c
net/core/sock_reuseport.c
net/ipv4/inet_hashtables.c
net/ipv4/udp.c
net/ipv6/inet6_hashtables.c
net/ipv6/udp.c

index 70e9d57..5d565c5 100644 (file)
@@ -753,6 +753,7 @@ int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk);
 int sk_attach_bpf(u32 ufd, struct sock *sk);
 int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk);
 int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk);
+void sk_reuseport_prog_free(struct bpf_prog *prog);
 int sk_detach_filter(struct sock *sk);
 int sk_get_filter(struct sock *sk, struct sock_filter __user *filter,
                  unsigned int len);
index 73b5695..8a5f70c 100644 (file)
@@ -34,8 +34,7 @@ extern struct sock *reuseport_select_sock(struct sock *sk,
                                          u32 hash,
                                          struct sk_buff *skb,
                                          int hdr_len);
-extern struct bpf_prog *reuseport_attach_prog(struct sock *sk,
-                                             struct bpf_prog *prog);
+extern int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog);
 int reuseport_get_id(struct sock_reuseport *reuse);
 
 #endif  /* _SOCK_REUSEPORT_H */
index 142595b..22906b3 100644 (file)
@@ -1453,30 +1453,6 @@ static int __sk_attach_prog(struct bpf_prog *prog, struct sock *sk)
        return 0;
 }
 
-static int __reuseport_attach_prog(struct bpf_prog *prog, struct sock *sk)
-{
-       struct bpf_prog *old_prog;
-       int err;
-
-       if (bpf_prog_size(prog->len) > sysctl_optmem_max)
-               return -ENOMEM;
-
-       if (sk_unhashed(sk) && sk->sk_reuseport) {
-               err = reuseport_alloc(sk, false);
-               if (err)
-                       return err;
-       } else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
-               /* The socket wasn't bound with SO_REUSEPORT */
-               return -EINVAL;
-       }
-
-       old_prog = reuseport_attach_prog(sk, prog);
-       if (old_prog)
-               bpf_prog_destroy(old_prog);
-
-       return 0;
-}
-
 static
 struct bpf_prog *__get_filter(struct sock_fprog *fprog, struct sock *sk)
 {
@@ -1550,13 +1526,15 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk)
        if (IS_ERR(prog))
                return PTR_ERR(prog);
 
-       err = __reuseport_attach_prog(prog, sk);
-       if (err < 0) {
+       if (bpf_prog_size(prog->len) > sysctl_optmem_max)
+               err = -ENOMEM;
+       else
+               err = reuseport_attach_prog(sk, prog);
+
+       if (err)
                __bpf_prog_release(prog);
-               return err;
-       }
 
-       return 0;
+       return err;
 }
 
 static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk)
@@ -1586,19 +1564,58 @@ int sk_attach_bpf(u32 ufd, struct sock *sk)
 
 int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk)
 {
-       struct bpf_prog *prog = __get_bpf(ufd, sk);
+       struct bpf_prog *prog;
        int err;
 
+       if (sock_flag(sk, SOCK_FILTER_LOCKED))
+               return -EPERM;
+
+       prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SOCKET_FILTER);
+       if (IS_ERR(prog) && PTR_ERR(prog) == -EINVAL)
+               prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SK_REUSEPORT);
        if (IS_ERR(prog))
                return PTR_ERR(prog);
 
-       err = __reuseport_attach_prog(prog, sk);
-       if (err < 0) {
-               bpf_prog_put(prog);
-               return err;
+       if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) {
+               /* Like other non BPF_PROG_TYPE_SOCKET_FILTER
+                * bpf prog (e.g. sockmap).  It depends on the
+                * limitation imposed by bpf_prog_load().
+                * Hence, sysctl_optmem_max is not checked.
+                */
+               if ((sk->sk_type != SOCK_STREAM &&
+                    sk->sk_type != SOCK_DGRAM) ||
+                   (sk->sk_protocol != IPPROTO_UDP &&
+                    sk->sk_protocol != IPPROTO_TCP) ||
+                   (sk->sk_family != AF_INET &&
+                    sk->sk_family != AF_INET6)) {
+                       err = -ENOTSUPP;
+                       goto err_prog_put;
+               }
+       } else {
+               /* BPF_PROG_TYPE_SOCKET_FILTER */
+               if (bpf_prog_size(prog->len) > sysctl_optmem_max) {
+                       err = -ENOMEM;
+                       goto err_prog_put;
+               }
        }
 
-       return 0;
+       err = reuseport_attach_prog(sk, prog);
+err_prog_put:
+       if (err)
+               bpf_prog_put(prog);
+
+       return err;
+}
+
+void sk_reuseport_prog_free(struct bpf_prog *prog)
+{
+       if (!prog)
+               return;
+
+       if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
+               bpf_prog_put(prog);
+       else
+               bpf_prog_destroy(prog);
 }
 
 struct bpf_scratchpad {
index d260167..ba5cba5 100644 (file)
@@ -9,6 +9,7 @@
 #include <net/sock_reuseport.h>
 #include <linux/bpf.h>
 #include <linux/idr.h>
+#include <linux/filter.h>
 #include <linux/rcupdate.h>
 
 #define INIT_SOCKS 128
@@ -133,8 +134,7 @@ static void reuseport_free_rcu(struct rcu_head *head)
        struct sock_reuseport *reuse;
 
        reuse = container_of(head, struct sock_reuseport, rcu);
-       if (reuse->prog)
-               bpf_prog_destroy(reuse->prog);
+       sk_reuseport_prog_free(rcu_dereference_protected(reuse->prog, 1));
        if (reuse->reuseport_id)
                ida_simple_remove(&reuseport_ida, reuse->reuseport_id);
        kfree(reuse);
@@ -219,9 +219,9 @@ void reuseport_detach_sock(struct sock *sk)
 }
 EXPORT_SYMBOL(reuseport_detach_sock);
 
-static struct sock *run_bpf(struct sock_reuseport *reuse, u16 socks,
-                           struct bpf_prog *prog, struct sk_buff *skb,
-                           int hdr_len)
+static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks,
+                                  struct bpf_prog *prog, struct sk_buff *skb,
+                                  int hdr_len)
 {
        struct sk_buff *nskb = NULL;
        u32 index;
@@ -282,9 +282,15 @@ struct sock *reuseport_select_sock(struct sock *sk,
                /* paired with smp_wmb() in reuseport_add_sock() */
                smp_rmb();
 
-               if (prog && skb)
-                       sk2 = run_bpf(reuse, socks, prog, skb, hdr_len);
+               if (!prog || !skb)
+                       goto select_by_hash;
+
+               if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
+                       sk2 = bpf_run_sk_reuseport(reuse, sk, prog, skb, hash);
+               else
+                       sk2 = run_bpf_filter(reuse, socks, prog, skb, hdr_len);
 
+select_by_hash:
                /* no bpf or invalid bpf result: fall back to hash usage */
                if (!sk2)
                        sk2 = reuse->socks[reciprocal_scale(hash, socks)];
@@ -296,12 +302,21 @@ out:
 }
 EXPORT_SYMBOL(reuseport_select_sock);
 
-struct bpf_prog *
-reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
+int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
 {
        struct sock_reuseport *reuse;
        struct bpf_prog *old_prog;
 
+       if (sk_unhashed(sk) && sk->sk_reuseport) {
+               int err = reuseport_alloc(sk, false);
+
+               if (err)
+                       return err;
+       } else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
+               /* The socket wasn't bound with SO_REUSEPORT */
+               return -EINVAL;
+       }
+
        spin_lock_bh(&reuseport_lock);
        reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
                                          lockdep_is_held(&reuseport_lock));
@@ -310,6 +325,7 @@ reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
        rcu_assign_pointer(reuse->prog, prog);
        spin_unlock_bh(&reuseport_lock);
 
-       return old_prog;
+       sk_reuseport_prog_free(old_prog);
+       return 0;
 }
 EXPORT_SYMBOL(reuseport_attach_prog);
index 370e244..f5c9ef2 100644 (file)
@@ -328,7 +328,7 @@ struct sock *__inet_lookup_listener(struct net *net,
                                    saddr, sport, daddr, hnum,
                                    dif, sdif);
        if (result)
-               return result;
+               goto done;
 
        /* Lookup lhash2 with INADDR_ANY */
 
@@ -337,9 +337,10 @@ struct sock *__inet_lookup_listener(struct net *net,
        if (ilb2->count > ilb->count)
                goto port_lookup;
 
-       return inet_lhash2_lookup(net, ilb2, skb, doff,
-                                 saddr, sport, daddr, hnum,
-                                 dif, sdif);
+       result = inet_lhash2_lookup(net, ilb2, skb, doff,
+                                   saddr, sport, daddr, hnum,
+                                   dif, sdif);
+       goto done;
 
 port_lookup:
        sk_for_each_rcu(sk, &ilb->head) {
@@ -352,12 +353,15 @@ port_lookup:
                                result = reuseport_select_sock(sk, phash,
                                                               skb, doff);
                                if (result)
-                                       return result;
+                                       goto done;
                        }
                        result = sk;
                        hiscore = score;
                }
        }
+done:
+       if (unlikely(IS_ERR(result)))
+               return NULL;
        return result;
 }
 EXPORT_SYMBOL_GPL(__inet_lookup_listener);
index 038dd79..f4e35b2 100644 (file)
@@ -499,6 +499,8 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
                                                  daddr, hnum, dif, sdif,
                                                  exact_dif, hslot2, skb);
                }
+               if (unlikely(IS_ERR(result)))
+                       return NULL;
                return result;
        }
 begin:
@@ -513,6 +515,8 @@ begin:
                                                   saddr, sport);
                                result = reuseport_select_sock(sk, hash, skb,
                                                        sizeof(struct udphdr));
+                               if (unlikely(IS_ERR(result)))
+                                       return NULL;
                                if (result)
                                        return result;
                        }
index 595ad40..3d7c746 100644 (file)
@@ -191,7 +191,7 @@ struct sock *inet6_lookup_listener(struct net *net,
                                     saddr, sport, daddr, hnum,
                                     dif, sdif);
        if (result)
-               return result;
+               goto done;
 
        /* Lookup lhash2 with in6addr_any */
 
@@ -200,9 +200,10 @@ struct sock *inet6_lookup_listener(struct net *net,
        if (ilb2->count > ilb->count)
                goto port_lookup;
 
-       return inet6_lhash2_lookup(net, ilb2, skb, doff,
-                                  saddr, sport, daddr, hnum,
-                                  dif, sdif);
+       result = inet6_lhash2_lookup(net, ilb2, skb, doff,
+                                    saddr, sport, daddr, hnum,
+                                    dif, sdif);
+       goto done;
 
 port_lookup:
        sk_for_each(sk, &ilb->head) {
@@ -214,12 +215,15 @@ port_lookup:
                                result = reuseport_select_sock(sk, phash,
                                                               skb, doff);
                                if (result)
-                                       return result;
+                                       goto done;
                        }
                        result = sk;
                        hiscore = score;
                }
        }
+done:
+       if (unlikely(IS_ERR(result)))
+               return NULL;
        return result;
 }
 EXPORT_SYMBOL_GPL(inet6_lookup_listener);
index f6b9695..83f4c77 100644 (file)
@@ -235,6 +235,8 @@ struct sock *__udp6_lib_lookup(struct net *net,
                                                  exact_dif, hslot2,
                                                  skb);
                }
+               if (unlikely(IS_ERR(result)))
+                       return NULL;
                return result;
        }
 begin:
@@ -249,6 +251,8 @@ begin:
                                                    saddr, sport);
                                result = reuseport_select_sock(sk, hash, skb,
                                                        sizeof(struct udphdr));
+                               if (unlikely(IS_ERR(result)))
+                                       return NULL;
                                if (result)
                                        return result;
                        }