diff --git a/fs/io_uring.c b/fs/io_uring.c index ae91632b8bf9..7020c6a72231 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -1066,12 +1066,18 @@ static void io_init_identity(struct io_identity *id) */ static inline void io_req_init_async(struct io_kiocb *req) { + struct io_uring_task *tctx = current->io_uring; + if (req->flags & REQ_F_WORK_INITIALIZED) return; memset(&req->work, 0, sizeof(req->work)); req->flags |= REQ_F_WORK_INITIALIZED; - req->work.identity = ¤t->io_uring->identity; + + /* Grab a ref if this isn't our static identity */ + req->work.identity = tctx->identity; + if (tctx->identity != &tctx->__identity) + refcount_inc(&req->work.identity->count); } static inline bool io_async_submit(struct io_ring_ctx *ctx) @@ -1179,7 +1185,7 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx) static void io_put_identity(struct io_uring_task *tctx, struct io_kiocb *req) { - if (req->work.identity == &tctx->identity) + if (req->work.identity == &tctx->__identity) return; if (refcount_dec_and_test(&req->work.identity->count)) kfree(req->work.identity); @@ -1254,11 +1260,12 @@ static bool io_identity_cow(struct io_kiocb *req) refcount_inc(&id->count); /* drop old identity, assign new one. one ref for req, one for tctx */ - if (req->work.identity != &tctx->identity && + if (req->work.identity != tctx->identity && refcount_sub_and_test(2, &req->work.identity->count)) kfree(req->work.identity); req->work.identity = id; + tctx->identity = id; return true; } @@ -7691,7 +7698,8 @@ static int io_uring_alloc_task_context(struct task_struct *task) tctx->in_idle = 0; atomic_long_set(&tctx->req_issue, 0); atomic_long_set(&tctx->req_complete, 0); - io_init_identity(&tctx->identity); + io_init_identity(&tctx->__identity); + tctx->identity = &tctx->__identity; task->io_uring = tctx; return 0; } @@ -7701,6 +7709,9 @@ void __io_uring_free(struct task_struct *tsk) struct io_uring_task *tctx = tsk->io_uring; WARN_ON_ONCE(!xa_empty(&tctx->xa)); + WARN_ON_ONCE(refcount_read(&tctx->identity->count) != 1); + if (tctx->identity != &tctx->__identity) + kfree(tctx->identity); kfree(tctx); tsk->io_uring = NULL; } diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h index bd3346194bca..607d14f61132 100644 --- a/include/linux/io_uring.h +++ b/include/linux/io_uring.h @@ -24,7 +24,8 @@ struct io_uring_task { struct wait_queue_head wait; struct file *last; atomic_long_t req_issue; - struct io_identity identity; + struct io_identity __identity; + struct io_identity *identity; /* completion side */ bool in_idle ____cacheline_aligned_in_smp;