OSDN Git Service

io_uring: pick up link work on submit reference drop
[tomoyo/tomoyo-test1.git] / fs / io_uring.c
index 0977017..f79ca49 100644 (file)
@@ -442,6 +442,7 @@ struct io_async_msghdr {
        struct iovec                    *iov;
        struct sockaddr __user          *uaddr;
        struct msghdr                   msg;
+       struct sockaddr_storage         addr;
 };
 
 struct io_async_rw {
@@ -480,6 +481,7 @@ enum {
        REQ_F_TIMEOUT_NOSEQ_BIT,
        REQ_F_COMP_LOCKED_BIT,
        REQ_F_NEED_CLEANUP_BIT,
+       REQ_F_OVERFLOW_BIT,
 };
 
 enum {
@@ -520,6 +522,8 @@ enum {
        REQ_F_COMP_LOCKED       = BIT(REQ_F_COMP_LOCKED_BIT),
        /* needs cleanup */
        REQ_F_NEED_CLEANUP      = BIT(REQ_F_NEED_CLEANUP_BIT),
+       /* in overflow list */
+       REQ_F_OVERFLOW          = BIT(REQ_F_OVERFLOW_BIT),
 };
 
 /*
@@ -925,6 +929,8 @@ static inline void io_req_work_grab_env(struct io_kiocb *req,
                }
                spin_unlock(&current->fs->lock);
        }
+       if (!req->work.task_pid)
+               req->work.task_pid = task_pid_vnr(current);
 }
 
 static inline void io_req_work_drop_env(struct io_kiocb *req)
@@ -1100,6 +1106,7 @@ static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
                req = list_first_entry(&ctx->cq_overflow_list, struct io_kiocb,
                                                list);
                list_move(&req->list, &list);
+               req->flags &= ~REQ_F_OVERFLOW;
                if (cqe) {
                        WRITE_ONCE(cqe->user_data, req->user_data);
                        WRITE_ONCE(cqe->res, req->result);
@@ -1152,6 +1159,7 @@ static void io_cqring_fill_event(struct io_kiocb *req, long res)
                        set_bit(0, &ctx->sq_check_overflow);
                        set_bit(0, &ctx->cq_check_overflow);
                }
+               req->flags |= REQ_F_OVERFLOW;
                refcount_inc(&req->refs);
                req->result = res;
                list_add_tail(&req->list, &ctx->cq_overflow_list);
@@ -1252,6 +1260,9 @@ static void __io_req_aux_free(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               io_cleanup_req(req);
+
        kfree(req->io);
        if (req->file) {
                if (req->flags & REQ_F_FIXED_FILE)
@@ -1267,9 +1278,6 @@ static void __io_free_req(struct io_kiocb *req)
 {
        __io_req_aux_free(req);
 
-       if (req->flags & REQ_F_NEED_CLEANUP)
-               io_cleanup_req(req);
-
        if (req->flags & REQ_F_INFLIGHT) {
                struct io_ring_ctx *ctx = req->ctx;
                unsigned long flags;
@@ -1475,10 +1483,10 @@ static void io_free_req(struct io_kiocb *req)
 __attribute__((nonnull))
 static void io_put_req_find_next(struct io_kiocb *req, struct io_kiocb **nxtptr)
 {
-       io_req_find_next(req, nxtptr);
-
-       if (refcount_dec_and_test(&req->refs))
+       if (refcount_dec_and_test(&req->refs)) {
+               io_req_find_next(req, nxtptr);
                __io_free_req(req);
+       }
 }
 
 static void io_put_req(struct io_kiocb *req)
@@ -1664,11 +1672,17 @@ static void io_iopoll_reap_events(struct io_ring_ctx *ctx)
        mutex_unlock(&ctx->uring_lock);
 }
 
-static int __io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events,
-                           long min)
+static int io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events,
+                          long min)
 {
        int iters = 0, ret = 0;
 
+       /*
+        * We disallow the app entering submit/complete with polling, but we
+        * still need to lock the ring to prevent racing with polled issue
+        * that got punted to a workqueue.
+        */
+       mutex_lock(&ctx->uring_lock);
        do {
                int tmin = 0;
 
@@ -1704,21 +1718,6 @@ static int __io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events,
                ret = 0;
        } while (min && !*nr_events && !need_resched());
 
-       return ret;
-}
-
-static int io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events,
-                          long min)
-{
-       int ret;
-
-       /*
-        * We disallow the app entering submit/complete with polling, but we
-        * still need to lock the ring to prevent racing with polled issue
-        * that got punted to a workqueue.
-        */
-       mutex_lock(&ctx->uring_lock);
-       ret = __io_iopoll_check(ctx, nr_events, min);
        mutex_unlock(&ctx->uring_lock);
        return ret;
 }
@@ -1822,6 +1821,10 @@ static void io_iopoll_req_issued(struct io_kiocb *req)
                list_add(&req->list, &ctx->poll_list);
        else
                list_add_tail(&req->list, &ctx->poll_list);
+
+       if ((ctx->flags & IORING_SETUP_SQPOLL) &&
+           wq_has_sleeper(&ctx->sqo_wait))
+               wake_up(&ctx->sqo_wait);
 }
 
 static void io_file_put(struct io_submit_state *state)
@@ -2509,6 +2512,9 @@ static void io_fallocate_finish(struct io_wq_work **workptr)
        struct io_kiocb *nxt = NULL;
        int ret;
 
+       if (io_req_cancelled(req))
+               return;
+
        ret = vfs_fallocate(req->file, req->sync.mode, req->sync.off,
                                req->sync.len);
        if (ret < 0)
@@ -2560,6 +2566,8 @@ static int io_openat_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                return -EINVAL;
        if (sqe->flags & IOSQE_FIXED_FILE)
                return -EBADF;
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               return 0;
 
        req->open.dfd = READ_ONCE(sqe->fd);
        req->open.how.mode = READ_ONCE(sqe->len);
@@ -2588,6 +2596,8 @@ static int io_openat2_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                return -EINVAL;
        if (sqe->flags & IOSQE_FIXED_FILE)
                return -EBADF;
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               return 0;
 
        req->open.dfd = READ_ONCE(sqe->fd);
        fname = u64_to_user_ptr(READ_ONCE(sqe->addr));
@@ -2787,6 +2797,8 @@ static int io_statx_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                return -EINVAL;
        if (sqe->flags & IOSQE_FIXED_FILE)
                return -EBADF;
+       if (req->flags & REQ_F_NEED_CLEANUP)
+               return 0;
 
        req->open.dfd = READ_ONCE(sqe->fd);
        req->open.mask = READ_ONCE(sqe->len);
@@ -2890,6 +2902,7 @@ static void io_close_finish(struct io_wq_work **workptr)
        struct io_kiocb *req = container_of(*workptr, struct io_kiocb, work);
        struct io_kiocb *nxt = NULL;
 
+       /* not cancellable, don't do io_req_cancelled() */
        __io_close_finish(req, &nxt);
        if (nxt)
                io_wq_assign_next(workptr, nxt);
@@ -3024,12 +3037,11 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
        sock = sock_from_file(req->file, &ret);
        if (sock) {
                struct io_async_ctx io;
-               struct sockaddr_storage addr;
                unsigned flags;
 
                if (req->io) {
                        kmsg = &req->io->msg;
-                       kmsg->msg.msg_name = &addr;
+                       kmsg->msg.msg_name = &req->io->msg.addr;
                        /* if iov is set, it's allocated already */
                        if (!kmsg->iov)
                                kmsg->iov = kmsg->fast_iov;
@@ -3038,7 +3050,7 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                        struct io_sr_msg *sr = &req->sr_msg;
 
                        kmsg = &io.msg;
-                       kmsg->msg.msg_name = &addr;
+                       kmsg->msg.msg_name = &io.msg.addr;
 
                        io.msg.iov = io.msg.fast_iov;
                        ret = sendmsg_copy_msghdr(&io.msg.msg, sr->msg,
@@ -3058,7 +3070,7 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                        if (req->io)
                                return -EAGAIN;
                        if (io_alloc_async_ctx(req)) {
-                               if (kmsg && kmsg->iov != kmsg->fast_iov)
+                               if (kmsg->iov != kmsg->fast_iov)
                                        kfree(kmsg->iov);
                                return -ENOMEM;
                        }
@@ -3177,12 +3189,11 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
        sock = sock_from_file(req->file, &ret);
        if (sock) {
                struct io_async_ctx io;
-               struct sockaddr_storage addr;
                unsigned flags;
 
                if (req->io) {
                        kmsg = &req->io->msg;
-                       kmsg->msg.msg_name = &addr;
+                       kmsg->msg.msg_name = &req->io->msg.addr;
                        /* if iov is set, it's allocated already */
                        if (!kmsg->iov)
                                kmsg->iov = kmsg->fast_iov;
@@ -3191,7 +3202,7 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                        struct io_sr_msg *sr = &req->sr_msg;
 
                        kmsg = &io.msg;
-                       kmsg->msg.msg_name = &addr;
+                       kmsg->msg.msg_name = &io.msg.addr;
 
                        io.msg.iov = io.msg.fast_iov;
                        ret = recvmsg_copy_msghdr(&io.msg.msg, sr->msg,
@@ -3213,7 +3224,7 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                        if (req->io)
                                return -EAGAIN;
                        if (io_alloc_async_ctx(req)) {
-                               if (kmsg && kmsg->iov != kmsg->fast_iov)
+                               if (kmsg->iov != kmsg->fast_iov)
                                        kfree(kmsg->iov);
                                return -ENOMEM;
                        }
@@ -4698,11 +4709,21 @@ static void __io_queue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
        struct io_kiocb *linked_timeout;
        struct io_kiocb *nxt = NULL;
+       const struct cred *old_creds = NULL;
        int ret;
 
 again:
        linked_timeout = io_prep_linked_timeout(req);
 
+       if (req->work.creds && req->work.creds != current_cred()) {
+               if (old_creds)
+                       revert_creds(old_creds);
+               if (old_creds == req->work.creds)
+                       old_creds = NULL; /* restored original creds */
+               else
+                       old_creds = override_creds(req->work.creds);
+       }
+
        ret = io_issue_sqe(req, sqe, &nxt, true);
 
        /*
@@ -4728,7 +4749,7 @@ punt:
 
 err:
        /* drop submission reference */
-       io_put_req(req);
+       io_put_req_find_next(req, &nxt);
 
        if (linked_timeout) {
                if (!ret)
@@ -4752,6 +4773,8 @@ done_req:
                        goto punt;
                goto again;
        }
+       if (old_creds)
+               revert_creds(old_creds);
 }
 
 static void io_queue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe)
@@ -4796,7 +4819,6 @@ static inline void io_queue_link_head(struct io_kiocb *req)
 static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                          struct io_submit_state *state, struct io_kiocb **link)
 {
-       const struct cred *old_creds = NULL;
        struct io_ring_ctx *ctx = req->ctx;
        unsigned int sqe_flags;
        int ret, id;
@@ -4811,14 +4833,12 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 
        id = READ_ONCE(sqe->personality);
        if (id) {
-               const struct cred *personality_creds;
-
-               personality_creds = idr_find(&ctx->personality_idr, id);
-               if (unlikely(!personality_creds)) {
+               req->work.creds = idr_find(&ctx->personality_idr, id);
+               if (unlikely(!req->work.creds)) {
                        ret = -EINVAL;
                        goto err_req;
                }
-               old_creds = override_creds(personality_creds);
+               get_cred(req->work.creds);
        }
 
        /* same numerical values with corresponding REQ_F_*, safe to copy */
@@ -4830,8 +4850,6 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 err_req:
                io_cqring_add_event(req, ret);
                io_double_put_req(req);
-               if (old_creds)
-                       revert_creds(old_creds);
                return false;
        }
 
@@ -4892,8 +4910,6 @@ err_req:
                }
        }
 
-       if (old_creds)
-               revert_creds(old_creds);
        return true;
 }
 
@@ -5074,9 +5090,8 @@ static int io_sq_thread(void *data)
        const struct cred *old_cred;
        mm_segment_t old_fs;
        DEFINE_WAIT(wait);
-       unsigned inflight;
        unsigned long timeout;
-       int ret;
+       int ret = 0;
 
        complete(&ctx->completions[1]);
 
@@ -5084,39 +5099,19 @@ static int io_sq_thread(void *data)
        set_fs(USER_DS);
        old_cred = override_creds(ctx->creds);
 
-       ret = timeout = inflight = 0;
+       timeout = jiffies + ctx->sq_thread_idle;
        while (!kthread_should_park()) {
                unsigned int to_submit;
 
-               if (inflight) {
+               if (!list_empty(&ctx->poll_list)) {
                        unsigned nr_events = 0;
 
-                       if (ctx->flags & IORING_SETUP_IOPOLL) {
-                               /*
-                                * inflight is the count of the maximum possible
-                                * entries we submitted, but it can be smaller
-                                * if we dropped some of them. If we don't have
-                                * poll entries available, then we know that we
-                                * have nothing left to poll for. Reset the
-                                * inflight count to zero in that case.
-                                */
-                               mutex_lock(&ctx->uring_lock);
-                               if (!list_empty(&ctx->poll_list))
-                                       __io_iopoll_check(ctx, &nr_events, 0);
-                               else
-                                       inflight = 0;
-                               mutex_unlock(&ctx->uring_lock);
-                       } else {
-                               /*
-                                * Normal IO, just pretend everything completed.
-                                * We don't have to poll completions for that.
-                                */
-                               nr_events = inflight;
-                       }
-
-                       inflight -= nr_events;
-                       if (!inflight)
+                       mutex_lock(&ctx->uring_lock);
+                       if (!list_empty(&ctx->poll_list))
+                               io_iopoll_getevents(ctx, &nr_events, 0);
+                       else
                                timeout = jiffies + ctx->sq_thread_idle;
+                       mutex_unlock(&ctx->uring_lock);
                }
 
                to_submit = io_sqring_entries(ctx);
@@ -5127,34 +5122,47 @@ static int io_sq_thread(void *data)
                 */
                if (!to_submit || ret == -EBUSY) {
                        /*
+                        * Drop cur_mm before scheduling, we can't hold it for
+                        * long periods (or over schedule()). Do this before
+                        * adding ourselves to the waitqueue, as the unuse/drop
+                        * may sleep.
+                        */
+                       if (cur_mm) {
+                               unuse_mm(cur_mm);
+                               mmput(cur_mm);
+                               cur_mm = NULL;
+                       }
+
+                       /*
                         * We're polling. If we're within the defined idle
                         * period, then let us spin without work before going
                         * to sleep. The exception is if we got EBUSY doing
                         * more IO, we should wait for the application to
                         * reap events and wake us up.
                         */
-                       if (inflight ||
+                       if (!list_empty(&ctx->poll_list) ||
                            (!time_after(jiffies, timeout) && ret != -EBUSY &&
                            !percpu_ref_is_dying(&ctx->refs))) {
                                cond_resched();
                                continue;
                        }
 
+                       prepare_to_wait(&ctx->sqo_wait, &wait,
+                                               TASK_INTERRUPTIBLE);
+
                        /*
-                        * Drop cur_mm before scheduling, we can't hold it for
-                        * long periods (or over schedule()). Do this before
-                        * adding ourselves to the waitqueue, as the unuse/drop
-                        * may sleep.
+                        * While doing polled IO, before going to sleep, we need
+                        * to check if there are new reqs added to poll_list, it
+                        * is because reqs may have been punted to io worker and
+                        * will be added to poll_list later, hence check the
+                        * poll_list again.
                         */
-                       if (cur_mm) {
-                               unuse_mm(cur_mm);
-                               mmput(cur_mm);
-                               cur_mm = NULL;
+                       if ((ctx->flags & IORING_SETUP_IOPOLL) &&
+                           !list_empty_careful(&ctx->poll_list)) {
+                               finish_wait(&ctx->sqo_wait, &wait);
+                               continue;
                        }
 
-                       prepare_to_wait(&ctx->sqo_wait, &wait,
-                                               TASK_INTERRUPTIBLE);
-
                        /* Tell userspace we may need a wakeup call */
                        ctx->rings->sq_flags |= IORING_SQ_NEED_WAKEUP;
                        /* make sure to read SQ tail after writing flags */
@@ -5182,8 +5190,7 @@ static int io_sq_thread(void *data)
                mutex_lock(&ctx->uring_lock);
                ret = io_submit_sqes(ctx, to_submit, NULL, -1, &cur_mm, true);
                mutex_unlock(&ctx->uring_lock);
-               if (ret > 0)
-                       inflight += ret;
+               timeout = jiffies + ctx->sq_thread_idle;
        }
 
        set_fs(old_fs);
@@ -6327,6 +6334,7 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
        io_sqe_buffer_unregister(ctx);
        io_sqe_files_unregister(ctx);
        io_eventfd_unregister(ctx);
+       idr_destroy(&ctx->personality_idr);
 
 #if defined(CONFIG_UNIX)
        if (ctx->ring_sock) {
@@ -6456,6 +6464,29 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
                if (!cancel_req)
                        break;
 
+               if (cancel_req->flags & REQ_F_OVERFLOW) {
+                       spin_lock_irq(&ctx->completion_lock);
+                       list_del(&cancel_req->list);
+                       cancel_req->flags &= ~REQ_F_OVERFLOW;
+                       if (list_empty(&ctx->cq_overflow_list)) {
+                               clear_bit(0, &ctx->sq_check_overflow);
+                               clear_bit(0, &ctx->cq_check_overflow);
+                       }
+                       spin_unlock_irq(&ctx->completion_lock);
+
+                       WRITE_ONCE(ctx->rings->cq_overflow,
+                               atomic_inc_return(&ctx->cached_cq_overflow));
+
+                       /*
+                        * Put inflight ref and overflow ref. If that's
+                        * all we had, then we're done with this request.
+                        */
+                       if (refcount_sub_and_test(2, &cancel_req->refs)) {
+                               io_put_req(cancel_req);
+                               continue;
+                       }
+               }
+
                io_wq_cancel_work(ctx->io_wq, &cancel_req->work);
                io_put_req(cancel_req);
                schedule();
@@ -6468,6 +6499,13 @@ static int io_uring_flush(struct file *file, void *data)
        struct io_ring_ctx *ctx = file->private_data;
 
        io_uring_cancel_files(ctx, data);
+
+       /*
+        * If the task is going away, cancel work it may have pending
+        */
+       if (fatal_signal_pending(current) || (current->flags & PF_EXITING))
+               io_wq_cancel_pid(ctx->io_wq, task_pid_vnr(current));
+
        return 0;
 }