diff --git a/drivers/vfio/pci/vfio_pci_intrs.c b/drivers/vfio/pci/vfio_pci_intrs.c index e8d695b3f54e..b134befbfd0e 100644 --- a/drivers/vfio/pci/vfio_pci_intrs.c +++ b/drivers/vfio/pci/vfio_pci_intrs.c @@ -763,46 +763,60 @@ static int vfio_pci_set_msi_trigger(struct vfio_pci_device *vdev, return 0; } -static int vfio_pci_set_err_trigger(struct vfio_pci_device *vdev, - unsigned index, unsigned start, - unsigned count, uint32_t flags, void *data) +static int vfio_pci_set_ctx_trigger_single(struct eventfd_ctx **ctx, + uint32_t flags, void *data) { int32_t fd = *(int32_t *)data; - if ((index != VFIO_PCI_ERR_IRQ_INDEX) || - !(flags & VFIO_IRQ_SET_DATA_TYPE_MASK)) + if (!(flags & VFIO_IRQ_SET_DATA_TYPE_MASK)) return -EINVAL; /* DATA_NONE/DATA_BOOL enables loopback testing */ if (flags & VFIO_IRQ_SET_DATA_NONE) { - if (vdev->err_trigger) - eventfd_signal(vdev->err_trigger, 1); + if (*ctx) + eventfd_signal(*ctx, 1); return 0; } else if (flags & VFIO_IRQ_SET_DATA_BOOL) { uint8_t trigger = *(uint8_t *)data; - if (trigger && vdev->err_trigger) - eventfd_signal(vdev->err_trigger, 1); + if (trigger && *ctx) + eventfd_signal(*ctx, 1); return 0; } /* Handle SET_DATA_EVENTFD */ if (fd == -1) { - if (vdev->err_trigger) - eventfd_ctx_put(vdev->err_trigger); - vdev->err_trigger = NULL; + if (*ctx) + eventfd_ctx_put(*ctx); + *ctx = NULL; return 0; } else if (fd >= 0) { struct eventfd_ctx *efdctx; efdctx = eventfd_ctx_fdget(fd); if (IS_ERR(efdctx)) return PTR_ERR(efdctx); - if (vdev->err_trigger) - eventfd_ctx_put(vdev->err_trigger); - vdev->err_trigger = efdctx; + if (*ctx) + eventfd_ctx_put(*ctx); + *ctx = efdctx; return 0; } else return -EINVAL; } + +static int vfio_pci_set_err_trigger(struct vfio_pci_device *vdev, + unsigned index, unsigned start, + unsigned count, uint32_t flags, void *data) +{ + if (index != VFIO_PCI_ERR_IRQ_INDEX) + return -EINVAL; + + /* + * We should sanitize start & count, but that wasn't caught + * originally, so this IRQ index must forever ignore them :-( + */ + + return vfio_pci_set_ctx_trigger_single(&vdev->err_trigger, flags, data); +} + int vfio_pci_set_irqs_ioctl(struct vfio_pci_device *vdev, uint32_t flags, unsigned index, unsigned start, unsigned count, void *data)