OSDN Git Service

io_uring: pick up link work on submit reference drop
[tomoyo/tomoyo-test1.git] / fs / io_uring.c
index 77f22c3..f79ca49 100644 (file)
@@ -75,6 +75,7 @@
 #include <linux/fsnotify.h>
 #include <linux/fadvise.h>
 #include <linux/eventpoll.h>
+#include <linux/fs_struct.h>
 
 #define CREATE_TRACE_POINTS
 #include <trace/events/io_uring.h>
@@ -204,11 +205,11 @@ struct io_ring_ctx {
 
        struct {
                unsigned int            flags;
-               int                     compat: 1;
-               int                     account_mem: 1;
-               int                     cq_overflow_flushed: 1;
-               int                     drain_next: 1;
-               int                     eventfd_async: 1;
+               unsigned int            compat: 1;
+               unsigned int            account_mem: 1;
+               unsigned int            cq_overflow_flushed: 1;
+               unsigned int            drain_next: 1;
+               unsigned int            eventfd_async: 1;
 
                /*
                 * Ring buffer of indices into array of io_uring_sqe, which is
@@ -441,6 +442,7 @@ struct io_async_msghdr {
        struct iovec                    *iov;
        struct sockaddr __user          *uaddr;
        struct msghdr                   msg;
+       struct sockaddr_storage         addr;
 };
 
 struct io_async_rw {
@@ -450,17 +452,12 @@ struct io_async_rw {
        ssize_t                         size;
 };
 
-struct io_async_open {
-       struct filename                 *filename;
-};
-
 struct io_async_ctx {
        union {
                struct io_async_rw      rw;
                struct io_async_msghdr  msg;
                struct io_async_connect connect;
                struct io_timeout_data  timeout;
-               struct io_async_open    open;
        };
 };
 
@@ -483,6 +480,8 @@ enum {
        REQ_F_MUST_PUNT_BIT,
        REQ_F_TIMEOUT_NOSEQ_BIT,
        REQ_F_COMP_LOCKED_BIT,
+       REQ_F_NEED_CLEANUP_BIT,
+       REQ_F_OVERFLOW_BIT,
 };
 
 enum {
@@ -521,6 +520,10 @@ enum {
        REQ_F_TIMEOUT_NOSEQ     = BIT(REQ_F_TIMEOUT_NOSEQ_BIT),
        /* completion under lock */
        REQ_F_COMP_LOCKED       = BIT(REQ_F_COMP_LOCKED_BIT),
+       /* needs cleanup */
+       REQ_F_NEED_CLEANUP      = BIT(REQ_F_NEED_CLEANUP_BIT),
+       /* in overflow list */
+       REQ_F_OVERFLOW          = BIT(REQ_F_OVERFLOW_BIT),
 };
 
 /*
@@ -553,7 +556,6 @@ struct io_kiocb {
         * llist_node is only used for poll deferred completions
         */
        struct llist_node               llist_node;
-       bool                            has_user;
        bool                            in_async;
        bool                            needs_fixed_file;
        u8                              opcode;
@@ -614,6 +616,8 @@ struct io_op_def {
        unsigned                not_supported : 1;
        /* needs file table */
        unsigned                file_table : 1;
+       /* needs ->fs */
+       unsigned                needs_fs : 1;
 };
 
 static const struct io_op_def io_op_defs[] = {
@@ -656,12 +660,14 @@ static const struct io_op_def io_op_defs[] = {
                .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
+               .needs_fs               = 1,
        },
        [IORING_OP_RECVMSG] = {
                .async_ctx              = 1,
                .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
+               .needs_fs               = 1,
        },
        [IORING_OP_TIMEOUT] = {
                .async_ctx              = 1,
@@ -692,6 +698,7 @@ static const struct io_op_def io_op_defs[] = {
                .needs_file             = 1,
                .fd_non_neg             = 1,
                .file_table             = 1,
+               .needs_fs               = 1,
        },
        [IORING_OP_CLOSE] = {
                .needs_file             = 1,
@@ -705,6 +712,7 @@ static const struct io_op_def io_op_defs[] = {
                .needs_mm               = 1,
                .needs_file             = 1,
                .fd_non_neg             = 1,
+               .needs_fs               = 1,
        },
        [IORING_OP_READ] = {
                .needs_mm               = 1,
@@ -736,6 +744,7 @@ static const struct io_op_def io_op_defs[] = {
                .needs_file             = 1,
                .fd_non_neg             = 1,
                .file_table             = 1,
+               .needs_fs               = 1,
        },
        [IORING_OP_EPOLL_CTL] = {
                .unbound_nonreg_file    = 1,
@@ -754,6 +763,7 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
                                 unsigned nr_args);
 static int io_grab_files(struct io_kiocb *req);
 static void io_ring_file_ref_flush(struct fixed_file_data *data);
+static void io_cleanup_req(struct io_kiocb *req);
 
 static struct kmem_cache *req_cachep;
 
@@ -909,6 +919,18 @@ static inline void io_req_work_grab_env(struct io_kiocb *req,
        }
        if (!req->work.creds)
                req->work.creds = get_current_cred();
+       if (!req->work.fs && def->needs_fs) {
+               spin_lock(&current->fs->lock);
+               if (!current->fs->in_exec) {
+                       req->work.fs = current->fs;
+                       req->work.fs->users++;
+               } else {
+                       req->work.flags |= IO_WQ_WORK_CANCEL;
+               }
+               spin_unlock(&current->fs->lock);
+       }
+       if (!req->work.task_pid)
+               req->work.task_pid = task_pid_vnr(current);
 }
 
 static inline void io_req_work_drop_env(struct io_kiocb *req)
@@ -921,6 +943,16 @@ static inline void io_req_work_drop_env(struct io_kiocb *req)
                put_cred(req->work.creds);
                req->work.creds = NULL;
        }
+       if (req->work.fs) {
+               struct fs_struct *fs = req->work.fs;
+
+               spin_lock(&req->work.fs->lock);
+               if (--fs->users)
+                       fs = NULL;
+               spin_unlock(&req->work.fs->lock);
+               if (fs)
+                       free_fs_struct(fs);
+       }
 }
 
 static inline bool io_prep_async_work(struct io_kiocb *req,
@@ -1074,6 +1106,7 @@ static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
                req = list_first_entry(&ctx->cq_overflow_list, struct io_kiocb,
                                                list);
                list_move(&req->list, &list);
+               req->flags &= ~REQ_F_OVERFLOW;
                if (cqe) {
                        WRITE_ONCE(cqe->user_data, req->user_data);
                        WRITE_ONCE(cqe->res, req->result);
@@ -1126,6 +1159,7 @@ static void io_cqring_fill_event(struct io_kiocb *req, long res)
                        set_bit(0, &ctx->sq_check_overflow);
                        set_bit(0, &ctx->cq_check_overflow);
                }
+               req->flags |= REQ_F_OVERFLOW;
                refcount_inc(&req->refs);
                req->result = res;
                list_add_tail(&req->list, &ctx->cq_overflow_list);
@@ -1226,6 +1260,9 @@ static void __io_req_aux_free(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               io_cleanup_req(req);
+
        kfree(req->io);
        if (req->file) {
                if (req->flags & REQ_F_FIXED_FILE)
@@ -1446,10 +1483,10 @@ static void io_free_req(struct io_kiocb *req)
 __attribute__((nonnull))
 static void io_put_req_find_next(struct io_kiocb *req, struct io_kiocb **nxtptr)
 {
-       io_req_find_next(req, nxtptr);
-
-       if (refcount_dec_and_test(&req->refs))
+       if (refcount_dec_and_test(&req->refs)) {
+               io_req_find_next(req, nxtptr);
                __io_free_req(req);
+       }
 }
 
 static void io_put_req(struct io_kiocb *req)
@@ -1635,11 +1672,17 @@ static void io_iopoll_reap_events(struct io_ring_ctx *ctx)
        mutex_unlock(&ctx->uring_lock);
 }
 
-static int __io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events,
-                           long min)
+static int io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events,
+                          long min)
 {
        int iters = 0, ret = 0;
 
+       /*
+        * We disallow the app entering submit/complete with polling, but we
+        * still need to lock the ring to prevent racing with polled issue
+        * that got punted to a workqueue.
+        */
+       mutex_lock(&ctx->uring_lock);
        do {
                int tmin = 0;
 
@@ -1675,21 +1718,6 @@ static int __io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events,
                ret = 0;
        } while (min && !*nr_events && !need_resched());
 
-       return ret;
-}
-
-static int io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events,
-                          long min)
-{
-       int ret;
-
-       /*
-        * We disallow the app entering submit/complete with polling, but we
-        * still need to lock the ring to prevent racing with polled issue
-        * that got punted to a workqueue.
-        */
-       mutex_lock(&ctx->uring_lock);
-       ret = __io_iopoll_check(ctx, nr_events, min);
        mutex_unlock(&ctx->uring_lock);
        return ret;
 }
@@ -1793,6 +1821,10 @@ static void io_iopoll_req_issued(struct io_kiocb *req)
                list_add(&req->list, &ctx->poll_list);
        else
                list_add_tail(&req->list, &ctx->poll_list);
+
+       if ((ctx->flags & IORING_SETUP_SQPOLL) &&
+           wq_has_sleeper(&ctx->sqo_wait))
+               wake_up(&ctx->sqo_wait);
 }
 
 static void io_file_put(struct io_submit_state *state)
@@ -2056,9 +2088,6 @@ static ssize_t io_import_iovec(int rw, struct io_kiocb *req,
                return iorw->size;
        }
 
-       if (!req->has_user)
-               return -EFAULT;
-
 #ifdef CONFIG_COMPAT
        if (req->ctx->compat)
                return compat_import_iovec(rw, buf, sqe_len, UIO_FASTIOV,
@@ -2137,6 +2166,8 @@ static void io_req_map_rw(struct io_kiocb *req, ssize_t io_size,
                req->io->rw.iov = req->io->rw.fast_iov;
                memcpy(req->io->rw.iov, fast_iov,
                        sizeof(struct iovec) * iter->nr_segs);
+       } else {
+               req->flags |= REQ_F_NEED_CLEANUP;
        }
 }
 
@@ -2148,17 +2179,6 @@ static int io_alloc_async_ctx(struct io_kiocb *req)
        return req->io == NULL;
 }
 
-static void io_rw_async(struct io_wq_work **workptr)
-{
-       struct io_kiocb *req = container_of(*workptr, struct io_kiocb, work);
-       struct iovec *iov = NULL;
-
-       if (req->io->rw.iov != req->io->rw.fast_iov)
-               iov = req->io->rw.iov;
-       io_wq_submit_work(workptr);
-       kfree(iov);
-}
-
 static int io_setup_async_rw(struct io_kiocb *req, ssize_t io_size,
                             struct iovec *iovec, struct iovec *fast_iov,
                             struct iov_iter *iter)
@@ -2171,7 +2191,6 @@ static int io_setup_async_rw(struct io_kiocb *req, ssize_t io_size,
 
                io_req_map_rw(req, io_size, iovec, fast_iov, iter);
        }
-       req->work.func = io_rw_async;
        return 0;
 }
 
@@ -2189,7 +2208,8 @@ static int io_read_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe,
        if (unlikely(!(req->file->f_mode & FMODE_READ)))
                return -EBADF;
 
-       if (!req->io)
+       /* either don't need iovec imported or already have it */
+       if (!req->io || req->flags & REQ_F_NEED_CLEANUP)
                return 0;
 
        io = req->io;
@@ -2258,8 +2278,8 @@ copy_iov:
                }
        }
 out_free:
-       if (!io_wq_current_is_worker())
-               kfree(iovec);
+       kfree(iovec);
+       req->flags &= ~REQ_F_NEED_CLEANUP;
        return ret;
 }
 
@@ -2277,7 +2297,8 @@ static int io_write_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe,
        if (unlikely(!(req->file->f_mode & FMODE_WRITE)))
                return -EBADF;
 
-       if (!req->io)
+       /* either don't need iovec imported or already have it */
+       if (!req->io || req->flags & REQ_F_NEED_CLEANUP)
                return 0;
 
        io = req->io;
@@ -2352,6 +2373,12 @@ static int io_write(struct io_kiocb *req, struct io_kiocb **nxt,
                        ret2 = call_write_iter(req->file, kiocb, &iter);
                else
                        ret2 = loop_rw_iter(WRITE, req->file, kiocb, &iter);
+               /*
+                * Raw bdev writes will -EOPNOTSUPP for IOCB_NOWAIT. Just
+                * retry them without IOCB_NOWAIT.
+                */
+               if (ret2 == -EOPNOTSUPP && (kiocb->ki_flags & IOCB_NOWAIT))
+                       ret2 = -EAGAIN;
                if (!force_nonblock || ret2 != -EAGAIN) {
                        kiocb_done(kiocb, ret2, nxt, req->in_async);
                } else {
@@ -2364,8 +2391,8 @@ copy_iov:
                }
        }
 out_free:
-       if (!io_wq_current_is_worker())
-               kfree(iovec);
+       req->flags &= ~REQ_F_NEED_CLEANUP;
+       kfree(iovec);
        return ret;
 }
 
@@ -2485,6 +2512,9 @@ static void io_fallocate_finish(struct io_wq_work **workptr)
        struct io_kiocb *nxt = NULL;
        int ret;
 
+       if (io_req_cancelled(req))
+               return;
+
        ret = vfs_fallocate(req->file, req->sync.mode, req->sync.off,
                                req->sync.len);
        if (ret < 0)
@@ -2534,6 +2564,10 @@ static int io_openat_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 
        if (sqe->ioprio || sqe->buf_index)
                return -EINVAL;
+       if (sqe->flags & IOSQE_FIXED_FILE)
+               return -EBADF;
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               return 0;
 
        req->open.dfd = READ_ONCE(sqe->fd);
        req->open.how.mode = READ_ONCE(sqe->len);
@@ -2547,6 +2581,7 @@ static int io_openat_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                return ret;
        }
 
+       req->flags |= REQ_F_NEED_CLEANUP;
        return 0;
 }
 
@@ -2559,6 +2594,10 @@ static int io_openat2_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 
        if (sqe->ioprio || sqe->buf_index)
                return -EINVAL;
+       if (sqe->flags & IOSQE_FIXED_FILE)
+               return -EBADF;
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               return 0;
 
        req->open.dfd = READ_ONCE(sqe->fd);
        fname = u64_to_user_ptr(READ_ONCE(sqe->addr));
@@ -2583,6 +2622,7 @@ static int io_openat2_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                return ret;
        }
 
+       req->flags |= REQ_F_NEED_CLEANUP;
        return 0;
 }
 
@@ -2614,6 +2654,7 @@ static int io_openat2(struct io_kiocb *req, struct io_kiocb **nxt,
        }
 err:
        putname(req->open.filename);
+       req->flags &= ~REQ_F_NEED_CLEANUP;
        if (ret < 0)
                req_set_fail_links(req);
        io_cqring_add_event(req, ret);
@@ -2754,6 +2795,10 @@ static int io_statx_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 
        if (sqe->ioprio || sqe->buf_index)
                return -EINVAL;
+       if (sqe->flags & IOSQE_FIXED_FILE)
+               return -EBADF;
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               return 0;
 
        req->open.dfd = READ_ONCE(sqe->fd);
        req->open.mask = READ_ONCE(sqe->len);
@@ -2771,6 +2816,7 @@ static int io_statx_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                return ret;
        }
 
+       req->flags |= REQ_F_NEED_CLEANUP;
        return 0;
 }
 
@@ -2808,6 +2854,7 @@ retry:
                ret = cp_statx(&stat, ctx->buffer);
 err:
        putname(ctx->filename);
+       req->flags &= ~REQ_F_NEED_CLEANUP;
        if (ret < 0)
                req_set_fail_links(req);
        io_cqring_add_event(req, ret);
@@ -2827,7 +2874,7 @@ static int io_close_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
            sqe->rw_flags || sqe->buf_index)
                return -EINVAL;
        if (sqe->flags & IOSQE_FIXED_FILE)
-               return -EINVAL;
+               return -EBADF;
 
        req->close.fd = READ_ONCE(sqe->fd);
        if (req->file->f_op == &io_uring_fops ||
@@ -2837,24 +2884,26 @@ static int io_close_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
        return 0;
 }
 
+/* only called when __close_fd_get_file() is done */
+static void __io_close_finish(struct io_kiocb *req, struct io_kiocb **nxt)
+{
+       int ret;
+
+       ret = filp_close(req->close.put_file, req->work.files);
+       if (ret < 0)
+               req_set_fail_links(req);
+       io_cqring_add_event(req, ret);
+       fput(req->close.put_file);
+       io_put_req_find_next(req, nxt);
+}
+
 static void io_close_finish(struct io_wq_work **workptr)
 {
        struct io_kiocb *req = container_of(*workptr, struct io_kiocb, work);
        struct io_kiocb *nxt = NULL;
 
-       /* Invoked with files, we need to do the close */
-       if (req->work.files) {
-               int ret;
-
-               ret = filp_close(req->close.put_file, req->work.files);
-               if (ret < 0)
-                       req_set_fail_links(req);
-               io_cqring_add_event(req, ret);
-       }
-
-       fput(req->close.put_file);
-
-       io_put_req_find_next(req, &nxt);
+       /* not cancellable, don't do io_req_cancelled() */
+       __io_close_finish(req, &nxt);
        if (nxt)
                io_wq_assign_next(workptr, nxt);
 }
@@ -2877,22 +2926,8 @@ static int io_close(struct io_kiocb *req, struct io_kiocb **nxt,
         * No ->flush(), safely close from here and just punt the
         * fput() to async context.
         */
-       ret = filp_close(req->close.put_file, current->files);
-
-       if (ret < 0)
-               req_set_fail_links(req);
-       io_cqring_add_event(req, ret);
-
-       if (io_wq_current_is_worker()) {
-               struct io_wq_work *old_work, *work;
-
-               old_work = work = &req->work;
-               io_close_finish(&work);
-               if (work && work != old_work)
-                       *nxt = container_of(work, struct io_kiocb, work);
-               return 0;
-       }
-
+       __io_close_finish(req, nxt);
+       return 0;
 eagain:
        req->work.func = io_close_finish;
        /*
@@ -2960,24 +2995,12 @@ static int io_sync_file_range(struct io_kiocb *req, struct io_kiocb **nxt,
        return 0;
 }
 
-#if defined(CONFIG_NET)
-static void io_sendrecv_async(struct io_wq_work **workptr)
-{
-       struct io_kiocb *req = container_of(*workptr, struct io_kiocb, work);
-       struct iovec *iov = NULL;
-
-       if (req->io->rw.iov != req->io->rw.fast_iov)
-               iov = req->io->msg.iov;
-       io_wq_submit_work(workptr);
-       kfree(iov);
-}
-#endif
-
 static int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
 #if defined(CONFIG_NET)
        struct io_sr_msg *sr = &req->sr_msg;
        struct io_async_ctx *io = req->io;
+       int ret;
 
        sr->msg_flags = READ_ONCE(sqe->msg_flags);
        sr->msg = u64_to_user_ptr(READ_ONCE(sqe->addr));
@@ -2985,10 +3008,16 @@ static int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 
        if (!io || req->opcode == IORING_OP_SEND)
                return 0;
+       /* iovec is already imported */
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               return 0;
 
        io->msg.iov = io->msg.fast_iov;
-       return sendmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
+       ret = sendmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
                                        &io->msg.iov);
+       if (!ret)
+               req->flags |= REQ_F_NEED_CLEANUP;
+       return ret;
 #else
        return -EOPNOTSUPP;
 #endif
@@ -3008,12 +3037,11 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
        sock = sock_from_file(req->file, &ret);
        if (sock) {
                struct io_async_ctx io;
-               struct sockaddr_storage addr;
                unsigned flags;
 
                if (req->io) {
                        kmsg = &req->io->msg;
-                       kmsg->msg.msg_name = &addr;
+                       kmsg->msg.msg_name = &req->io->msg.addr;
                        /* if iov is set, it's allocated already */
                        if (!kmsg->iov)
                                kmsg->iov = kmsg->fast_iov;
@@ -3022,7 +3050,7 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                        struct io_sr_msg *sr = &req->sr_msg;
 
                        kmsg = &io.msg;
-                       kmsg->msg.msg_name = &addr;
+                       kmsg->msg.msg_name = &io.msg.addr;
 
                        io.msg.iov = io.msg.fast_iov;
                        ret = sendmsg_copy_msghdr(&io.msg.msg, sr->msg,
@@ -3041,18 +3069,22 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                if (force_nonblock && ret == -EAGAIN) {
                        if (req->io)
                                return -EAGAIN;
-                       if (io_alloc_async_ctx(req))
+                       if (io_alloc_async_ctx(req)) {
+                               if (kmsg->iov != kmsg->fast_iov)
+                                       kfree(kmsg->iov);
                                return -ENOMEM;
+                       }
+                       req->flags |= REQ_F_NEED_CLEANUP;
                        memcpy(&req->io->msg, &io.msg, sizeof(io.msg));
-                       req->work.func = io_sendrecv_async;
                        return -EAGAIN;
                }
                if (ret == -ERESTARTSYS)
                        ret = -EINTR;
        }
 
-       if (!io_wq_current_is_worker() && kmsg && kmsg->iov != kmsg->fast_iov)
+       if (kmsg && kmsg->iov != kmsg->fast_iov)
                kfree(kmsg->iov);
+       req->flags &= ~REQ_F_NEED_CLEANUP;
        io_cqring_add_event(req, ret);
        if (ret < 0)
                req_set_fail_links(req);
@@ -3120,6 +3152,7 @@ static int io_recvmsg_prep(struct io_kiocb *req,
 #if defined(CONFIG_NET)
        struct io_sr_msg *sr = &req->sr_msg;
        struct io_async_ctx *io = req->io;
+       int ret;
 
        sr->msg_flags = READ_ONCE(sqe->msg_flags);
        sr->msg = u64_to_user_ptr(READ_ONCE(sqe->addr));
@@ -3127,10 +3160,16 @@ static int io_recvmsg_prep(struct io_kiocb *req,
 
        if (!io || req->opcode == IORING_OP_RECV)
                return 0;
+       /* iovec is already imported */
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               return 0;
 
        io->msg.iov = io->msg.fast_iov;
-       return recvmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
+       ret = recvmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
                                        &io->msg.uaddr, &io->msg.iov);
+       if (!ret)
+               req->flags |= REQ_F_NEED_CLEANUP;
+       return ret;
 #else
        return -EOPNOTSUPP;
 #endif
@@ -3150,12 +3189,11 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
        sock = sock_from_file(req->file, &ret);
        if (sock) {
                struct io_async_ctx io;
-               struct sockaddr_storage addr;
                unsigned flags;
 
                if (req->io) {
                        kmsg = &req->io->msg;
-                       kmsg->msg.msg_name = &addr;
+                       kmsg->msg.msg_name = &req->io->msg.addr;
                        /* if iov is set, it's allocated already */
                        if (!kmsg->iov)
                                kmsg->iov = kmsg->fast_iov;
@@ -3164,7 +3202,7 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                        struct io_sr_msg *sr = &req->sr_msg;
 
                        kmsg = &io.msg;
-                       kmsg->msg.msg_name = &addr;
+                       kmsg->msg.msg_name = &io.msg.addr;
 
                        io.msg.iov = io.msg.fast_iov;
                        ret = recvmsg_copy_msghdr(&io.msg.msg, sr->msg,
@@ -3185,18 +3223,22 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                if (force_nonblock && ret == -EAGAIN) {
                        if (req->io)
                                return -EAGAIN;
-                       if (io_alloc_async_ctx(req))
+                       if (io_alloc_async_ctx(req)) {
+                               if (kmsg->iov != kmsg->fast_iov)
+                                       kfree(kmsg->iov);
                                return -ENOMEM;
+                       }
                        memcpy(&req->io->msg, &io.msg, sizeof(io.msg));
-                       req->work.func = io_sendrecv_async;
+                       req->flags |= REQ_F_NEED_CLEANUP;
                        return -EAGAIN;
                }
                if (ret == -ERESTARTSYS)
                        ret = -EINTR;
        }
 
-       if (!io_wq_current_is_worker() && kmsg && kmsg->iov != kmsg->fast_iov)
+       if (kmsg && kmsg->iov != kmsg->fast_iov)
                kfree(kmsg->iov);
+       req->flags &= ~REQ_F_NEED_CLEANUP;
        io_cqring_add_event(req, ret);
        if (ret < 0)
                req_set_fail_links(req);
@@ -4207,6 +4249,35 @@ static int io_req_defer(struct io_kiocb *req, const struct io_uring_sqe *sqe)
        return -EIOCBQUEUED;
 }
 
+static void io_cleanup_req(struct io_kiocb *req)
+{
+       struct io_async_ctx *io = req->io;
+
+       switch (req->opcode) {
+       case IORING_OP_READV:
+       case IORING_OP_READ_FIXED:
+       case IORING_OP_READ:
+       case IORING_OP_WRITEV:
+       case IORING_OP_WRITE_FIXED:
+       case IORING_OP_WRITE:
+               if (io->rw.iov != io->rw.fast_iov)
+                       kfree(io->rw.iov);
+               break;
+       case IORING_OP_SENDMSG:
+       case IORING_OP_RECVMSG:
+               if (io->msg.iov != io->msg.fast_iov)
+                       kfree(io->msg.iov);
+               break;
+       case IORING_OP_OPENAT:
+       case IORING_OP_OPENAT2:
+       case IORING_OP_STATX:
+               putname(req->open.filename);
+               break;
+       }
+
+       req->flags &= ~REQ_F_NEED_CLEANUP;
+}
+
 static int io_issue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                        struct io_kiocb **nxt, bool force_nonblock)
 {
@@ -4446,7 +4517,6 @@ static void io_wq_submit_work(struct io_wq_work **workptr)
        }
 
        if (!ret) {
-               req->has_user = (work->flags & IO_WQ_WORK_HAS_MM) != 0;
                req->in_async = true;
                do {
                        ret = io_issue_sqe(req, NULL, &nxt, false);
@@ -4479,7 +4549,7 @@ static int io_req_needs_file(struct io_kiocb *req, int fd)
 {
        if (!io_op_defs[req->opcode].needs_file)
                return 0;
-       if (fd == -1 && io_op_defs[req->opcode].fd_non_neg)
+       if ((fd == -1 || fd == AT_FDCWD) && io_op_defs[req->opcode].fd_non_neg)
                return 0;
        return 1;
 }
@@ -4639,11 +4709,21 @@ static void __io_queue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
        struct io_kiocb *linked_timeout;
        struct io_kiocb *nxt = NULL;
+       const struct cred *old_creds = NULL;
        int ret;
 
 again:
        linked_timeout = io_prep_linked_timeout(req);
 
+       if (req->work.creds && req->work.creds != current_cred()) {
+               if (old_creds)
+                       revert_creds(old_creds);
+               if (old_creds == req->work.creds)
+                       old_creds = NULL; /* restored original creds */
+               else
+                       old_creds = override_creds(req->work.creds);
+       }
+
        ret = io_issue_sqe(req, sqe, &nxt, true);
 
        /*
@@ -4669,7 +4749,7 @@ punt:
 
 err:
        /* drop submission reference */
-       io_put_req(req);
+       io_put_req_find_next(req, &nxt);
 
        if (linked_timeout) {
                if (!ret)
@@ -4693,6 +4773,8 @@ done_req:
                        goto punt;
                goto again;
        }
+       if (old_creds)
+               revert_creds(old_creds);
 }
 
 static void io_queue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe)
@@ -4737,7 +4819,6 @@ static inline void io_queue_link_head(struct io_kiocb *req)
 static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                          struct io_submit_state *state, struct io_kiocb **link)
 {
-       const struct cred *old_creds = NULL;
        struct io_ring_ctx *ctx = req->ctx;
        unsigned int sqe_flags;
        int ret, id;
@@ -4752,14 +4833,12 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 
        id = READ_ONCE(sqe->personality);
        if (id) {
-               const struct cred *personality_creds;
-
-               personality_creds = idr_find(&ctx->personality_idr, id);
-               if (unlikely(!personality_creds)) {
+               req->work.creds = idr_find(&ctx->personality_idr, id);
+               if (unlikely(!req->work.creds)) {
                        ret = -EINVAL;
                        goto err_req;
                }
-               old_creds = override_creds(personality_creds);
+               get_cred(req->work.creds);
        }
 
        /* same numerical values with corresponding REQ_F_*, safe to copy */
@@ -4771,8 +4850,6 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 err_req:
                io_cqring_add_event(req, ret);
                io_double_put_req(req);
-               if (old_creds)
-                       revert_creds(old_creds);
                return false;
        }
 
@@ -4833,8 +4910,6 @@ err_req:
                }
        }
 
-       if (old_creds)
-               revert_creds(old_creds);
        return true;
 }
 
@@ -4950,6 +5025,7 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
        for (i = 0; i < nr; i++) {
                const struct io_uring_sqe *sqe;
                struct io_kiocb *req;
+               int err;
 
                req = io_get_req(ctx, statep);
                if (unlikely(!req)) {
@@ -4966,20 +5042,23 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
                submitted++;
 
                if (unlikely(req->opcode >= IORING_OP_LAST)) {
-                       io_cqring_add_event(req, -EINVAL);
+                       err = -EINVAL;
+fail_req:
+                       io_cqring_add_event(req, err);
                        io_double_put_req(req);
                        break;
                }
 
                if (io_op_defs[req->opcode].needs_mm && !*mm) {
                        mm_fault = mm_fault || !mmget_not_zero(ctx->sqo_mm);
-                       if (!mm_fault) {
-                               use_mm(ctx->sqo_mm);
-                               *mm = ctx->sqo_mm;
+                       if (unlikely(mm_fault)) {
+                               err = -EFAULT;
+                               goto fail_req;
                        }
+                       use_mm(ctx->sqo_mm);
+                       *mm = ctx->sqo_mm;
                }
 
-               req->has_user = *mm != NULL;
                req->in_async = async;
                req->needs_fixed_file = async;
                trace_io_uring_submit_sqe(ctx, req->opcode, req->user_data,
@@ -5011,9 +5090,8 @@ static int io_sq_thread(void *data)
        const struct cred *old_cred;
        mm_segment_t old_fs;
        DEFINE_WAIT(wait);
-       unsigned inflight;
        unsigned long timeout;
-       int ret;
+       int ret = 0;
 
        complete(&ctx->completions[1]);
 
@@ -5021,39 +5099,19 @@ static int io_sq_thread(void *data)
        set_fs(USER_DS);
        old_cred = override_creds(ctx->creds);
 
-       ret = timeout = inflight = 0;
+       timeout = jiffies + ctx->sq_thread_idle;
        while (!kthread_should_park()) {
                unsigned int to_submit;
 
-               if (inflight) {
+               if (!list_empty(&ctx->poll_list)) {
                        unsigned nr_events = 0;
 
-                       if (ctx->flags & IORING_SETUP_IOPOLL) {
-                               /*
-                                * inflight is the count of the maximum possible
-                                * entries we submitted, but it can be smaller
-                                * if we dropped some of them. If we don't have
-                                * poll entries available, then we know that we
-                                * have nothing left to poll for. Reset the
-                                * inflight count to zero in that case.
-                                */
-                               mutex_lock(&ctx->uring_lock);
-                               if (!list_empty(&ctx->poll_list))
-                                       __io_iopoll_check(ctx, &nr_events, 0);
-                               else
-                                       inflight = 0;
-                               mutex_unlock(&ctx->uring_lock);
-                       } else {
-                               /*
-                                * Normal IO, just pretend everything completed.
-                                * We don't have to poll completions for that.
-                                */
-                               nr_events = inflight;
-                       }
-
-                       inflight -= nr_events;
-                       if (!inflight)
+                       mutex_lock(&ctx->uring_lock);
+                       if (!list_empty(&ctx->poll_list))
+                               io_iopoll_getevents(ctx, &nr_events, 0);
+                       else
                                timeout = jiffies + ctx->sq_thread_idle;
+                       mutex_unlock(&ctx->uring_lock);
                }
 
                to_submit = io_sqring_entries(ctx);
@@ -5064,34 +5122,47 @@ static int io_sq_thread(void *data)
                 */
                if (!to_submit || ret == -EBUSY) {
                        /*
+                        * Drop cur_mm before scheduling, we can't hold it for
+                        * long periods (or over schedule()). Do this before
+                        * adding ourselves to the waitqueue, as the unuse/drop
+                        * may sleep.
+                        */
+                       if (cur_mm) {
+                               unuse_mm(cur_mm);
+                               mmput(cur_mm);
+                               cur_mm = NULL;
+                       }
+
+                       /*
                         * We're polling. If we're within the defined idle
                         * period, then let us spin without work before going
                         * to sleep. The exception is if we got EBUSY doing
                         * more IO, we should wait for the application to
                         * reap events and wake us up.
                         */
-                       if (inflight ||
+                       if (!list_empty(&ctx->poll_list) ||
                            (!time_after(jiffies, timeout) && ret != -EBUSY &&
                            !percpu_ref_is_dying(&ctx->refs))) {
                                cond_resched();
                                continue;
                        }
 
+                       prepare_to_wait(&ctx->sqo_wait, &wait,
+                                               TASK_INTERRUPTIBLE);
+
                        /*
-                        * Drop cur_mm before scheduling, we can't hold it for
-                        * long periods (or over schedule()). Do this before
-                        * adding ourselves to the waitqueue, as the unuse/drop
-                        * may sleep.
+                        * While doing polled IO, before going to sleep, we need
+                        * to check if there are new reqs added to poll_list, it
+                        * is because reqs may have been punted to io worker and
+                        * will be added to poll_list later, hence check the
+                        * poll_list again.
                         */
-                       if (cur_mm) {
-                               unuse_mm(cur_mm);
-                               mmput(cur_mm);
-                               cur_mm = NULL;
+                       if ((ctx->flags & IORING_SETUP_IOPOLL) &&
+                           !list_empty_careful(&ctx->poll_list)) {
+                               finish_wait(&ctx->sqo_wait, &wait);
+                               continue;
                        }
 
-                       prepare_to_wait(&ctx->sqo_wait, &wait,
-                                               TASK_INTERRUPTIBLE);
-
                        /* Tell userspace we may need a wakeup call */
                        ctx->rings->sq_flags |= IORING_SQ_NEED_WAKEUP;
                        /* make sure to read SQ tail after writing flags */
@@ -5119,8 +5190,7 @@ static int io_sq_thread(void *data)
                mutex_lock(&ctx->uring_lock);
                ret = io_submit_sqes(ctx, to_submit, NULL, -1, &cur_mm, true);
                mutex_unlock(&ctx->uring_lock);
-               if (ret > 0)
-                       inflight += ret;
+               timeout = jiffies + ctx->sq_thread_idle;
        }
 
        set_fs(old_fs);
@@ -6264,6 +6334,7 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
        io_sqe_buffer_unregister(ctx);
        io_sqe_files_unregister(ctx);
        io_eventfd_unregister(ctx);
+       idr_destroy(&ctx->personality_idr);
 
 #if defined(CONFIG_UNIX)
        if (ctx->ring_sock) {
@@ -6301,7 +6372,7 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait)
        if (READ_ONCE(ctx->rings->sq.tail) - ctx->cached_sq_head !=
            ctx->rings->sq_ring_entries)
                mask |= EPOLLOUT | EPOLLWRNORM;
-       if (READ_ONCE(ctx->rings->cq.head) != ctx->cached_cq_tail)
+       if (io_cqring_events(ctx, false))
                mask |= EPOLLIN | EPOLLRDNORM;
 
        return mask;
@@ -6393,6 +6464,29 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
                if (!cancel_req)
                        break;
 
+               if (cancel_req->flags & REQ_F_OVERFLOW) {
+                       spin_lock_irq(&ctx->completion_lock);
+                       list_del(&cancel_req->list);
+                       cancel_req->flags &= ~REQ_F_OVERFLOW;
+                       if (list_empty(&ctx->cq_overflow_list)) {
+                               clear_bit(0, &ctx->sq_check_overflow);
+                               clear_bit(0, &ctx->cq_check_overflow);
+                       }
+                       spin_unlock_irq(&ctx->completion_lock);
+
+                       WRITE_ONCE(ctx->rings->cq_overflow,
+                               atomic_inc_return(&ctx->cached_cq_overflow));
+
+                       /*
+                        * Put inflight ref and overflow ref. If that's
+                        * all we had, then we're done with this request.
+                        */
+                       if (refcount_sub_and_test(2, &cancel_req->refs)) {
+                               io_put_req(cancel_req);
+                               continue;
+                       }
+               }
+
                io_wq_cancel_work(ctx->io_wq, &cancel_req->work);
                io_put_req(cancel_req);
                schedule();
@@ -6405,6 +6499,13 @@ static int io_uring_flush(struct file *file, void *data)
        struct io_ring_ctx *ctx = file->private_data;
 
        io_uring_cancel_files(ctx, data);
+
+       /*
+        * If the task is going away, cancel work it may have pending
+        */
+       if (fatal_signal_pending(current) || (current->flags & PF_EXITING))
+               io_wq_cancel_pid(ctx->io_wq, task_pid_vnr(current));
+
        return 0;
 }