Refactor the UDP/TCP handlers slightly to allow skb_steal_sock() to make
the determination of whether the socket is reference counted in the case
where it is prefetched by earlier logic such as early_demux.
Signed-off-by: Joe Stringer <joe@wand.net.nz>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Martin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/bpf/20200329225342.16317-3-joe@wand.net.nz
int iif, int sdif,
bool *refcounted)
{
- struct sock *sk = skb_steal_sock(skb);
+ struct sock *sk = skb_steal_sock(skb, refcounted);
- *refcounted = true;
if (sk)
return sk;
const int sdif,
bool *refcounted)
{
- struct sock *sk = skb_steal_sock(skb);
+ struct sock *sk = skb_steal_sock(skb, refcounted);
const struct iphdr *iph = ip_hdr(skb);
- *refcounted = true;
if (sk)
return sk;
#endif /* CONFIG_INET */
}
-static inline struct sock *skb_steal_sock(struct sk_buff *skb)
+/**
+ * skb_steal_sock
+ * @skb to steal the socket from
+ * @refcounted is set to true if the socket is reference-counted
+ */
+static inline struct sock *
+skb_steal_sock(struct sk_buff *skb, bool *refcounted)
{
if (skb->sk) {
struct sock *sk = skb->sk;
+ *refcounted = true;
skb->destructor = NULL;
skb->sk = NULL;
return sk;
}
+ *refcounted = false;
return NULL;
}
struct rtable *rt = skb_rtable(skb);
__be32 saddr, daddr;
struct net *net = dev_net(skb->dev);
+ bool refcounted;
/*
* Validate the packet.
if (udp4_csum_init(skb, uh, proto))
goto csum_error;
- sk = skb_steal_sock(skb);
+ sk = skb_steal_sock(skb, &refcounted);
if (sk) {
struct dst_entry *dst = skb_dst(skb);
int ret;
udp_sk_rx_dst_set(sk, dst);
ret = udp_unicast_rcv_skb(sk, skb, uh);
- sock_put(sk);
+ if (refcounted)
+ sock_put(sk);
return ret;
}
struct net *net = dev_net(skb->dev);
struct udphdr *uh;
struct sock *sk;
+ bool refcounted;
u32 ulen = 0;
if (!pskb_may_pull(skb, sizeof(struct udphdr)))
goto csum_error;
/* Check if the socket is already available, e.g. due to early demux */
- sk = skb_steal_sock(skb);
+ sk = skb_steal_sock(skb, &refcounted);
if (sk) {
struct dst_entry *dst = skb_dst(skb);
int ret;
udp6_sk_rx_dst_set(sk, dst);
if (!uh->check && !udp_sk(sk)->no_check6_rx) {
- sock_put(sk);
+ if (refcounted)
+ sock_put(sk);
goto report_csum_error;
}
ret = udp6_unicast_rcv_skb(sk, skb, uh);
- sock_put(sk);
+ if (refcounted)
+ sock_put(sk);
return ret;
}