OSDN Git Service

net: Track socket refcounts in skb_steal_sock()
authorJoe Stringer <joe@wand.net.nz>
Sun, 29 Mar 2020 22:53:39 +0000 (15:53 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Mon, 30 Mar 2020 20:45:04 +0000 (13:45 -0700)
Refactor the UDP/TCP handlers slightly to allow skb_steal_sock() to make
the determination of whether the socket is reference counted in the case
where it is prefetched by earlier logic such as early_demux.

Signed-off-by: Joe Stringer <joe@wand.net.nz>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Martin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/bpf/20200329225342.16317-3-joe@wand.net.nz
include/net/inet6_hashtables.h
include/net/inet_hashtables.h
include/net/sock.h
net/ipv4/udp.c
net/ipv6/udp.c

index fe96bf2..81b9659 100644 (file)
@@ -85,9 +85,8 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
                                              int iif, int sdif,
                                              bool *refcounted)
 {
-       struct sock *sk = skb_steal_sock(skb);
+       struct sock *sk = skb_steal_sock(skb, refcounted);
 
-       *refcounted = true;
        if (sk)
                return sk;
 
index d0019d3..ad64ba6 100644 (file)
@@ -379,10 +379,9 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
                                             const int sdif,
                                             bool *refcounted)
 {
-       struct sock *sk = skb_steal_sock(skb);
+       struct sock *sk = skb_steal_sock(skb, refcounted);
        const struct iphdr *iph = ip_hdr(skb);
 
-       *refcounted = true;
        if (sk)
                return sk;
 
index dc398ce..f81d528 100644 (file)
@@ -2537,15 +2537,23 @@ skb_sk_is_prefetched(struct sk_buff *skb)
 #endif /* CONFIG_INET */
 }
 
-static inline struct sock *skb_steal_sock(struct sk_buff *skb)
+/**
+ * skb_steal_sock
+ * @skb to steal the socket from
+ * @refcounted is set to true if the socket is reference-counted
+ */
+static inline struct sock *
+skb_steal_sock(struct sk_buff *skb, bool *refcounted)
 {
        if (skb->sk) {
                struct sock *sk = skb->sk;
 
+               *refcounted = true;
                skb->destructor = NULL;
                skb->sk = NULL;
                return sk;
        }
+       *refcounted = false;
        return NULL;
 }
 
index 2633fc2..b403502 100644 (file)
@@ -2288,6 +2288,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        struct rtable *rt = skb_rtable(skb);
        __be32 saddr, daddr;
        struct net *net = dev_net(skb->dev);
+       bool refcounted;
 
        /*
         *  Validate the packet.
@@ -2313,7 +2314,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        if (udp4_csum_init(skb, uh, proto))
                goto csum_error;
 
-       sk = skb_steal_sock(skb);
+       sk = skb_steal_sock(skb, &refcounted);
        if (sk) {
                struct dst_entry *dst = skb_dst(skb);
                int ret;
@@ -2322,7 +2323,8 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
                        udp_sk_rx_dst_set(sk, dst);
 
                ret = udp_unicast_rcv_skb(sk, skb, uh);
-               sock_put(sk);
+               if (refcounted)
+                       sock_put(sk);
                return ret;
        }
 
index 5dc439a..7d41517 100644 (file)
@@ -843,6 +843,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        struct net *net = dev_net(skb->dev);
        struct udphdr *uh;
        struct sock *sk;
+       bool refcounted;
        u32 ulen = 0;
 
        if (!pskb_may_pull(skb, sizeof(struct udphdr)))
@@ -879,7 +880,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
                goto csum_error;
 
        /* Check if the socket is already available, e.g. due to early demux */
-       sk = skb_steal_sock(skb);
+       sk = skb_steal_sock(skb, &refcounted);
        if (sk) {
                struct dst_entry *dst = skb_dst(skb);
                int ret;
@@ -888,12 +889,14 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
                        udp6_sk_rx_dst_set(sk, dst);
 
                if (!uh->check && !udp_sk(sk)->no_check6_rx) {
-                       sock_put(sk);
+                       if (refcounted)
+                               sock_put(sk);
                        goto report_csum_error;
                }
 
                ret = udp6_unicast_rcv_skb(sk, skb, uh);
-               sock_put(sk);
+               if (refcounted)
+                       sock_put(sk);
                return ret;
        }