diff --git a/nbd.c b/nbd.c index 053ad8d00d..964a732d5c 100644 --- a/nbd.c +++ b/nbd.c @@ -611,6 +611,47 @@ static int nbd_do_send_reply(int csock, struct nbd_reply *reply, return rc; } +static int nbd_do_receive_request(int csock, struct nbd_request *request, + uint8_t *data) +{ + int rc; + + if (nbd_receive_request(csock, request) == -1) { + rc = -EIO; + goto out; + } + + if (request->len > NBD_BUFFER_SIZE) { + LOG("len (%u) is larger than max len (%u)", + request->len, NBD_BUFFER_SIZE); + rc = -EINVAL; + goto out; + } + + if ((request->from + request->len) < request->from) { + LOG("integer overflow detected! " + "you're probably being attacked"); + rc = -EINVAL; + goto out; + } + + TRACE("Decoding type"); + + if ((request->type & NBD_CMD_MASK_COMMAND) == NBD_CMD_WRITE) { + TRACE("Reading %u byte(s)", request->len); + + if (read_sync(csock, data, request->len) != request->len) { + LOG("reading from socket failed"); + rc = -EIO; + goto out; + } + } + rc = 0; + +out: + return rc; +} + int nbd_trip(BlockDriverState *bs, int csock, off_t size, uint64_t dev_offset, uint32_t nbdflags, uint8_t *data) @@ -621,22 +662,17 @@ int nbd_trip(BlockDriverState *bs, int csock, off_t size, TRACE("Reading request."); - if (nbd_receive_request(csock, &request) == -1) + ret = nbd_do_receive_request(csock, &request, data); + if (ret == -EIO) { return -1; + } reply.handle = request.handle; reply.error = 0; - if (request.len > NBD_BUFFER_SIZE) { - LOG("len (%u) is larger than max len (%u)", - request.len, NBD_BUFFER_SIZE); - goto invalid_request; - } - - if ((request.from + request.len) < request.from) { - LOG("integer overflow detected! " - "you're probably being attacked"); - goto invalid_request; + if (ret < 0) { + reply.error = -ret; + goto error_reply; } if ((request.from + request.len) > size) { @@ -647,8 +683,6 @@ int nbd_trip(BlockDriverState *bs, int csock, off_t size, goto invalid_request; } - TRACE("Decoding type"); - switch (request.type & NBD_CMD_MASK_COMMAND) { case NBD_CMD_READ: TRACE("Request type is READ"); @@ -668,14 +702,6 @@ int nbd_trip(BlockDriverState *bs, int csock, off_t size, case NBD_CMD_WRITE: TRACE("Request type is WRITE"); - TRACE("Reading %u byte(s)", request.len); - - if (read_sync(csock, data, request.len) != request.len) { - LOG("reading from socket failed"); - errno = EINVAL; - return -1; - } - if (nbdflags & NBD_FLAG_READ_ONLY) { TRACE("Server is read-only, return error"); reply.error = EROFS;