OSDN Git Service

skmsg: Extract __tcp_bpf_recvmsg() and tcp_bpf_wait_data()
authorCong Wang <cong.wang@bytedance.com>
Wed, 31 Mar 2021 02:32:33 +0000 (19:32 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Thu, 1 Apr 2021 17:56:14 +0000 (10:56 -0700)
Although these two functions are only used by TCP, they are not
specific to TCP at all, both operate on skmsg and ingress_msg,
so fit in net/core/skmsg.c very well.

And we will need them for non-TCP, so rename and move them to
skmsg.c and export them to modules.

Signed-off-by: Cong Wang <cong.wang@bytedance.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210331023237.41094-13-xiyou.wangcong@gmail.com
include/linux/skmsg.h
include/net/tcp.h
net/core/skmsg.c
net/ipv4/tcp_bpf.c
net/tls/tls_sw.c

index 5e800dd..f78e90a 100644 (file)
@@ -125,6 +125,10 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
                              struct sk_msg *msg, u32 bytes);
 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
                             struct sk_msg *msg, u32 bytes);
+int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
+                    long timeo, int *err);
+int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
+                  int len, int flags);
 
 static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
 {
index 2efa4e5..31b1696 100644 (file)
@@ -2209,8 +2209,6 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
 
 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
                          int flags);
-int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
-                     struct msghdr *msg, int len, int flags);
 #endif /* CONFIG_NET_SOCK_MSG */
 
 #if !defined(CONFIG_BPF_SYSCALL) || !defined(CONFIG_NET_SOCK_MSG)
index 9fc83f7..92a83c0 100644 (file)
@@ -399,6 +399,104 @@ out:
 }
 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
 
+int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
+                    long timeo, int *err)
+{
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
+       int ret = 0;
+
+       if (sk->sk_shutdown & RCV_SHUTDOWN)
+               return 1;
+
+       if (!timeo)
+               return ret;
+
+       add_wait_queue(sk_sleep(sk), &wait);
+       sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+       ret = sk_wait_event(sk, &timeo,
+                           !list_empty(&psock->ingress_msg) ||
+                           !skb_queue_empty(&sk->sk_receive_queue), &wait);
+       sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+       remove_wait_queue(sk_sleep(sk), &wait);
+       return ret;
+}
+EXPORT_SYMBOL_GPL(sk_msg_wait_data);
+
+/* Receive sk_msg from psock->ingress_msg to @msg. */
+int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
+                  int len, int flags)
+{
+       struct iov_iter *iter = &msg->msg_iter;
+       int peek = flags & MSG_PEEK;
+       struct sk_msg *msg_rx;
+       int i, copied = 0;
+
+       msg_rx = sk_psock_peek_msg(psock);
+       while (copied != len) {
+               struct scatterlist *sge;
+
+               if (unlikely(!msg_rx))
+                       break;
+
+               i = msg_rx->sg.start;
+               do {
+                       struct page *page;
+                       int copy;
+
+                       sge = sk_msg_elem(msg_rx, i);
+                       copy = sge->length;
+                       page = sg_page(sge);
+                       if (copied + copy > len)
+                               copy = len - copied;
+                       copy = copy_page_to_iter(page, sge->offset, copy, iter);
+                       if (!copy)
+                               return copied ? copied : -EFAULT;
+
+                       copied += copy;
+                       if (likely(!peek)) {
+                               sge->offset += copy;
+                               sge->length -= copy;
+                               if (!msg_rx->skb)
+                                       sk_mem_uncharge(sk, copy);
+                               msg_rx->sg.size -= copy;
+
+                               if (!sge->length) {
+                                       sk_msg_iter_var_next(i);
+                                       if (!msg_rx->skb)
+                                               put_page(page);
+                               }
+                       } else {
+                               /* Lets not optimize peek case if copy_page_to_iter
+                                * didn't copy the entire length lets just break.
+                                */
+                               if (copy != sge->length)
+                                       return copied;
+                               sk_msg_iter_var_next(i);
+                       }
+
+                       if (copied == len)
+                               break;
+               } while (i != msg_rx->sg.end);
+
+               if (unlikely(peek)) {
+                       msg_rx = sk_psock_next_msg(psock, msg_rx);
+                       if (!msg_rx)
+                               break;
+                       continue;
+               }
+
+               msg_rx->sg.start = i;
+               if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
+                       msg_rx = sk_psock_dequeue_msg(psock);
+                       kfree_sk_msg(msg_rx);
+               }
+               msg_rx = sk_psock_peek_msg(psock);
+       }
+
+       return copied;
+}
+EXPORT_SYMBOL_GPL(sk_msg_recvmsg);
+
 static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
                                                  struct sk_buff *skb)
 {
index ac8cfba..3d622a0 100644 (file)
 #include <net/inet_common.h>
 #include <net/tls.h>
 
-int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
-                     struct msghdr *msg, int len, int flags)
-{
-       struct iov_iter *iter = &msg->msg_iter;
-       int peek = flags & MSG_PEEK;
-       struct sk_msg *msg_rx;
-       int i, copied = 0;
-
-       msg_rx = sk_psock_peek_msg(psock);
-       while (copied != len) {
-               struct scatterlist *sge;
-
-               if (unlikely(!msg_rx))
-                       break;
-
-               i = msg_rx->sg.start;
-               do {
-                       struct page *page;
-                       int copy;
-
-                       sge = sk_msg_elem(msg_rx, i);
-                       copy = sge->length;
-                       page = sg_page(sge);
-                       if (copied + copy > len)
-                               copy = len - copied;
-                       copy = copy_page_to_iter(page, sge->offset, copy, iter);
-                       if (!copy)
-                               return copied ? copied : -EFAULT;
-
-                       copied += copy;
-                       if (likely(!peek)) {
-                               sge->offset += copy;
-                               sge->length -= copy;
-                               if (!msg_rx->skb)
-                                       sk_mem_uncharge(sk, copy);
-                               msg_rx->sg.size -= copy;
-
-                               if (!sge->length) {
-                                       sk_msg_iter_var_next(i);
-                                       if (!msg_rx->skb)
-                                               put_page(page);
-                               }
-                       } else {
-                               /* Lets not optimize peek case if copy_page_to_iter
-                                * didn't copy the entire length lets just break.
-                                */
-                               if (copy != sge->length)
-                                       return copied;
-                               sk_msg_iter_var_next(i);
-                       }
-
-                       if (copied == len)
-                               break;
-               } while (i != msg_rx->sg.end);
-
-               if (unlikely(peek)) {
-                       msg_rx = sk_psock_next_msg(psock, msg_rx);
-                       if (!msg_rx)
-                               break;
-                       continue;
-               }
-
-               msg_rx->sg.start = i;
-               if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
-                       msg_rx = sk_psock_dequeue_msg(psock);
-                       kfree_sk_msg(msg_rx);
-               }
-               msg_rx = sk_psock_peek_msg(psock);
-       }
-
-       return copied;
-}
-EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
-
 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
                           struct sk_msg *msg, u32 apply_bytes, int flags)
 {
@@ -237,28 +163,6 @@ static bool tcp_bpf_stream_read(const struct sock *sk)
        return !empty;
 }
 
-static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
-                            int flags, long timeo, int *err)
-{
-       DEFINE_WAIT_FUNC(wait, woken_wake_function);
-       int ret = 0;
-
-       if (sk->sk_shutdown & RCV_SHUTDOWN)
-               return 1;
-
-       if (!timeo)
-               return ret;
-
-       add_wait_queue(sk_sleep(sk), &wait);
-       sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-       ret = sk_wait_event(sk, &timeo,
-                           !list_empty(&psock->ingress_msg) ||
-                           !skb_queue_empty(&sk->sk_receive_queue), &wait);
-       sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-       remove_wait_queue(sk_sleep(sk), &wait);
-       return ret;
-}
-
 static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                    int nonblock, int flags, int *addr_len)
 {
@@ -278,13 +182,13 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
        }
        lock_sock(sk);
 msg_bytes_ready:
-       copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
+       copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
        if (!copied) {
                int data, err = 0;
                long timeo;
 
                timeo = sock_rcvtimeo(sk, nonblock);
-               data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
+               data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
                if (data) {
                        if (!sk_psock_queue_empty(psock))
                                goto msg_bytes_ready;
index 01d933a..1dcb34d 100644 (file)
@@ -1789,8 +1789,8 @@ int tls_sw_recvmsg(struct sock *sk,
                skb = tls_wait_data(sk, psock, flags, timeo, &err);
                if (!skb) {
                        if (psock) {
-                               int ret = __tcp_bpf_recvmsg(sk, psock,
-                                                           msg, len, flags);
+                               int ret = sk_msg_recvmsg(sk, psock, msg, len,
+                                                        flags);
 
                                if (ret > 0) {
                                        decrypted += ret;