From 3c4d157c8419d73574587a22a8095ad32d860af5 Mon Sep 17 00:00:00 2001 From: Jethro Beekman Date: Fri, 30 Aug 2019 20:35:27 -0700 Subject: [PATCH] Fix unlock ordering in SGX synchronization primitives --- src/libstd/sys/sgx/condvar.rs | 3 +-- src/libstd/sys/sgx/mutex.rs | 4 ++-- src/libstd/sys/sgx/rwlock.rs | 34 +++++++++++++++++++-------------- src/libstd/sys/sgx/waitqueue.rs | 11 +++++++++-- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/libstd/sys/sgx/condvar.rs b/src/libstd/sys/sgx/condvar.rs index 000bb19f269..cc1c04a83e7 100644 --- a/src/libstd/sys/sgx/condvar.rs +++ b/src/libstd/sys/sgx/condvar.rs @@ -27,8 +27,7 @@ impl Condvar { pub unsafe fn wait(&self, mutex: &Mutex) { let guard = self.inner.lock(); - mutex.unlock(); - WaitQueue::wait(guard); + WaitQueue::wait(guard, || mutex.unlock()); mutex.lock() } diff --git a/src/libstd/sys/sgx/mutex.rs b/src/libstd/sys/sgx/mutex.rs index f325fb1dd58..662da8b3f66 100644 --- a/src/libstd/sys/sgx/mutex.rs +++ b/src/libstd/sys/sgx/mutex.rs @@ -22,7 +22,7 @@ impl Mutex { let mut guard = self.inner.lock(); if *guard.lock_var() { // Another thread has the lock, wait - WaitQueue::wait(guard) + WaitQueue::wait(guard, ||{}) // Another thread has passed the lock to us } else { // We are just now obtaining the lock @@ -83,7 +83,7 @@ impl ReentrantMutex { match guard.lock_var().owner { Some(tcs) if tcs != thread::current() => { // Another thread has the lock, wait - WaitQueue::wait(guard); + WaitQueue::wait(guard, ||{}); // Another thread has passed the lock to us }, _ => { diff --git a/src/libstd/sys/sgx/rwlock.rs b/src/libstd/sys/sgx/rwlock.rs index 30c47e44eef..e2f94b1d928 100644 --- a/src/libstd/sys/sgx/rwlock.rs +++ b/src/libstd/sys/sgx/rwlock.rs @@ -31,7 +31,7 @@ impl RWLock { if *wguard.lock_var() || !wguard.queue_empty() { // Another thread has or is waiting for the write lock, wait drop(wguard); - WaitQueue::wait(rguard); + WaitQueue::wait(rguard, ||{}); // Another thread has passed the lock to us } else { // No waiting writers, acquire the read lock @@ -62,7 +62,7 @@ impl RWLock { if *wguard.lock_var() || rguard.lock_var().is_some() { // Another thread has the lock, wait drop(rguard); - WaitQueue::wait(wguard); + WaitQueue::wait(wguard, ||{}); // Another thread has passed the lock to us } else { // We are just now obtaining the lock @@ -97,6 +97,7 @@ impl RWLock { if let Ok(mut wguard) = WaitQueue::notify_one(wguard) { // A writer was waiting, pass the lock *wguard.lock_var_mut() = true; + wguard.drop_after(rguard); } else { // No writers were waiting, the lock is released rtassert!(rguard.queue_empty()); @@ -117,21 +118,26 @@ impl RWLock { rguard: SpinMutexGuard<'_, WaitVariable>>, wguard: SpinMutexGuard<'_, WaitVariable>, ) { - if let Err(mut wguard) = WaitQueue::notify_one(wguard) { - // No writers waiting, release the write lock - *wguard.lock_var_mut() = false; - if let Ok(mut rguard) = WaitQueue::notify_all(rguard) { - // One or more readers were waiting, pass the lock to them - if let NotifiedTcs::All { count } = rguard.notified_tcs() { - *rguard.lock_var_mut() = Some(count) + match WaitQueue::notify_one(wguard) { + Err(mut wguard) => { + // No writers waiting, release the write lock + *wguard.lock_var_mut() = false; + if let Ok(mut rguard) = WaitQueue::notify_all(rguard) { + // One or more readers were waiting, pass the lock to them + if let NotifiedTcs::All { count } = rguard.notified_tcs() { + *rguard.lock_var_mut() = Some(count) + } else { + unreachable!() // called notify_all + } + rguard.drop_after(wguard); } else { - unreachable!() // called notify_all + // No readers waiting, the lock is released } - } else { - // No readers waiting, the lock is released + }, + Ok(wguard) => { + // There was a thread waiting for write, just pass the lock + wguard.drop_after(rguard); } - } else { - // There was a thread waiting for write, just pass the lock } } diff --git a/src/libstd/sys/sgx/waitqueue.rs b/src/libstd/sys/sgx/waitqueue.rs index d542f9b4101..3cb40e509b6 100644 --- a/src/libstd/sys/sgx/waitqueue.rs +++ b/src/libstd/sys/sgx/waitqueue.rs @@ -98,6 +98,12 @@ impl<'a, T> WaitGuard<'a, T> { pub fn notified_tcs(&self) -> NotifiedTcs { self.notified_tcs } + + /// Drop this `WaitGuard`, after dropping another `guard`. + pub fn drop_after(self, guard: U) { + drop(guard); + drop(self); + } } impl<'a, T> Deref for WaitGuard<'a, T> { @@ -140,7 +146,7 @@ impl WaitQueue { /// until a wakeup event. /// /// This function does not return until this thread has been awoken. - pub fn wait(mut guard: SpinMutexGuard<'_, WaitVariable>) { + pub fn wait(mut guard: SpinMutexGuard<'_, WaitVariable>, before_wait: F) { // very unsafe: check requirements of UnsafeList::push unsafe { let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry { @@ -149,6 +155,7 @@ impl WaitQueue { })); let entry = guard.queue.inner.push(&mut entry); drop(guard); + before_wait(); while !entry.lock().wake { // don't panic, this would invalidate `entry` during unwinding let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE)); @@ -545,7 +552,7 @@ mod tests { assert!(WaitQueue::notify_one(wq2.lock()).is_ok()); }); - WaitQueue::wait(locked); + WaitQueue::wait(locked, ||{}); t1.join().unwrap(); }