diff --git a/drivers/dax/dax.c b/drivers/dax/dax.c index 922ec461dcaa..b90bb301bda0 100644 --- a/drivers/dax/dax.c +++ b/drivers/dax/dax.c @@ -493,6 +493,51 @@ static int __dax_dev_pmd_fault(struct dax_dev *dax_dev, struct vm_fault *vmf) vmf->flags & FAULT_FLAG_WRITE); } +#ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD +static int __dax_dev_pud_fault(struct dax_dev *dax_dev, struct vm_fault *vmf) +{ + unsigned long pud_addr = vmf->address & PUD_MASK; + struct device *dev = &dax_dev->dev; + struct dax_region *dax_region; + phys_addr_t phys; + pgoff_t pgoff; + pfn_t pfn; + + if (check_vma(dax_dev, vmf->vma, __func__)) + return VM_FAULT_SIGBUS; + + dax_region = dax_dev->region; + if (dax_region->align > PUD_SIZE) { + dev_dbg(dev, "%s: alignment > fault size\n", __func__); + return VM_FAULT_SIGBUS; + } + + /* dax pud mappings require pfn_t_devmap() */ + if ((dax_region->pfn_flags & (PFN_DEV|PFN_MAP)) != (PFN_DEV|PFN_MAP)) { + dev_dbg(dev, "%s: alignment > fault size\n", __func__); + return VM_FAULT_SIGBUS; + } + + pgoff = linear_page_index(vmf->vma, pud_addr); + phys = pgoff_to_phys(dax_dev, pgoff, PUD_SIZE); + if (phys == -1) { + dev_dbg(dev, "%s: phys_to_pgoff(%#lx) failed\n", __func__, + pgoff); + return VM_FAULT_SIGBUS; + } + + pfn = phys_to_pfn_t(phys, dax_region->pfn_flags); + + return vmf_insert_pfn_pud(vmf->vma, vmf->address, vmf->pud, pfn, + vmf->flags & FAULT_FLAG_WRITE); +} +#else +static int __dax_dev_pud_fault(struct dax_dev *dax_dev, struct vm_fault *vmf) +{ + return VM_FAULT_FALLBACK; +} +#endif /* !CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD */ + static int dax_dev_fault(struct vm_fault *vmf) { int rc; @@ -512,6 +557,9 @@ static int dax_dev_fault(struct vm_fault *vmf) case FAULT_FLAG_SIZE_PMD: rc = __dax_dev_pmd_fault(dax_dev, vmf); break; + case FAULT_FLAG_SIZE_PUD: + rc = __dax_dev_pud_fault(dax_dev, vmf); + break; default: return VM_FAULT_FALLBACK; }