diff --git a/hw/usb-msd.c b/hw/usb-msd.c index 48e0b343ca..90e57fbf6b 100644 --- a/hw/usb-msd.c +++ b/hw/usb-msd.c @@ -43,8 +43,6 @@ typedef struct { enum USBMSDMode mode; uint32_t scsi_len; uint8_t *scsi_buf; - uint32_t usb_len; - uint8_t *usb_buf; uint32_t data_len; uint32_t residue; uint32_t tag; @@ -176,20 +174,14 @@ static const USBDesc desc = { .str = desc_strings, }; -static void usb_msd_copy_data(MSDState *s) +static void usb_msd_copy_data(MSDState *s, USBPacket *p) { uint32_t len; - len = s->usb_len; + len = p->iov.size - p->result; if (len > s->scsi_len) len = s->scsi_len; - if (s->mode == USB_MSDM_DATAIN) { - memcpy(s->usb_buf, s->scsi_buf, len); - } else { - memcpy(s->scsi_buf, s->usb_buf, len); - } - s->usb_len -= len; + usb_packet_copy(p, s->scsi_buf, len); s->scsi_len -= len; - s->usb_buf += len; s->scsi_buf += len; s->data_len -= len; if (s->scsi_len == 0 || s->data_len == 0) { @@ -221,9 +213,9 @@ static void usb_msd_transfer_data(SCSIRequest *req, uint32_t len) s->scsi_len = len; s->scsi_buf = scsi_req_get_buf(req); if (p) { - usb_msd_copy_data(s); - if (s->packet && s->usb_len == 0) { - p->result = p->iov.size; + usb_msd_copy_data(s, p); + p = s->packet; + if (p && p->result == p->iov.size) { /* Set s->packet to NULL before calling usb_packet_complete because another request may be issued before usb_packet_complete returns. */ @@ -250,16 +242,13 @@ static void usb_msd_command_complete(SCSIRequest *req, uint32_t status) s->mode = USB_MSDM_CBW; } else { if (s->data_len) { - s->data_len -= s->usb_len; - if (s->mode == USB_MSDM_DATAIN) { - memset(s->usb_buf, 0, s->usb_len); - } - s->usb_len = 0; + int len = (p->iov.size - p->result); + usb_packet_skip(p, len); + s->data_len -= len; } if (s->data_len == 0) { s->mode = USB_MSDM_CSW; } - p->result = p->iov.size; } s->packet = NULL; usb_packet_complete(&s->dev, p); @@ -345,10 +334,7 @@ static int usb_msd_handle_data(USBDevice *dev, USBPacket *p) int ret = 0; struct usb_msd_cbw cbw; uint8_t devep = p->devep; - uint8_t *data = p->iov.iov[0].iov_base; - int len = p->iov.iov[0].iov_len; - assert(p->iov.niov == 1); /* temporary */ switch (p->pid) { case USB_TOKEN_OUT: if (devep != 2) @@ -356,11 +342,11 @@ static int usb_msd_handle_data(USBDevice *dev, USBPacket *p) switch (s->mode) { case USB_MSDM_CBW: - if (len != 31) { + if (p->iov.size != 31) { fprintf(stderr, "usb-msd: Bad CBW size"); goto fail; } - memcpy(&cbw, data, 31); + usb_packet_copy(p, &cbw, 31); if (le32_to_cpu(cbw.sig) != 0x43425355) { fprintf(stderr, "usb-msd: Bad signature %08x\n", le32_to_cpu(cbw.sig)); @@ -391,36 +377,39 @@ static int usb_msd_handle_data(USBDevice *dev, USBPacket *p) if (s->mode != USB_MSDM_CSW && s->residue == 0) { scsi_req_continue(s->req); } - ret = len; + ret = p->result; break; case USB_MSDM_DATAOUT: - DPRINTF("Data out %d/%d\n", len, s->data_len); - if (len > s->data_len) + DPRINTF("Data out %zd/%d\n", p->iov.size, s->data_len); + if (p->iov.size > s->data_len) { goto fail; + } - s->usb_buf = data; - s->usb_len = len; if (s->scsi_len) { - usb_msd_copy_data(s); + usb_msd_copy_data(s, p); } - if (s->residue && s->usb_len) { - s->data_len -= s->usb_len; - if (s->data_len == 0) - s->mode = USB_MSDM_CSW; - s->usb_len = 0; + if (s->residue) { + int len = p->iov.size - p->result; + if (len) { + usb_packet_skip(p, len); + s->data_len -= len; + if (s->data_len == 0) { + s->mode = USB_MSDM_CSW; + } + } } - if (s->usb_len) { + if (p->result < p->iov.size) { DPRINTF("Deferring packet %p\n", p); s->packet = p; ret = USB_RET_ASYNC; } else { - ret = len; + ret = p->result; } break; default: - DPRINTF("Unexpected write (len %d)\n", len); + DPRINTF("Unexpected write (len %zd)\n", p->iov.size); goto fail; } break; @@ -431,18 +420,20 @@ static int usb_msd_handle_data(USBDevice *dev, USBPacket *p) switch (s->mode) { case USB_MSDM_DATAOUT: - if (s->data_len != 0 || len < 13) + if (s->data_len != 0 || p->iov.size < 13) { goto fail; + } /* Waiting for SCSI write to complete. */ s->packet = p; ret = USB_RET_ASYNC; break; case USB_MSDM_CSW: - DPRINTF("Command status %d tag 0x%x, len %d\n", - s->result, s->tag, len); - if (len < 13) + DPRINTF("Command status %d tag 0x%x, len %zd\n", + s->result, s->tag, p->iov.size); + if (p->iov.size < 13) { goto fail; + } usb_msd_send_status(s, p); s->mode = USB_MSDM_CBW; @@ -450,32 +441,32 @@ static int usb_msd_handle_data(USBDevice *dev, USBPacket *p) break; case USB_MSDM_DATAIN: - DPRINTF("Data in %d/%d, scsi_len %d\n", len, s->data_len, s->scsi_len); - if (len > s->data_len) - len = s->data_len; - s->usb_buf = data; - s->usb_len = len; + DPRINTF("Data in %zd/%d, scsi_len %d\n", + p->iov.size, s->data_len, s->scsi_len); if (s->scsi_len) { - usb_msd_copy_data(s); + usb_msd_copy_data(s, p); } - if (s->residue && s->usb_len) { - s->data_len -= s->usb_len; - memset(s->usb_buf, 0, s->usb_len); - if (s->data_len == 0) - s->mode = USB_MSDM_CSW; - s->usb_len = 0; + if (s->residue) { + int len = p->iov.size - p->result; + if (len) { + usb_packet_skip(p, len); + s->data_len -= len; + if (s->data_len == 0) { + s->mode = USB_MSDM_CSW; + } + } } - if (s->usb_len) { + if (p->result < p->iov.size) { DPRINTF("Deferring packet %p\n", p); s->packet = p; ret = USB_RET_ASYNC; } else { - ret = len; + ret = p->result; } break; default: - DPRINTF("Unexpected read (len %d)\n", len); + DPRINTF("Unexpected read (len %zd)\n", p->iov.size); goto fail; } break;