diff --git a/drivers/misc/cxl/cxllib.c b/drivers/misc/cxl/cxllib.c index bea1eb004b49..0bc7c31cf739 100644 --- a/drivers/misc/cxl/cxllib.c +++ b/drivers/misc/cxl/cxllib.c @@ -208,49 +208,74 @@ int cxllib_get_PE_attributes(struct task_struct *task, } EXPORT_SYMBOL_GPL(cxllib_get_PE_attributes); -int cxllib_handle_fault(struct mm_struct *mm, u64 addr, u64 size, u64 flags) +static int get_vma_info(struct mm_struct *mm, u64 addr, + u64 *vma_start, u64 *vma_end, + unsigned long *page_size) { - int rc; - u64 dar; struct vm_area_struct *vma = NULL; - unsigned long page_size; - - if (mm == NULL) - return -EFAULT; + int rc = 0; down_read(&mm->mmap_sem); vma = find_vma(mm, addr); if (!vma) { - pr_err("Can't find vma for addr %016llx\n", addr); rc = -EFAULT; goto out; } - /* get the size of the pages allocated */ - page_size = vma_kernel_pagesize(vma); - - for (dar = (addr & ~(page_size - 1)); dar < (addr + size); dar += page_size) { - if (dar < vma->vm_start || dar >= vma->vm_end) { - vma = find_vma(mm, addr); - if (!vma) { - pr_err("Can't find vma for addr %016llx\n", addr); - rc = -EFAULT; - goto out; - } - /* get the size of the pages allocated */ - page_size = vma_kernel_pagesize(vma); - } - - rc = cxl_handle_mm_fault(mm, flags, dar); - if (rc) { - pr_err("cxl_handle_mm_fault failed %d", rc); - rc = -EFAULT; - goto out; - } - } - rc = 0; + *page_size = vma_kernel_pagesize(vma); + *vma_start = vma->vm_start; + *vma_end = vma->vm_end; out: up_read(&mm->mmap_sem); return rc; } + +int cxllib_handle_fault(struct mm_struct *mm, u64 addr, u64 size, u64 flags) +{ + int rc; + u64 dar, vma_start, vma_end; + unsigned long page_size; + + if (mm == NULL) + return -EFAULT; + + /* + * The buffer we have to process can extend over several pages + * and may also cover several VMAs. + * We iterate over all the pages. The page size could vary + * between VMAs. + */ + rc = get_vma_info(mm, addr, &vma_start, &vma_end, &page_size); + if (rc) + return rc; + + for (dar = (addr & ~(page_size - 1)); dar < (addr + size); + dar += page_size) { + if (dar < vma_start || dar >= vma_end) { + /* + * We don't hold the mm->mmap_sem semaphore + * while iterating, since the semaphore is + * required by one of the lower-level page + * fault processing functions and it could + * create a deadlock. + * + * It means the VMAs can be altered between 2 + * loop iterations and we could theoretically + * miss a page (however unlikely). But that's + * not really a problem, as the driver will + * retry access, get another page fault on the + * missing page and call us again. + */ + rc = get_vma_info(mm, dar, &vma_start, &vma_end, + &page_size); + if (rc) + return rc; + } + + rc = cxl_handle_mm_fault(mm, flags, dar); + if (rc) + return -EFAULT; + } + return 0; +} EXPORT_SYMBOL_GPL(cxllib_handle_fault);