libstd/net: Add peek APIs to UdpSocket and TcpStream

These methods enable socket reads without side-effects. That is,
repeated calls to peek() return identical data. This is accomplished
by providing the POSIX flag MSG_PEEK to the underlying socket read
operations.

This also moves the current implementation of recv_from out of the
platform-independent sys_common and into respective sys/windows and
sys/unix implementations. This allows for more platform-dependent
implementations.
This commit is contained in:
Tyler Julian 2017-01-10 19:11:56 -08:00
parent 9749df52b7
commit a40be0857c
8 changed files with 251 additions and 17 deletions

@ -1 +1 @@
Subproject commit 7d57bdcdbb56540f37afe5a934ce12d33a6ca7fc Subproject commit cb7f66732175e6171587ed69656b7aae7dd2e6ec

View File

@ -275,6 +275,7 @@
#![feature(oom)] #![feature(oom)]
#![feature(optin_builtin_traits)] #![feature(optin_builtin_traits)]
#![feature(panic_unwind)] #![feature(panic_unwind)]
#![feature(peek)]
#![feature(placement_in_syntax)] #![feature(placement_in_syntax)]
#![feature(prelude_import)] #![feature(prelude_import)]
#![feature(pub_restricted)] #![feature(pub_restricted)]

View File

@ -296,6 +296,29 @@ impl TcpStream {
self.0.write_timeout() self.0.write_timeout()
} }
/// Receives data on the socket from the remote adress to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::TcpStream;
///
/// let stream = TcpStream::connect("127.0.0.1:8000")
/// .expect("couldn't bind to address");
/// let mut buf = [0; 10];
/// let len = stream.peek(&mut buf).expect("peek failed");
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.peek(buf)
}
/// Sets the value of the `TCP_NODELAY` option on this socket. /// Sets the value of the `TCP_NODELAY` option on this socket.
/// ///
/// If set, this option disables the Nagle algorithm. This means that /// If set, this option disables the Nagle algorithm. This means that
@ -1405,4 +1428,35 @@ mod tests {
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
} }
} }
#[test]
fn peek() {
each_ip(&mut |addr| {
let (txdone, rxdone) = channel();
let srv = t!(TcpListener::bind(&addr));
let _t = thread::spawn(move|| {
let mut cl = t!(srv.accept()).0;
cl.write(&[1,3,3,7]).unwrap();
t!(rxdone.recv());
});
let mut c = t!(TcpStream::connect(&addr));
let mut b = [0; 10];
for _ in 1..3 {
let len = c.peek(&mut b).unwrap();
assert_eq!(len, 4);
}
let len = c.read(&mut b).unwrap();
assert_eq!(len, 4);
t!(c.set_nonblocking(true));
match c.peek(&mut b) {
Ok(_) => panic!("expected error"),
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {}
Err(e) => panic!("unexpected error {}", e),
}
t!(txdone.send(()));
})
}
} }

View File

@ -83,6 +83,30 @@ impl UdpSocket {
self.0.recv_from(buf) self.0.recv_from(buf)
} }
/// Receives data from the socket, without removing it from the queue.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recvfrom` system call.
///
/// On success, returns the number of bytes peeked and the address from
/// whence the data came.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::UdpSocket;
///
/// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
/// let mut buf = [0; 10];
/// let (number_of_bytes, src_addr) = socket.peek_from(&mut buf)
/// .expect("Didn't receive data");
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.0.peek_from(buf)
}
/// Sends data on the socket to the given address. On success, returns the /// Sends data on the socket to the given address. On success, returns the
/// number of bytes written. /// number of bytes written.
/// ///
@ -579,6 +603,37 @@ impl UdpSocket {
self.0.recv(buf) self.0.recv(buf)
} }
/// Receives data on the socket from the remote adress to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
///
/// # Errors
///
/// This method will fail if the socket is not connected. The `connect` method
/// will connect this socket to a remote address.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::UdpSocket;
///
/// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
/// socket.connect("127.0.0.1:8080").expect("connect function failed");
/// let mut buf = [0; 10];
/// match socket.peek(&mut buf) {
/// Ok(received) => println!("received {} bytes", received),
/// Err(e) => println!("peek function failed: {:?}", e),
/// }
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.peek(buf)
}
/// Moves this UDP socket into or out of nonblocking mode. /// Moves this UDP socket into or out of nonblocking mode.
/// ///
/// On Unix this corresponds to calling fcntl, and on Windows this /// On Unix this corresponds to calling fcntl, and on Windows this
@ -869,6 +924,48 @@ mod tests {
assert_eq!(b"hello world", &buf[..]); assert_eq!(b"hello world", &buf[..]);
} }
#[test]
fn connect_send_peek_recv() {
each_ip(&mut |addr, _| {
let socket = t!(UdpSocket::bind(&addr));
t!(socket.connect(addr));
t!(socket.send(b"hello world"));
for _ in 1..3 {
let mut buf = [0; 11];
let size = t!(socket.peek(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
}
let mut buf = [0; 11];
let size = t!(socket.recv(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
})
}
#[test]
fn peek_from() {
each_ip(&mut |addr, _| {
let socket = t!(UdpSocket::bind(&addr));
t!(socket.send_to(b"hello world", &addr));
for _ in 1..3 {
let mut buf = [0; 11];
let (size, _) = t!(socket.peek_from(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
}
let mut buf = [0; 11];
let (size, _) = t!(socket.recv_from(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
})
}
#[test] #[test]
fn ttl() { fn ttl() {
let ttl = 100; let ttl = 100;

View File

@ -10,12 +10,13 @@
use ffi::CStr; use ffi::CStr;
use io; use io;
use libc::{self, c_int, size_t, sockaddr, socklen_t, EAI_SYSTEM}; use libc::{self, c_int, c_void, size_t, sockaddr, socklen_t, EAI_SYSTEM, MSG_PEEK};
use mem;
use net::{SocketAddr, Shutdown}; use net::{SocketAddr, Shutdown};
use str; use str;
use sys::fd::FileDesc; use sys::fd::FileDesc;
use sys_common::{AsInner, FromInner, IntoInner}; use sys_common::{AsInner, FromInner, IntoInner};
use sys_common::net::{getsockopt, setsockopt}; use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
use time::Duration; use time::Duration;
pub use sys::{cvt, cvt_r}; pub use sys::{cvt, cvt_r};
@ -155,8 +156,46 @@ impl Socket {
self.0.duplicate().map(Socket) self.0.duplicate().map(Socket)
} }
fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
let ret = cvt(unsafe {
libc::recv(self.0.raw(),
buf.as_mut_ptr() as *mut c_void,
buf.len(),
flags)
})?;
Ok(ret as usize)
}
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> { pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf) self.recv_with_flags(buf, 0)
}
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, MSG_PEEK)
}
fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
-> io::Result<(usize, SocketAddr)> {
let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t;
let n = cvt(unsafe {
libc::recvfrom(self.0.raw(),
buf.as_mut_ptr() as *mut c_void,
buf.len(),
flags,
&mut storage as *mut _ as *mut _,
&mut addrlen)
})?;
Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
}
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, 0)
}
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, MSG_PEEK)
} }
pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> { pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {

View File

@ -244,6 +244,7 @@ pub const IP_ADD_MEMBERSHIP: c_int = 12;
pub const IP_DROP_MEMBERSHIP: c_int = 13; pub const IP_DROP_MEMBERSHIP: c_int = 13;
pub const IPV6_ADD_MEMBERSHIP: c_int = 12; pub const IPV6_ADD_MEMBERSHIP: c_int = 12;
pub const IPV6_DROP_MEMBERSHIP: c_int = 13; pub const IPV6_DROP_MEMBERSHIP: c_int = 13;
pub const MSG_PEEK: c_int = 0x2;
#[repr(C)] #[repr(C)]
pub struct ip_mreq { pub struct ip_mreq {

View File

@ -147,12 +147,12 @@ impl Socket {
Ok(socket) Ok(socket)
} }
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> { fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
// On unix when a socket is shut down all further reads return 0, so we // On unix when a socket is shut down all further reads return 0, so we
// do the same on windows to map a shut down socket to returning EOF. // do the same on windows to map a shut down socket to returning EOF.
let len = cmp::min(buf.len(), i32::max_value() as usize) as i32; let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
unsafe { unsafe {
match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, 0) { match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) {
-1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0), -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0),
-1 => Err(last_error()), -1 => Err(last_error()),
n => Ok(n as usize) n => Ok(n as usize)
@ -160,6 +160,46 @@ impl Socket {
} }
} }
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, 0)
}
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, c::MSG_PEEK)
}
fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
-> io::Result<(usize, SocketAddr)> {
let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
// On unix when a socket is shut down all further reads return 0, so we
// do the same on windows to map a shut down socket to returning EOF.
unsafe {
match c::recvfrom(self.0,
buf.as_mut_ptr() as *mut c_void,
len,
flags,
&mut storage as *mut _ as *mut _,
&mut addrlen) {
-1 if c::WSAGetLastError() == c::WSAESHUTDOWN => {
Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?))
},
-1 => Err(last_error()),
n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)),
}
}
}
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, 0)
}
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, c::MSG_PEEK)
}
pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> { pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
let mut me = self; let mut me = self;
(&mut me).read_to_end(buf) (&mut me).read_to_end(buf)

View File

@ -91,7 +91,7 @@ fn sockname<F>(f: F) -> io::Result<SocketAddr>
} }
} }
fn sockaddr_to_addr(storage: &c::sockaddr_storage, pub fn sockaddr_to_addr(storage: &c::sockaddr_storage,
len: usize) -> io::Result<SocketAddr> { len: usize) -> io::Result<SocketAddr> {
match storage.ss_family as c_int { match storage.ss_family as c_int {
c::AF_INET => { c::AF_INET => {
@ -222,6 +222,10 @@ impl TcpStream {
self.inner.timeout(c::SO_SNDTIMEO) self.inner.timeout(c::SO_SNDTIMEO)
} }
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf)
}
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> { pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf) self.inner.read(buf)
} }
@ -441,17 +445,11 @@ impl UdpSocket {
} }
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
let mut storage: c::sockaddr_storage = unsafe { mem::zeroed() }; self.inner.recv_from(buf)
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t; }
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
let n = cvt(unsafe { pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
c::recvfrom(*self.inner.as_inner(), self.inner.peek_from(buf)
buf.as_mut_ptr() as *mut c_void,
len, 0,
&mut storage as *mut _ as *mut _, &mut addrlen)
})?;
Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
} }
pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result<usize> { pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result<usize> {
@ -578,6 +576,10 @@ impl UdpSocket {
self.inner.read(buf) self.inner.read(buf)
} }
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf)
}
pub fn send(&self, buf: &[u8]) -> io::Result<usize> { pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t; let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
let ret = cvt(unsafe { let ret = cvt(unsafe {