OSDN Git Service

bpf: tcp: Move assertions into tcp_bpf_get_proto
authorLorenz Bauer <lmb@cloudflare.com>
Mon, 9 Mar 2020 11:12:34 +0000 (11:12 +0000)
committerDaniel Borkmann <daniel@iogearbox.net>
Mon, 9 Mar 2020 21:34:58 +0000 (22:34 +0100)
We need to ensure that sk->sk_prot uses certain callbacks, so that
code that directly calls e.g. tcp_sendmsg in certain corner cases
works. To avoid spurious asserts, we must to do this only if
sk_psock_update_proto has not yet been called. The same invariants
apply for tcp_bpf_check_v6_needs_rebuild, so move the call as well.

Doing so allows us to merge tcp_bpf_init and tcp_bpf_reinit.

Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Reviewed-by: Jakub Sitnicki <jakub@cloudflare.com>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20200309111243.6982-4-lmb@cloudflare.com
include/net/tcp.h
net/core/sock_map.c
net/ipv4/tcp_bpf.c

index 07f947c..ccf39d8 100644 (file)
@@ -2196,7 +2196,6 @@ struct sk_msg;
 struct sk_psock;
 
 int tcp_bpf_init(struct sock *sk);
-void tcp_bpf_reinit(struct sock *sk);
 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
                          int flags);
 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
index cb8f740..fafcbd2 100644 (file)
@@ -145,8 +145,8 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
                         struct sock *sk)
 {
        struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
-       bool skb_progs, sk_psock_is_new = false;
        struct sk_psock *psock;
+       bool skb_progs;
        int ret;
 
        skb_verdict = READ_ONCE(progs->skb_verdict);
@@ -191,18 +191,14 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
                        ret = -ENOMEM;
                        goto out_progs;
                }
-               sk_psock_is_new = true;
        }
 
        if (msg_parser)
                psock_set_prog(&psock->progs.msg_parser, msg_parser);
-       if (sk_psock_is_new) {
-               ret = tcp_bpf_init(sk);
-               if (ret < 0)
-                       goto out_drop;
-       } else {
-               tcp_bpf_reinit(sk);
-       }
+
+       ret = tcp_bpf_init(sk);
+       if (ret < 0)
+               goto out_drop;
 
        write_lock_bh(&sk->sk_callback_lock);
        if (skb_progs && !psock->parser.enabled) {
@@ -239,15 +235,12 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
        if (IS_ERR(psock))
                return PTR_ERR(psock);
 
-       if (psock) {
-               tcp_bpf_reinit(sk);
-               return 0;
+       if (!psock) {
+               psock = sk_psock_init(sk, map->numa_node);
+               if (!psock)
+                       return -ENOMEM;
        }
 
-       psock = sk_psock_init(sk, map->numa_node);
-       if (!psock)
-               return -ENOMEM;
-
        ret = tcp_bpf_init(sk);
        if (ret < 0)
                sk_psock_put(sk, psock);
index 3327afa..ed8a8f3 100644 (file)
@@ -629,14 +629,6 @@ static int __init tcp_bpf_v4_build_proto(void)
 }
 core_initcall(tcp_bpf_v4_build_proto);
 
-static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
-{
-       int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
-       int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
-
-       sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
-}
-
 static int tcp_bpf_assert_proto_ops(struct proto *ops)
 {
        /* In order to avoid retpoline, we make assumptions when we call
@@ -648,34 +640,44 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
               ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-void tcp_bpf_reinit(struct sock *sk)
+static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
 {
-       struct sk_psock *psock;
+       int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
+       int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
-       sock_owned_by_me(sk);
+       if (!psock->sk_proto) {
+               struct proto *ops = READ_ONCE(sk->sk_prot);
 
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       tcp_bpf_update_sk_prot(sk, psock);
-       rcu_read_unlock();
+               if (tcp_bpf_assert_proto_ops(ops))
+                       return ERR_PTR(-EINVAL);
+
+               tcp_bpf_check_v6_needs_rebuild(sk, ops);
+       }
+
+       return &tcp_bpf_prots[family][config];
 }
 
 int tcp_bpf_init(struct sock *sk)
 {
-       struct proto *ops = READ_ONCE(sk->sk_prot);
        struct sk_psock *psock;
+       struct proto *prot;
 
        sock_owned_by_me(sk);
 
        rcu_read_lock();
        psock = sk_psock(sk);
-       if (unlikely(!psock || psock->sk_proto ||
-                    tcp_bpf_assert_proto_ops(ops))) {
+       if (unlikely(!psock)) {
                rcu_read_unlock();
                return -EINVAL;
        }
-       tcp_bpf_check_v6_needs_rebuild(sk, ops);
-       tcp_bpf_update_sk_prot(sk, psock);
+
+       prot = tcp_bpf_get_proto(sk, psock);
+       if (IS_ERR(prot)) {
+               rcu_read_unlock();
+               return PTR_ERR(prot);
+       }
+
+       sk_psock_update_proto(sk, psock, prot);
        rcu_read_unlock();
        return 0;
 }