diff --git a/drivers/infiniband/core/uverbs_cmd.c b/drivers/infiniband/core/uverbs_cmd.c index b81307b625a6..8b6df7cec0bf 100644 --- a/drivers/infiniband/core/uverbs_cmd.c +++ b/drivers/infiniband/core/uverbs_cmd.c @@ -155,7 +155,7 @@ static struct ib_uobject *__idr_get_uobj(struct idr *idr, int id, } static struct ib_uobject *idr_read_uobj(struct idr *idr, int id, - struct ib_ucontext *context) + struct ib_ucontext *context, int nested) { struct ib_uobject *uobj; @@ -163,7 +163,10 @@ static struct ib_uobject *idr_read_uobj(struct idr *idr, int id, if (!uobj) return NULL; - down_read(&uobj->mutex); + if (nested) + down_read_nested(&uobj->mutex, SINGLE_DEPTH_NESTING); + else + down_read(&uobj->mutex); if (!uobj->live) { put_uobj_read(uobj); return NULL; @@ -190,17 +193,18 @@ static struct ib_uobject *idr_write_uobj(struct idr *idr, int id, return uobj; } -static void *idr_read_obj(struct idr *idr, int id, struct ib_ucontext *context) +static void *idr_read_obj(struct idr *idr, int id, struct ib_ucontext *context, + int nested) { struct ib_uobject *uobj; - uobj = idr_read_uobj(idr, id, context); + uobj = idr_read_uobj(idr, id, context, nested); return uobj ? uobj->object : NULL; } static struct ib_pd *idr_read_pd(int pd_handle, struct ib_ucontext *context) { - return idr_read_obj(&ib_uverbs_pd_idr, pd_handle, context); + return idr_read_obj(&ib_uverbs_pd_idr, pd_handle, context, 0); } static void put_pd_read(struct ib_pd *pd) @@ -208,9 +212,9 @@ static void put_pd_read(struct ib_pd *pd) put_uobj_read(pd->uobject); } -static struct ib_cq *idr_read_cq(int cq_handle, struct ib_ucontext *context) +static struct ib_cq *idr_read_cq(int cq_handle, struct ib_ucontext *context, int nested) { - return idr_read_obj(&ib_uverbs_cq_idr, cq_handle, context); + return idr_read_obj(&ib_uverbs_cq_idr, cq_handle, context, nested); } static void put_cq_read(struct ib_cq *cq) @@ -220,7 +224,7 @@ static void put_cq_read(struct ib_cq *cq) static struct ib_ah *idr_read_ah(int ah_handle, struct ib_ucontext *context) { - return idr_read_obj(&ib_uverbs_ah_idr, ah_handle, context); + return idr_read_obj(&ib_uverbs_ah_idr, ah_handle, context, 0); } static void put_ah_read(struct ib_ah *ah) @@ -230,7 +234,7 @@ static void put_ah_read(struct ib_ah *ah) static struct ib_qp *idr_read_qp(int qp_handle, struct ib_ucontext *context) { - return idr_read_obj(&ib_uverbs_qp_idr, qp_handle, context); + return idr_read_obj(&ib_uverbs_qp_idr, qp_handle, context, 0); } static void put_qp_read(struct ib_qp *qp) @@ -240,7 +244,7 @@ static void put_qp_read(struct ib_qp *qp) static struct ib_srq *idr_read_srq(int srq_handle, struct ib_ucontext *context) { - return idr_read_obj(&ib_uverbs_srq_idr, srq_handle, context); + return idr_read_obj(&ib_uverbs_srq_idr, srq_handle, context, 0); } static void put_srq_read(struct ib_srq *srq) @@ -867,7 +871,7 @@ ssize_t ib_uverbs_resize_cq(struct ib_uverbs_file *file, (unsigned long) cmd.response + sizeof resp, in_len - sizeof cmd, out_len - sizeof resp); - cq = idr_read_cq(cmd.cq_handle, file->ucontext); + cq = idr_read_cq(cmd.cq_handle, file->ucontext, 0); if (!cq) return -EINVAL; @@ -914,7 +918,7 @@ ssize_t ib_uverbs_poll_cq(struct ib_uverbs_file *file, goto out_wc; } - cq = idr_read_cq(cmd.cq_handle, file->ucontext); + cq = idr_read_cq(cmd.cq_handle, file->ucontext, 0); if (!cq) { ret = -EINVAL; goto out; @@ -962,7 +966,7 @@ ssize_t ib_uverbs_req_notify_cq(struct ib_uverbs_file *file, if (copy_from_user(&cmd, buf, sizeof cmd)) return -EFAULT; - cq = idr_read_cq(cmd.cq_handle, file->ucontext); + cq = idr_read_cq(cmd.cq_handle, file->ucontext, 0); if (!cq) return -EINVAL; @@ -1060,9 +1064,9 @@ ssize_t ib_uverbs_create_qp(struct ib_uverbs_file *file, srq = cmd.is_srq ? idr_read_srq(cmd.srq_handle, file->ucontext) : NULL; pd = idr_read_pd(cmd.pd_handle, file->ucontext); - scq = idr_read_cq(cmd.send_cq_handle, file->ucontext); + scq = idr_read_cq(cmd.send_cq_handle, file->ucontext, 0); rcq = cmd.recv_cq_handle == cmd.send_cq_handle ? - scq : idr_read_cq(cmd.recv_cq_handle, file->ucontext); + scq : idr_read_cq(cmd.recv_cq_handle, file->ucontext, 1); if (!pd || !scq || !rcq || (cmd.is_srq && !srq)) { ret = -EINVAL;