libvhost-user: handle NOFD flag in call/kick/err better

The code here is odd, for example will it print out invalid
file descriptor numbers that were never sent in the message.

Clean that up a bit so it's actually possible to implement
a device that uses polling.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Message-Id: <20200123081708.7817-5-johannes@sipsolutions.net>
Reviewed-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
This commit is contained in:
Johannes Berg 2020-01-23 09:17:06 +01:00 committed by Michael S. Tsirkin
parent a00fdc9c9d
commit d5f99fc578
1 changed files with 16 additions and 8 deletions

View File

@ -948,6 +948,7 @@ static bool
vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
if (index >= dev->max_queues) {
vmsg_close_fds(vmsg);
@ -955,8 +956,12 @@ vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
return false;
}
if (vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK ||
vmsg->fd_num != 1) {
if (nofd) {
vmsg_close_fds(vmsg);
return true;
}
if (vmsg->fd_num != 1) {
vmsg_close_fds(vmsg);
vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
return false;
@ -1053,6 +1058,7 @@ static bool
vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
@ -1066,8 +1072,8 @@ vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
dev->vq[index].kick_fd = -1;
}
dev->vq[index].kick_fd = vmsg->fds[0];
DPRINT("Got kick_fd: %d for vq: %d\n", vmsg->fds[0], index);
dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
dev->vq[index].started = true;
if (dev->iface->queue_set_started) {
@ -1147,6 +1153,7 @@ static bool
vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
@ -1159,14 +1166,14 @@ vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
dev->vq[index].call_fd = -1;
}
dev->vq[index].call_fd = vmsg->fds[0];
dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
/* in case of I/O hang after reconnecting */
if (eventfd_write(vmsg->fds[0], 1)) {
if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
return -1;
}
DPRINT("Got call_fd: %d for vq: %d\n", vmsg->fds[0], index);
DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
return false;
}
@ -1175,6 +1182,7 @@ static bool
vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
@ -1187,7 +1195,7 @@ vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
dev->vq[index].err_fd = -1;
}
dev->vq[index].err_fd = vmsg->fds[0];
dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
return false;
}