From b94b15852cd9b14160cce7f85f241691a72c18af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emilio=20Cobos=20=C3=81lvarez?= Date: Fri, 20 May 2016 02:43:18 +0200 Subject: [PATCH] std: sync: Implement recv_timeout() --- src/libstd/sync/mpsc/blocking.rs | 14 +- src/libstd/sync/mpsc/mod.rs | 302 +++++++++++++++++++++++++++++-- src/libstd/sync/mpsc/oneshot.rs | 15 +- src/libstd/sync/mpsc/shared.rs | 17 +- src/libstd/sync/mpsc/stream.rs | 15 +- src/libstd/sync/mpsc/sync.rs | 75 ++++++-- 6 files changed, 396 insertions(+), 42 deletions(-) diff --git a/src/libstd/sync/mpsc/blocking.rs b/src/libstd/sync/mpsc/blocking.rs index 0e5a9859116..4a70de0e7d8 100644 --- a/src/libstd/sync/mpsc/blocking.rs +++ b/src/libstd/sync/mpsc/blocking.rs @@ -16,6 +16,7 @@ use sync::Arc; use marker::{Sync, Send}; use mem; use clone::Clone; +use time::Instant; struct Inner { thread: Thread, @@ -74,7 +75,6 @@ impl SignalToken { pub unsafe fn cast_from_usize(signal_ptr: usize) -> SignalToken { SignalToken { inner: mem::transmute(signal_ptr) } } - } impl WaitToken { @@ -83,4 +83,16 @@ impl WaitToken { thread::park() } } + + /// Returns true if we wake up normally, false otherwise. + pub fn wait_max_until(self, end: Instant) -> bool { + while !self.inner.woken.load(Ordering::SeqCst) { + let now = Instant::now(); + if now >= end { + return false; + } + thread::park_timeout(end - now) + } + true + } } diff --git a/src/libstd/sync/mpsc/mod.rs b/src/libstd/sync/mpsc/mod.rs index 63b659d8db3..34bc210b3c8 100644 --- a/src/libstd/sync/mpsc/mod.rs +++ b/src/libstd/sync/mpsc/mod.rs @@ -134,9 +134,9 @@ // senders. Under the hood, however, there are actually three flavors of // channels in play. // -// * Flavor::Oneshots - these channels are highly optimized for the one-send use case. -// They contain as few atomics as possible and involve one and -// exactly one allocation. +// * Flavor::Oneshots - these channels are highly optimized for the one-send use +// case. They contain as few atomics as possible and +// involve one and exactly one allocation. // * Streams - these channels are optimized for the non-shared use case. They // use a different concurrent queue that is more tailored for this // use case. The initial allocation of this flavor of channel is not @@ -148,9 +148,9 @@ // // ## Concurrent queues // -// The basic idea of Rust's Sender/Receiver types is that send() never blocks, but -// recv() obviously blocks. This means that under the hood there must be some -// shared and concurrent queue holding all of the actual data. +// The basic idea of Rust's Sender/Receiver types is that send() never blocks, +// but recv() obviously blocks. This means that under the hood there must be +// some shared and concurrent queue holding all of the actual data. // // With two flavors of channels, two flavors of queues are also used. We have // chosen to use queues from a well-known author that are abbreviated as SPSC @@ -271,6 +271,7 @@ use fmt; use mem; use cell::UnsafeCell; use marker::Reflect; +use time::{Duration, Instant}; #[unstable(feature = "mpsc_select", issue = "27800")] pub use self::select::{Select, Handle}; @@ -379,6 +380,19 @@ pub enum TryRecvError { Disconnected, } +/// This enumeration is the list of possible errors that `recv_timeout` could +/// not return data when called. +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[unstable(feature = "mpsc_recv_timeout", issue = "34029")] +pub enum RecvTimeoutError { + /// This channel is currently empty, but the sender(s) have not yet + /// disconnected, so data may yet become available. + Timeout, + /// This channel's sending half has become disconnected, and there will + /// never be any more data received on this channel + Disconnected, +} + /// This enumeration is the list of the possible error outcomes for the /// `SyncSender::try_send` method. #[stable(feature = "rust1", since = "1.0.0")] @@ -838,30 +852,30 @@ impl Receiver { loop { let new_port = match *unsafe { self.inner() } { Flavor::Oneshot(ref p) => { - match unsafe { (*p.get()).recv() } { + match unsafe { (*p.get()).recv(None) } { Ok(t) => return Ok(t), - Err(oneshot::Empty) => return unreachable!(), Err(oneshot::Disconnected) => return Err(RecvError), Err(oneshot::Upgraded(rx)) => rx, + Err(oneshot::Empty) => unreachable!(), } } Flavor::Stream(ref p) => { - match unsafe { (*p.get()).recv() } { + match unsafe { (*p.get()).recv(None) } { Ok(t) => return Ok(t), - Err(stream::Empty) => return unreachable!(), Err(stream::Disconnected) => return Err(RecvError), Err(stream::Upgraded(rx)) => rx, + Err(stream::Empty) => unreachable!(), } } Flavor::Shared(ref p) => { - match unsafe { (*p.get()).recv() } { + match unsafe { (*p.get()).recv(None) } { Ok(t) => return Ok(t), - Err(shared::Empty) => return unreachable!(), Err(shared::Disconnected) => return Err(RecvError), + Err(shared::Empty) => unreachable!(), } } Flavor::Sync(ref p) => return unsafe { - (*p.get()).recv().map_err(|()| RecvError) + (*p.get()).recv(None).map_err(|_| RecvError) } }; unsafe { @@ -870,6 +884,98 @@ impl Receiver { } } + /// Attempts to wait for a value on this receiver, returning an error if the + /// corresponding channel has hung up, or if it waits more than `timeout`. + /// + /// This function will always block the current thread if there is no data + /// available and it's possible for more data to be sent. Once a message is + /// sent to the corresponding `Sender`, then this receiver will wake up and + /// return that message. + /// + /// If the corresponding `Sender` has disconnected, or it disconnects while + /// this call is blocking, this call will wake up and return `Err` to + /// indicate that no more messages can ever be received on this channel. + /// However, since channels are buffered, messages sent before the disconnect + /// will still be properly received. + /// + /// # Examples + /// + /// ```no_run + /// #![feature(mpsc_recv_timeout)] + /// + /// use std::sync::mpsc::{self, RecvTimeoutError}; + /// use std::time::Duration; + /// + /// let (send, recv) = mpsc::channel::<()>(); + /// + /// let timeout = Duration::from_millis(100); + /// assert_eq!(Err(RecvTimeoutError::Timeout), recv.recv_timeout(timeout)); + /// ``` + #[unstable(feature = "mpsc_recv_timeout", issue = "34029")] + pub fn recv_timeout(&self, timeout: Duration) -> Result { + // Do an optimistic try_recv to avoid the performance impact of + // Instant::now() in the full-channel case. + match self.try_recv() { + Ok(result) + => Ok(result), + Err(TryRecvError::Disconnected) + => Err(RecvTimeoutError::Disconnected), + Err(TryRecvError::Empty) + => self.recv_max_until(Instant::now() + timeout) + } + } + + fn recv_max_until(&self, deadline: Instant) -> Result { + use self::RecvTimeoutError::*; + + loop { + let port_or_empty = match *unsafe { self.inner() } { + Flavor::Oneshot(ref p) => { + match unsafe { (*p.get()).recv(Some(deadline)) } { + Ok(t) => return Ok(t), + Err(oneshot::Disconnected) => return Err(Disconnected), + Err(oneshot::Upgraded(rx)) => Some(rx), + Err(oneshot::Empty) => None, + } + } + Flavor::Stream(ref p) => { + match unsafe { (*p.get()).recv(Some(deadline)) } { + Ok(t) => return Ok(t), + Err(stream::Disconnected) => return Err(Disconnected), + Err(stream::Upgraded(rx)) => Some(rx), + Err(stream::Empty) => None, + } + } + Flavor::Shared(ref p) => { + match unsafe { (*p.get()).recv(Some(deadline)) } { + Ok(t) => return Ok(t), + Err(shared::Disconnected) => return Err(Disconnected), + Err(shared::Empty) => None, + } + } + Flavor::Sync(ref p) => { + match unsafe { (*p.get()).recv(Some(deadline)) } { + Ok(t) => return Ok(t), + Err(sync::Disconnected) => return Err(Disconnected), + Err(sync::Empty) => None, + } + } + }; + + if let Some(new_port) = port_or_empty { + unsafe { + mem::swap(self.inner_mut(), new_port.inner_mut()); + } + } + + // If we're already passed the deadline, and we're here without + // data, return a timeout, else try again. + if Instant::now() >= deadline { + return Err(Timeout); + } + } + } + /// Returns an iterator that will block waiting for messages, but never /// `panic!`. It will return `None` when the channel has hung up. #[stable(feature = "rust1", since = "1.0.0")] @@ -1141,6 +1247,7 @@ mod tests { use env; use super::*; use thread; + use time::{Duration, Instant}; pub fn stress_factor() -> usize { match env::var("RUST_TEST_STRESS") { @@ -1539,6 +1646,87 @@ mod tests { } } + #[test] + fn oneshot_single_thread_recv_timeout() { + let (tx, rx) = channel(); + tx.send(()).unwrap(); + assert_eq!(rx.recv_timeout(Duration::from_millis(1)), Ok(())); + assert_eq!(rx.recv_timeout(Duration::from_millis(1)), Err(RecvTimeoutError::Timeout)); + tx.send(()).unwrap(); + assert_eq!(rx.recv_timeout(Duration::from_millis(1)), Ok(())); + } + + #[test] + fn stress_recv_timeout_two_threads() { + let (tx, rx) = channel(); + let stress = stress_factor() + 100; + let timeout = Duration::from_millis(100); + + thread::spawn(move || { + for i in 0..stress { + if i % 2 == 0 { + thread::sleep(timeout * 2); + } + tx.send(1usize).unwrap(); + } + }); + + let mut recv_count = 0; + loop { + match rx.recv_timeout(timeout) { + Ok(n) => { + assert_eq!(n, 1usize); + recv_count += 1; + } + Err(RecvTimeoutError::Timeout) => continue, + Err(RecvTimeoutError::Disconnected) => break, + } + } + + assert_eq!(recv_count, stress); + } + + #[test] + fn recv_timeout_upgrade() { + let (tx, rx) = channel::<()>(); + let timeout = Duration::from_millis(1); + let _tx_clone = tx.clone(); + + let start = Instant::now(); + assert_eq!(rx.recv_timeout(timeout), Err(RecvTimeoutError::Timeout)); + assert!(Instant::now() >= start + timeout); + } + + #[test] + fn stress_recv_timeout_shared() { + let (tx, rx) = channel(); + let stress = stress_factor() + 100; + + for i in 0..stress { + let tx = tx.clone(); + thread::spawn(move || { + thread::sleep(Duration::from_millis(i as u64 * 10)); + tx.send(1usize).unwrap(); + }); + } + + drop(tx); + + let mut recv_count = 0; + loop { + match rx.recv_timeout(Duration::from_millis(10)) { + Ok(n) => { + assert_eq!(n, 1usize); + recv_count += 1; + } + Err(RecvTimeoutError::Timeout) => continue, + Err(RecvTimeoutError::Disconnected) => break, + } + } + + assert_eq!(recv_count, stress); + } + #[test] fn recv_a_lot() { // Regression test that we don't run out of stack in scheduler context @@ -1547,6 +1735,24 @@ mod tests { for _ in 0..10000 { rx.recv().unwrap(); } } + #[test] + fn shared_recv_timeout() { + let (tx, rx) = channel(); + let total = 5; + for _ in 0..total { + let tx = tx.clone(); + thread::spawn(move|| { + tx.send(()).unwrap(); + }); + } + + for _ in 0..total { rx.recv().unwrap(); } + + assert_eq!(rx.recv_timeout(Duration::from_millis(1)), Err(RecvTimeoutError::Timeout)); + tx.send(()).unwrap(); + assert_eq!(rx.recv_timeout(Duration::from_millis(1)), Ok(())); + } + #[test] fn shared_chan_stress() { let (tx, rx) = channel(); @@ -1689,6 +1895,7 @@ mod sync_tests { use env; use thread; use super::*; + use time::Duration; pub fn stress_factor() -> usize { match env::var("RUST_TEST_STRESS") { @@ -1720,6 +1927,14 @@ mod sync_tests { assert_eq!(rx.recv().unwrap(), 1); } + #[test] + fn recv_timeout() { + let (tx, rx) = sync_channel::(1); + assert_eq!(rx.recv_timeout(Duration::from_millis(1)), Err(RecvTimeoutError::Timeout)); + tx.send(1).unwrap(); + assert_eq!(rx.recv_timeout(Duration::from_millis(1)), Ok(1)); + } + #[test] fn smoke_threads() { let (tx, rx) = sync_channel::(0); @@ -1801,6 +2016,67 @@ mod sync_tests { } } + #[test] + fn stress_recv_timeout_two_threads() { + let (tx, rx) = sync_channel::(0); + + thread::spawn(move|| { + for _ in 0..10000 { tx.send(1).unwrap(); } + }); + + let mut recv_count = 0; + loop { + match rx.recv_timeout(Duration::from_millis(1)) { + Ok(v) => { + assert_eq!(v, 1); + recv_count += 1; + }, + Err(RecvTimeoutError::Timeout) => continue, + Err(RecvTimeoutError::Disconnected) => break, + } + } + + assert_eq!(recv_count, 10000); + } + + #[test] + fn stress_recv_timeout_shared() { + const AMT: u32 = 1000; + const NTHREADS: u32 = 8; + let (tx, rx) = sync_channel::(0); + let (dtx, drx) = sync_channel::<()>(0); + + thread::spawn(move|| { + let mut recv_count = 0; + loop { + match rx.recv_timeout(Duration::from_millis(10)) { + Ok(v) => { + assert_eq!(v, 1); + recv_count += 1; + }, + Err(RecvTimeoutError::Timeout) => continue, + Err(RecvTimeoutError::Disconnected) => break, + } + } + + assert_eq!(recv_count, AMT * NTHREADS); + assert!(rx.try_recv().is_err()); + + dtx.send(()).unwrap(); + }); + + for _ in 0..NTHREADS { + let tx = tx.clone(); + thread::spawn(move|| { + for _ in 0..AMT { tx.send(1).unwrap(); } + }); + } + + drop(tx); + + drx.recv().unwrap(); + } + #[test] fn stress_shared() { const AMT: u32 = 1000; diff --git a/src/libstd/sync/mpsc/oneshot.rs b/src/libstd/sync/mpsc/oneshot.rs index cb930280964..7a35ea6bbaa 100644 --- a/src/libstd/sync/mpsc/oneshot.rs +++ b/src/libstd/sync/mpsc/oneshot.rs @@ -41,6 +41,7 @@ use sync::mpsc::Receiver; use sync::mpsc::blocking::{self, SignalToken}; use core::mem; use sync::atomic::{AtomicUsize, Ordering}; +use time::Instant; // Various states you can find a port in. const EMPTY: usize = 0; // initial state: no data, no blocked receiver @@ -136,7 +137,7 @@ impl Packet { } } - pub fn recv(&mut self) -> Result> { + pub fn recv(&mut self, deadline: Option) -> Result> { // Attempt to not block the thread (it's a little expensive). If it looks // like we're not empty, then immediately go through to `try_recv`. if self.state.load(Ordering::SeqCst) == EMPTY { @@ -145,8 +146,16 @@ impl Packet { // race with senders to enter the blocking state if self.state.compare_and_swap(EMPTY, ptr, Ordering::SeqCst) == EMPTY { - wait_token.wait(); - debug_assert!(self.state.load(Ordering::SeqCst) != EMPTY); + if let Some(deadline) = deadline { + let timed_out = !wait_token.wait_max_until(deadline); + // Try to reset the state + if timed_out { + try!(self.abort_selection().map_err(Upgraded)); + } + } else { + wait_token.wait(); + debug_assert!(self.state.load(Ordering::SeqCst) != EMPTY); + } } else { // drop the signal token, since we never blocked drop(unsafe { SignalToken::cast_from_usize(ptr) }); diff --git a/src/libstd/sync/mpsc/shared.rs b/src/libstd/sync/mpsc/shared.rs index a3779931c7b..baa4db7e5c0 100644 --- a/src/libstd/sync/mpsc/shared.rs +++ b/src/libstd/sync/mpsc/shared.rs @@ -30,6 +30,7 @@ use sync::mpsc::select::StartResult::*; use sync::mpsc::select::StartResult; use sync::{Mutex, MutexGuard}; use thread; +use time::Instant; const DISCONNECTED: isize = isize::MIN; const FUDGE: isize = 1024; @@ -66,7 +67,7 @@ impl Packet { // Creation of a packet *must* be followed by a call to postinit_lock // and later by inherit_blocker pub fn new() -> Packet { - let p = Packet { + Packet { queue: mpsc::Queue::new(), cnt: AtomicIsize::new(0), steals: 0, @@ -75,8 +76,7 @@ impl Packet { port_dropped: AtomicBool::new(false), sender_drain: AtomicIsize::new(0), select_lock: Mutex::new(()), - }; - return p; + } } // This function should be used after newly created Packet @@ -216,7 +216,7 @@ impl Packet { Ok(()) } - pub fn recv(&mut self) -> Result { + pub fn recv(&mut self, deadline: Option) -> Result { // This code is essentially the exact same as that found in the stream // case (see stream.rs) match self.try_recv() { @@ -226,7 +226,14 @@ impl Packet { let (wait_token, signal_token) = blocking::tokens(); if self.decrement(signal_token) == Installed { - wait_token.wait() + if let Some(deadline) = deadline { + let timed_out = !wait_token.wait_max_until(deadline); + if timed_out { + self.abort_selection(false); + } + } else { + wait_token.wait(); + } } match self.try_recv() { diff --git a/src/libstd/sync/mpsc/stream.rs b/src/libstd/sync/mpsc/stream.rs index e8012ca470b..aa1254c8641 100644 --- a/src/libstd/sync/mpsc/stream.rs +++ b/src/libstd/sync/mpsc/stream.rs @@ -25,6 +25,7 @@ use self::Message::*; use core::cmp; use core::isize; use thread; +use time::Instant; use sync::atomic::{AtomicIsize, AtomicUsize, Ordering, AtomicBool}; use sync::mpsc::Receiver; @@ -172,7 +173,7 @@ impl Packet { Err(unsafe { SignalToken::cast_from_usize(ptr) }) } - pub fn recv(&mut self) -> Result> { + pub fn recv(&mut self, deadline: Option) -> Result> { // Optimistic preflight check (scheduling is expensive). match self.try_recv() { Err(Empty) => {} @@ -183,7 +184,15 @@ impl Packet { // initiate the blocking protocol. let (wait_token, signal_token) = blocking::tokens(); if self.decrement(signal_token).is_ok() { - wait_token.wait() + if let Some(deadline) = deadline { + let timed_out = !wait_token.wait_max_until(deadline); + if timed_out { + try!(self.abort_selection(/* was_upgrade = */ false) + .map_err(Upgraded)); + } + } else { + wait_token.wait(); + } } match self.try_recv() { @@ -332,7 +341,7 @@ impl Packet { // the internal state. match self.queue.peek() { Some(&mut GoUp(..)) => { - match self.recv() { + match self.recv(None) { Err(Upgraded(port)) => Err(port), _ => unreachable!(), } diff --git a/src/libstd/sync/mpsc/sync.rs b/src/libstd/sync/mpsc/sync.rs index b98fc2859af..f021689acad 100644 --- a/src/libstd/sync/mpsc/sync.rs +++ b/src/libstd/sync/mpsc/sync.rs @@ -44,6 +44,7 @@ use sync::atomic::{Ordering, AtomicUsize}; use sync::mpsc::blocking::{self, WaitToken, SignalToken}; use sync::mpsc::select::StartResult::{self, Installed, Abort}; use sync::{Mutex, MutexGuard}; +use time::Instant; pub struct Packet { /// Only field outside of the mutex. Just done for kicks, but mainly because @@ -126,6 +127,38 @@ fn wait<'a, 'b, T>(lock: &'a Mutex>, lock.lock().unwrap() // relock } +/// Same as wait, but waiting at most until `deadline`. +fn wait_timeout_receiver<'a, 'b, T>(lock: &'a Mutex>, + deadline: Instant, + mut guard: MutexGuard<'b, State>, + success: &mut bool) + -> MutexGuard<'a, State> +{ + let (wait_token, signal_token) = blocking::tokens(); + match mem::replace(&mut guard.blocker, BlockedReceiver(signal_token)) { + NoneBlocked => {} + _ => unreachable!(), + } + drop(guard); // unlock + *success = wait_token.wait_max_until(deadline); // block + let mut new_guard = lock.lock().unwrap(); // relock + if !*success { + abort_selection(&mut new_guard); + } + new_guard +} + +fn abort_selection<'a, T>(guard: &mut MutexGuard<'a , State>) -> bool { + match mem::replace(&mut guard.blocker, NoneBlocked) { + NoneBlocked => true, + BlockedSender(token) => { + guard.blocker = BlockedSender(token); + true + } + BlockedReceiver(token) => { drop(token); false } + } +} + /// Wakes up a thread, dropping the lock at the correct time fn wakeup(token: SignalToken, guard: MutexGuard>) { // We need to be careful to wake up the waiting thread *outside* of the mutex @@ -238,22 +271,37 @@ impl Packet { // // When reading this, remember that there can only ever be one receiver at // time. - pub fn recv(&self) -> Result { + pub fn recv(&self, deadline: Option) -> Result { let mut guard = self.lock.lock().unwrap(); - // Wait for the buffer to have something in it. No need for a while loop - // because we're the only receiver. - let mut waited = false; + let mut woke_up_after_waiting = false; + // Wait for the buffer to have something in it. No need for a + // while loop because we're the only receiver. if !guard.disconnected && guard.buf.size() == 0 { - guard = wait(&self.lock, guard, BlockedReceiver); - waited = true; + if let Some(deadline) = deadline { + guard = wait_timeout_receiver(&self.lock, + deadline, + guard, + &mut woke_up_after_waiting); + } else { + guard = wait(&self.lock, guard, BlockedReceiver); + woke_up_after_waiting = true; + } + } + + // NB: Channel could be disconnected while waiting, so the order of + // these conditionals is important. + if guard.disconnected && guard.buf.size() == 0 { + return Err(Disconnected); } - if guard.disconnected && guard.buf.size() == 0 { return Err(()) } // Pick up the data, wake up our neighbors, and carry on - assert!(guard.buf.size() > 0); + assert!(guard.buf.size() > 0 || (deadline.is_some() && !woke_up_after_waiting)); + + if guard.buf.size() == 0 { return Err(Empty); } + let ret = guard.buf.dequeue(); - self.wakeup_senders(waited, guard); + self.wakeup_senders(woke_up_after_waiting, guard); Ok(ret) } @@ -392,14 +440,7 @@ impl Packet { // The return value indicates whether there's data on this port. pub fn abort_selection(&self) -> bool { let mut guard = self.lock.lock().unwrap(); - match mem::replace(&mut guard.blocker, NoneBlocked) { - NoneBlocked => true, - BlockedSender(token) => { - guard.blocker = BlockedSender(token); - true - } - BlockedReceiver(token) => { drop(token); false } - } + abort_selection(&mut guard) } }