diff --git a/drivers/ntb/ntb_transport.c b/drivers/ntb/ntb_transport.c index 4eb8adb34508..9b791cb33c19 100644 --- a/drivers/ntb/ntb_transport.c +++ b/drivers/ntb/ntb_transport.c @@ -66,6 +66,7 @@ #define NTB_TRANSPORT_VER "4" #define NTB_TRANSPORT_NAME "ntb_transport" #define NTB_TRANSPORT_DESC "Software Queue-Pair Transport over NTB" +#define NTB_TRANSPORT_MIN_SPADS (MW0_SZ_HIGH + 2) MODULE_DESCRIPTION(NTB_TRANSPORT_DESC); MODULE_VERSION(NTB_TRANSPORT_VER); @@ -242,9 +243,6 @@ enum { NUM_MWS, MW0_SZ_HIGH, MW0_SZ_LOW, - MW1_SZ_HIGH, - MW1_SZ_LOW, - MAX_SPAD, }; #define dev_client_dev(__dev) \ @@ -811,7 +809,7 @@ static void ntb_transport_link_cleanup(struct ntb_transport_ctx *nt) { struct ntb_transport_qp *qp; u64 qp_bitmap_alloc; - int i; + unsigned int i, count; qp_bitmap_alloc = nt->qp_bitmap & ~nt->qp_bitmap_free; @@ -831,7 +829,8 @@ static void ntb_transport_link_cleanup(struct ntb_transport_ctx *nt) * goes down, blast them now to give them a sane value the next * time they are accessed */ - for (i = 0; i < MAX_SPAD; i++) + count = ntb_spad_count(nt->ndev); + for (i = 0; i < count; i++) ntb_spad_write(nt->ndev, i, 0); } @@ -1064,17 +1063,12 @@ static int ntb_transport_probe(struct ntb_client *self, struct ntb_dev *ndev) { struct ntb_transport_ctx *nt; struct ntb_transport_mw *mw; - unsigned int mw_count, qp_count; + unsigned int mw_count, qp_count, spad_count, max_mw_count_for_spads; u64 qp_bitmap; int node; int rc, i; mw_count = ntb_mw_count(ndev); - if (ntb_spad_count(ndev) < (NUM_MWS + 1 + mw_count * 2)) { - dev_err(&ndev->dev, "Not enough scratch pad registers for %s", - NTB_TRANSPORT_NAME); - return -EIO; - } if (ntb_db_is_unsafe(ndev)) dev_dbg(&ndev->dev, @@ -1090,8 +1084,18 @@ static int ntb_transport_probe(struct ntb_client *self, struct ntb_dev *ndev) return -ENOMEM; nt->ndev = ndev; + spad_count = ntb_spad_count(ndev); - nt->mw_count = mw_count; + /* Limit the MW's based on the availability of scratchpads */ + + if (spad_count < NTB_TRANSPORT_MIN_SPADS) { + nt->mw_count = 0; + rc = -EINVAL; + goto err; + } + + max_mw_count_for_spads = (spad_count - MW0_SZ_HIGH) / 2; + nt->mw_count = min(mw_count, max_mw_count_for_spads); nt->mw_vec = kzalloc_node(mw_count * sizeof(*nt->mw_vec), GFP_KERNEL, node);