OSDN Git Service

io_uring: pick up link work on submit reference drop
[tomoyo/tomoyo-test1.git] / fs / io_uring.c
index de650df..f79ca49 100644 (file)
@@ -1483,10 +1483,10 @@ static void io_free_req(struct io_kiocb *req)
 __attribute__((nonnull))
 static void io_put_req_find_next(struct io_kiocb *req, struct io_kiocb **nxtptr)
 {
-       io_req_find_next(req, nxtptr);
-
-       if (refcount_dec_and_test(&req->refs))
+       if (refcount_dec_and_test(&req->refs)) {
+               io_req_find_next(req, nxtptr);
                __io_free_req(req);
+       }
 }
 
 static void io_put_req(struct io_kiocb *req)
@@ -1821,6 +1821,10 @@ static void io_iopoll_req_issued(struct io_kiocb *req)
                list_add(&req->list, &ctx->poll_list);
        else
                list_add_tail(&req->list, &ctx->poll_list);
+
+       if ((ctx->flags & IORING_SETUP_SQPOLL) &&
+           wq_has_sleeper(&ctx->sqo_wait))
+               wake_up(&ctx->sqo_wait);
 }
 
 static void io_file_put(struct io_submit_state *state)
@@ -4705,11 +4709,21 @@ static void __io_queue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
        struct io_kiocb *linked_timeout;
        struct io_kiocb *nxt = NULL;
+       const struct cred *old_creds = NULL;
        int ret;
 
 again:
        linked_timeout = io_prep_linked_timeout(req);
 
+       if (req->work.creds && req->work.creds != current_cred()) {
+               if (old_creds)
+                       revert_creds(old_creds);
+               if (old_creds == req->work.creds)
+                       old_creds = NULL; /* restored original creds */
+               else
+                       old_creds = override_creds(req->work.creds);
+       }
+
        ret = io_issue_sqe(req, sqe, &nxt, true);
 
        /*
@@ -4735,7 +4749,7 @@ punt:
 
 err:
        /* drop submission reference */
-       io_put_req(req);
+       io_put_req_find_next(req, &nxt);
 
        if (linked_timeout) {
                if (!ret)
@@ -4759,6 +4773,8 @@ done_req:
                        goto punt;
                goto again;
        }
+       if (old_creds)
+               revert_creds(old_creds);
 }
 
 static void io_queue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe)
@@ -4803,7 +4819,6 @@ static inline void io_queue_link_head(struct io_kiocb *req)
 static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                          struct io_submit_state *state, struct io_kiocb **link)
 {
-       const struct cred *old_creds = NULL;
        struct io_ring_ctx *ctx = req->ctx;
        unsigned int sqe_flags;
        int ret, id;
@@ -4818,14 +4833,12 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 
        id = READ_ONCE(sqe->personality);
        if (id) {
-               const struct cred *personality_creds;
-
-               personality_creds = idr_find(&ctx->personality_idr, id);
-               if (unlikely(!personality_creds)) {
+               req->work.creds = idr_find(&ctx->personality_idr, id);
+               if (unlikely(!req->work.creds)) {
                        ret = -EINVAL;
                        goto err_req;
                }
-               old_creds = override_creds(personality_creds);
+               get_cred(req->work.creds);
        }
 
        /* same numerical values with corresponding REQ_F_*, safe to copy */
@@ -4837,8 +4850,6 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 err_req:
                io_cqring_add_event(req, ret);
                io_double_put_req(req);
-               if (old_creds)
-                       revert_creds(old_creds);
                return false;
        }
 
@@ -4899,8 +4910,6 @@ err_req:
                }
        }
 
-       if (old_creds)
-               revert_creds(old_creds);
        return true;
 }
 
@@ -5081,9 +5090,8 @@ static int io_sq_thread(void *data)
        const struct cred *old_cred;
        mm_segment_t old_fs;
        DEFINE_WAIT(wait);
-       unsigned inflight;
        unsigned long timeout;
-       int ret;
+       int ret = 0;
 
        complete(&ctx->completions[1]);
 
@@ -5091,39 +5099,19 @@ static int io_sq_thread(void *data)
        set_fs(USER_DS);
        old_cred = override_creds(ctx->creds);
 
-       ret = timeout = inflight = 0;
+       timeout = jiffies + ctx->sq_thread_idle;
        while (!kthread_should_park()) {
                unsigned int to_submit;
 
-               if (inflight) {
+               if (!list_empty(&ctx->poll_list)) {
                        unsigned nr_events = 0;
 
-                       if (ctx->flags & IORING_SETUP_IOPOLL) {
-                               /*
-                                * inflight is the count of the maximum possible
-                                * entries we submitted, but it can be smaller
-                                * if we dropped some of them. If we don't have
-                                * poll entries available, then we know that we
-                                * have nothing left to poll for. Reset the
-                                * inflight count to zero in that case.
-                                */
-                               mutex_lock(&ctx->uring_lock);
-                               if (!list_empty(&ctx->poll_list))
-                                       io_iopoll_getevents(ctx, &nr_events, 0);
-                               else
-                                       inflight = 0;
-                               mutex_unlock(&ctx->uring_lock);
-                       } else {
-                               /*
-                                * Normal IO, just pretend everything completed.
-                                * We don't have to poll completions for that.
-                                */
-                               nr_events = inflight;
-                       }
-
-                       inflight -= nr_events;
-                       if (!inflight)
+                       mutex_lock(&ctx->uring_lock);
+                       if (!list_empty(&ctx->poll_list))
+                               io_iopoll_getevents(ctx, &nr_events, 0);
+                       else
                                timeout = jiffies + ctx->sq_thread_idle;
+                       mutex_unlock(&ctx->uring_lock);
                }
 
                to_submit = io_sqring_entries(ctx);
@@ -5152,7 +5140,7 @@ static int io_sq_thread(void *data)
                         * more IO, we should wait for the application to
                         * reap events and wake us up.
                         */
-                       if (inflight ||
+                       if (!list_empty(&ctx->poll_list) ||
                            (!time_after(jiffies, timeout) && ret != -EBUSY &&
                            !percpu_ref_is_dying(&ctx->refs))) {
                                cond_resched();
@@ -5162,6 +5150,19 @@ static int io_sq_thread(void *data)
                        prepare_to_wait(&ctx->sqo_wait, &wait,
                                                TASK_INTERRUPTIBLE);
 
+                       /*
+                        * While doing polled IO, before going to sleep, we need
+                        * to check if there are new reqs added to poll_list, it
+                        * is because reqs may have been punted to io worker and
+                        * will be added to poll_list later, hence check the
+                        * poll_list again.
+                        */
+                       if ((ctx->flags & IORING_SETUP_IOPOLL) &&
+                           !list_empty_careful(&ctx->poll_list)) {
+                               finish_wait(&ctx->sqo_wait, &wait);
+                               continue;
+                       }
+
                        /* Tell userspace we may need a wakeup call */
                        ctx->rings->sq_flags |= IORING_SQ_NEED_WAKEUP;
                        /* make sure to read SQ tail after writing flags */
@@ -5189,8 +5190,7 @@ static int io_sq_thread(void *data)
                mutex_lock(&ctx->uring_lock);
                ret = io_submit_sqes(ctx, to_submit, NULL, -1, &cur_mm, true);
                mutex_unlock(&ctx->uring_lock);
-               if (ret > 0)
-                       inflight += ret;
+               timeout = jiffies + ctx->sq_thread_idle;
        }
 
        set_fs(old_fs);
@@ -6334,6 +6334,7 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
        io_sqe_buffer_unregister(ctx);
        io_sqe_files_unregister(ctx);
        io_eventfd_unregister(ctx);
+       idr_destroy(&ctx->personality_idr);
 
 #if defined(CONFIG_UNIX)
        if (ctx->ring_sock) {