diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index b2240361f1a1..367d8023b54d 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -94,7 +94,7 @@ struct vhost_net_ubuf_ref { struct vhost_virtqueue *vq; }; -#define VHOST_RX_BATCH 64 +#define VHOST_NET_BATCH 64 struct vhost_net_buf { void **queue; int tail; @@ -168,7 +168,7 @@ static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq) rxq->head = 0; rxq->tail = ptr_ring_consume_batched(nvq->rx_ring, rxq->queue, - VHOST_RX_BATCH); + VHOST_NET_BATCH); return rxq->tail; } @@ -428,17 +428,31 @@ static int vhost_net_enable_vq(struct vhost_net *n, return vhost_poll_start(poll, sock->file); } +static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq) +{ + struct vhost_virtqueue *vq = &nvq->vq; + struct vhost_dev *dev = vq->dev; + + if (!nvq->done_idx) + return; + + vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx); + nvq->done_idx = 0; +} + static int vhost_net_tx_get_vq_desc(struct vhost_net *net, - struct vhost_virtqueue *vq, - struct iovec iov[], unsigned int iov_size, + struct vhost_net_virtqueue *nvq, unsigned int *out_num, unsigned int *in_num, bool *busyloop_intr) { + struct vhost_virtqueue *vq = &nvq->vq; unsigned long uninitialized_var(endtime); int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), out_num, in_num, NULL, NULL); if (r == vq->num && vq->busyloop_timeout) { + if (!vhost_sock_zcopy(vq->private_data)) + vhost_net_signal_used(nvq); preempt_disable(); endtime = busy_clock() + vq->busyloop_timeout; while (vhost_can_busy_poll(endtime)) { @@ -467,9 +481,62 @@ static bool vhost_exceeds_maxpend(struct vhost_net *net) min_t(unsigned int, VHOST_MAX_PEND, vq->num >> 2); } -/* Expects to be always run from workqueue - which acts as - * read-size critical section for our kind of RCU. */ -static void handle_tx(struct vhost_net *net) +static size_t init_iov_iter(struct vhost_virtqueue *vq, struct iov_iter *iter, + size_t hdr_size, int out) +{ + /* Skip header. TODO: support TSO. */ + size_t len = iov_length(vq->iov, out); + + iov_iter_init(iter, WRITE, vq->iov, out, len); + iov_iter_advance(iter, hdr_size); + + return iov_iter_count(iter); +} + +static bool vhost_exceeds_weight(int pkts, int total_len) +{ + return total_len >= VHOST_NET_WEIGHT || + pkts >= VHOST_NET_PKT_WEIGHT; +} + +static int get_tx_bufs(struct vhost_net *net, + struct vhost_net_virtqueue *nvq, + struct msghdr *msg, + unsigned int *out, unsigned int *in, + size_t *len, bool *busyloop_intr) +{ + struct vhost_virtqueue *vq = &nvq->vq; + int ret; + + ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, busyloop_intr); + + if (ret < 0 || ret == vq->num) + return ret; + + if (*in) { + vq_err(vq, "Unexpected descriptor format for TX: out %d, int %d\n", + *out, *in); + return -EFAULT; + } + + /* Sanity check */ + *len = init_iov_iter(vq, &msg->msg_iter, nvq->vhost_hlen, *out); + if (*len == 0) { + vq_err(vq, "Unexpected header len for TX: %zd expected %zd\n", + *len, nvq->vhost_hlen); + return -EFAULT; + } + + return ret; +} + +static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len) +{ + return total_len < VHOST_NET_WEIGHT && + !vhost_vq_avail_empty(vq->dev, vq); +} + +static void handle_tx_copy(struct vhost_net *net, struct socket *sock) { struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; struct vhost_virtqueue *vq = &nvq->vq; @@ -484,37 +551,86 @@ static void handle_tx(struct vhost_net *net) }; size_t len, total_len = 0; int err; - size_t hdr_size; - struct socket *sock; - struct vhost_net_ubuf_ref *uninitialized_var(ubufs); - bool zcopy, zcopy_used; int sent_pkts = 0; - mutex_lock(&vq->mutex); - sock = vq->private_data; - if (!sock) - goto out; + for (;;) { + bool busyloop_intr = false; - if (!vq_iotlb_prefetch(vq)) - goto out; + head = get_tx_bufs(net, nvq, &msg, &out, &in, &len, + &busyloop_intr); + /* On error, stop handling until the next kick. */ + if (unlikely(head < 0)) + break; + /* Nothing new? Wait for eventfd to tell us they refilled. */ + if (head == vq->num) { + if (unlikely(busyloop_intr)) { + vhost_poll_queue(&vq->poll); + } else if (unlikely(vhost_enable_notify(&net->dev, + vq))) { + vhost_disable_notify(&net->dev, vq); + continue; + } + break; + } - vhost_disable_notify(&net->dev, vq); - vhost_net_disable_vq(net, vq); + vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head); + vq->heads[nvq->done_idx].len = 0; - hdr_size = nvq->vhost_hlen; - zcopy = nvq->ubufs; + total_len += len; + if (tx_can_batch(vq, total_len)) + msg.msg_flags |= MSG_MORE; + else + msg.msg_flags &= ~MSG_MORE; + + /* TODO: Check specific error and bomb out unless ENOBUFS? */ + err = sock->ops->sendmsg(sock, &msg, len); + if (unlikely(err < 0)) { + vhost_discard_vq_desc(vq, 1); + vhost_net_enable_vq(net, vq); + break; + } + if (err != len) + pr_debug("Truncated TX packet: len %d != %zd\n", + err, len); + if (++nvq->done_idx >= VHOST_NET_BATCH) + vhost_net_signal_used(nvq); + if (vhost_exceeds_weight(++sent_pkts, total_len)) { + vhost_poll_queue(&vq->poll); + break; + } + } + + vhost_net_signal_used(nvq); +} + +static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) +{ + struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; + struct vhost_virtqueue *vq = &nvq->vq; + unsigned out, in; + int head; + struct msghdr msg = { + .msg_name = NULL, + .msg_namelen = 0, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = MSG_DONTWAIT, + }; + size_t len, total_len = 0; + int err; + struct vhost_net_ubuf_ref *uninitialized_var(ubufs); + bool zcopy_used; + int sent_pkts = 0; for (;;) { bool busyloop_intr; /* Release DMAs done buffers first */ - if (zcopy) - vhost_zerocopy_signal_used(net, vq); + vhost_zerocopy_signal_used(net, vq); busyloop_intr = false; - head = vhost_net_tx_get_vq_desc(net, vq, vq->iov, - ARRAY_SIZE(vq->iov), - &out, &in, &busyloop_intr); + head = get_tx_bufs(net, nvq, &msg, &out, &in, &len, + &busyloop_intr); /* On error, stop handling until the next kick. */ if (unlikely(head < 0)) break; @@ -528,27 +644,10 @@ static void handle_tx(struct vhost_net *net) } break; } - if (in) { - vq_err(vq, "Unexpected descriptor format for TX: " - "out %d, int %d\n", out, in); - break; - } - /* Skip header. TODO: support TSO. */ - len = iov_length(vq->iov, out); - iov_iter_init(&msg.msg_iter, WRITE, vq->iov, out, len); - iov_iter_advance(&msg.msg_iter, hdr_size); - /* Sanity check */ - if (!msg_data_left(&msg)) { - vq_err(vq, "Unexpected header len for TX: " - "%zd expected %zd\n", - len, hdr_size); - break; - } - len = msg_data_left(&msg); - zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN - && !vhost_exceeds_maxpend(net) - && vhost_net_tx_select_zcopy(net); + zcopy_used = len >= VHOST_GOODCOPY_LEN + && !vhost_exceeds_maxpend(net) + && vhost_net_tx_select_zcopy(net); /* use msg_control to pass vhost zerocopy ubuf info to skb */ if (zcopy_used) { @@ -570,10 +669,8 @@ static void handle_tx(struct vhost_net *net) msg.msg_control = NULL; ubufs = NULL; } - total_len += len; - if (total_len < VHOST_NET_WEIGHT && - !vhost_vq_avail_empty(&net->dev, vq) && + if (tx_can_batch(vq, total_len) && likely(!vhost_exceeds_maxpend(net))) { msg.msg_flags |= MSG_MORE; } else { @@ -600,12 +697,37 @@ static void handle_tx(struct vhost_net *net) else vhost_zerocopy_signal_used(net, vq); vhost_net_tx_packet(net); - if (unlikely(total_len >= VHOST_NET_WEIGHT) || - unlikely(++sent_pkts >= VHOST_NET_PKT_WEIGHT)) { + if (unlikely(vhost_exceeds_weight(++sent_pkts, total_len))) { vhost_poll_queue(&vq->poll); break; } } +} + +/* Expects to be always run from workqueue - which acts as + * read-size critical section for our kind of RCU. */ +static void handle_tx(struct vhost_net *net) +{ + struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; + struct vhost_virtqueue *vq = &nvq->vq; + struct socket *sock; + + mutex_lock(&vq->mutex); + sock = vq->private_data; + if (!sock) + goto out; + + if (!vq_iotlb_prefetch(vq)) + goto out; + + vhost_disable_notify(&net->dev, vq); + vhost_net_disable_vq(net, vq); + + if (vhost_sock_zcopy(sock)) + handle_tx_zerocopy(net, sock); + else + handle_tx_copy(net, sock); + out: mutex_unlock(&vq->mutex); } @@ -641,18 +763,6 @@ static int sk_has_rx_data(struct sock *sk) return skb_queue_empty(&sk->sk_receive_queue); } -static void vhost_rx_signal_used(struct vhost_net_virtqueue *nvq) -{ - struct vhost_virtqueue *vq = &nvq->vq; - struct vhost_dev *dev = vq->dev; - - if (!nvq->done_idx) - return; - - vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx); - nvq->done_idx = 0; -} - static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, bool *busyloop_intr) { @@ -665,7 +775,7 @@ static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, if (!len && tvq->busyloop_timeout) { /* Flush batched heads first */ - vhost_rx_signal_used(rnvq); + vhost_net_signal_used(rnvq); /* Both tx vq and rx socket were polled here */ mutex_lock_nested(&tvq->mutex, 1); vhost_disable_notify(&net->dev, tvq); @@ -907,13 +1017,12 @@ static void handle_rx(struct vhost_net *net) goto out; } nvq->done_idx += headcount; - if (nvq->done_idx > VHOST_RX_BATCH) - vhost_rx_signal_used(nvq); + if (nvq->done_idx > VHOST_NET_BATCH) + vhost_net_signal_used(nvq); if (unlikely(vq_log)) vhost_log_write(vq, vq_log, log, vhost_len); total_len += vhost_len; - if (unlikely(total_len >= VHOST_NET_WEIGHT) || - unlikely(++recv_pkts >= VHOST_NET_PKT_WEIGHT)) { + if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) { vhost_poll_queue(&vq->poll); goto out; } @@ -923,7 +1032,7 @@ static void handle_rx(struct vhost_net *net) else vhost_net_enable_vq(net, vq); out: - vhost_rx_signal_used(nvq); + vhost_net_signal_used(nvq); mutex_unlock(&vq->mutex); } @@ -976,7 +1085,7 @@ static int vhost_net_open(struct inode *inode, struct file *f) return -ENOMEM; } - queue = kmalloc_array(VHOST_RX_BATCH, sizeof(void *), + queue = kmalloc_array(VHOST_NET_BATCH, sizeof(void *), GFP_KERNEL); if (!queue) { kfree(vqs);