OSDN Git Service

io_uring: add IOSQE_BUFFER_SELECT support for IORING_OP_RECVMSG
authorJens Axboe <axboe@kernel.dk>
Thu, 27 Feb 2020 17:15:42 +0000 (10:15 -0700)
committerJens Axboe <axboe@kernel.dk>
Tue, 10 Mar 2020 15:12:51 +0000 (09:12 -0600)
Like IORING_OP_READV, this is limited to supporting just a single
segment in the iovec passed in.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c

index 7c855a0..455d53f 100644 (file)
@@ -44,6 +44,7 @@
 #include <linux/errno.h>
 #include <linux/syscalls.h>
 #include <linux/compat.h>
+#include <net/compat.h>
 #include <linux/refcount.h>
 #include <linux/uio.h>
 #include <linux/bits.h>
@@ -729,6 +730,7 @@ static const struct io_op_def io_op_defs[] = {
                .unbound_nonreg_file    = 1,
                .needs_fs               = 1,
                .pollin                 = 1,
+               .buffer_select          = 1,
        },
        [IORING_OP_TIMEOUT] = {
                .async_ctx              = 1,
@@ -3569,6 +3571,92 @@ static int io_send(struct io_kiocb *req, bool force_nonblock)
 #endif
 }
 
+static int __io_recvmsg_copy_hdr(struct io_kiocb *req, struct io_async_ctx *io)
+{
+       struct io_sr_msg *sr = &req->sr_msg;
+       struct iovec __user *uiov;
+       size_t iov_len;
+       int ret;
+
+       ret = __copy_msghdr_from_user(&io->msg.msg, sr->msg, &io->msg.uaddr,
+                                       &uiov, &iov_len);
+       if (ret)
+               return ret;
+
+       if (req->flags & REQ_F_BUFFER_SELECT) {
+               if (iov_len > 1)
+                       return -EINVAL;
+               if (copy_from_user(io->msg.iov, uiov, sizeof(*uiov)))
+                       return -EFAULT;
+               sr->len = io->msg.iov[0].iov_len;
+               iov_iter_init(&io->msg.msg.msg_iter, READ, io->msg.iov, 1,
+                               sr->len);
+               io->msg.iov = NULL;
+       } else {
+               ret = import_iovec(READ, uiov, iov_len, UIO_FASTIOV,
+                                       &io->msg.iov, &io->msg.msg.msg_iter);
+               if (ret > 0)
+                       ret = 0;
+       }
+
+       return ret;
+}
+
+#ifdef CONFIG_COMPAT
+static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
+                                       struct io_async_ctx *io)
+{
+       struct compat_msghdr __user *msg_compat;
+       struct io_sr_msg *sr = &req->sr_msg;
+       struct compat_iovec __user *uiov;
+       compat_uptr_t ptr;
+       compat_size_t len;
+       int ret;
+
+       msg_compat = (struct compat_msghdr __user *) sr->msg;
+       ret = __get_compat_msghdr(&io->msg.msg, msg_compat, &io->msg.uaddr,
+                                       &ptr, &len);
+       if (ret)
+               return ret;
+
+       uiov = compat_ptr(ptr);
+       if (req->flags & REQ_F_BUFFER_SELECT) {
+               compat_ssize_t clen;
+
+               if (len > 1)
+                       return -EINVAL;
+               if (!access_ok(uiov, sizeof(*uiov)))
+                       return -EFAULT;
+               if (__get_user(clen, &uiov->iov_len))
+                       return -EFAULT;
+               if (clen < 0)
+                       return -EINVAL;
+               sr->len = io->msg.iov[0].iov_len;
+               io->msg.iov = NULL;
+       } else {
+               ret = compat_import_iovec(READ, uiov, len, UIO_FASTIOV,
+                                               &io->msg.iov,
+                                               &io->msg.msg.msg_iter);
+               if (ret < 0)
+                       return ret;
+       }
+
+       return 0;
+}
+#endif
+
+static int io_recvmsg_copy_hdr(struct io_kiocb *req, struct io_async_ctx *io)
+{
+       io->msg.iov = io->msg.fast_iov;
+
+#ifdef CONFIG_COMPAT
+       if (req->ctx->compat)
+               return __io_compat_recvmsg_copy_hdr(req, io);
+#endif
+
+       return __io_recvmsg_copy_hdr(req, io);
+}
+
 static struct io_buffer *io_recv_buffer_select(struct io_kiocb *req,
                                               int *cflags, bool needs_lock)
 {
@@ -3614,9 +3702,7 @@ static int io_recvmsg_prep(struct io_kiocb *req,
        if (req->flags & REQ_F_NEED_CLEANUP)
                return 0;
 
-       io->msg.iov = io->msg.fast_iov;
-       ret = recvmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
-                                       &io->msg.uaddr, &io->msg.iov);
+       ret = io_recvmsg_copy_hdr(req, io);
        if (!ret)
                req->flags |= REQ_F_NEED_CLEANUP;
        return ret;
@@ -3630,13 +3716,14 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock)
 #if defined(CONFIG_NET)
        struct io_async_msghdr *kmsg = NULL;
        struct socket *sock;
-       int ret;
+       int ret, cflags = 0;
 
        if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
                return -EINVAL;
 
        sock = sock_from_file(req->file, &ret);
        if (sock) {
+               struct io_buffer *kbuf;
                struct io_async_ctx io;
                unsigned flags;
 
@@ -3648,19 +3735,23 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock)
                                kmsg->iov = kmsg->fast_iov;
                        kmsg->msg.msg_iter.iov = kmsg->iov;
                } else {
-                       struct io_sr_msg *sr = &req->sr_msg;
-
                        kmsg = &io.msg;
                        kmsg->msg.msg_name = &io.msg.addr;
 
-                       io.msg.iov = io.msg.fast_iov;
-                       ret = recvmsg_copy_msghdr(&io.msg.msg, sr->msg,
-                                       sr->msg_flags, &io.msg.uaddr,
-                                       &io.msg.iov);
+                       ret = io_recvmsg_copy_hdr(req, &io);
                        if (ret)
                                return ret;
                }
 
+               kbuf = io_recv_buffer_select(req, &cflags, !force_nonblock);
+               if (IS_ERR(kbuf)) {
+                       return PTR_ERR(kbuf);
+               } else if (kbuf) {
+                       kmsg->fast_iov[0].iov_base = u64_to_user_ptr(kbuf->addr);
+                       iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->iov,
+                                       1, req->sr_msg.len);
+               }
+
                flags = req->sr_msg.msg_flags;
                if (flags & MSG_DONTWAIT)
                        req->flags |= REQ_F_NOWAIT;
@@ -3678,7 +3769,7 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock)
        if (kmsg && kmsg->iov != kmsg->fast_iov)
                kfree(kmsg->iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       io_cqring_add_event(req, ret);
+       __io_cqring_add_event(req, ret, cflags);
        if (ret < 0)
                req_set_fail_links(req);
        io_put_req(req);
@@ -4789,8 +4880,11 @@ static void io_cleanup_req(struct io_kiocb *req)
                if (io->rw.iov != io->rw.fast_iov)
                        kfree(io->rw.iov);
                break;
-       case IORING_OP_SENDMSG:
        case IORING_OP_RECVMSG:
+               if (req->flags & REQ_F_BUFFER_SELECTED)
+                       kfree(req->sr_msg.kbuf);
+               /* fallthrough */
+       case IORING_OP_SENDMSG:
                if (io->msg.iov != io->msg.fast_iov)
                        kfree(io->msg.iov);
                break;