Fix unlock ordering in SGX synchronization primitives

This commit is contained in:
Jethro Beekman 2019-08-30 20:35:27 -07:00
parent 72b2abfd65
commit 3c4d157c84
4 changed files with 32 additions and 20 deletions

View File

@ -27,8 +27,7 @@ impl Condvar {
pub unsafe fn wait(&self, mutex: &Mutex) {
let guard = self.inner.lock();
mutex.unlock();
WaitQueue::wait(guard);
WaitQueue::wait(guard, || mutex.unlock());
mutex.lock()
}

View File

@ -22,7 +22,7 @@ impl Mutex {
let mut guard = self.inner.lock();
if *guard.lock_var() {
// Another thread has the lock, wait
WaitQueue::wait(guard)
WaitQueue::wait(guard, ||{})
// Another thread has passed the lock to us
} else {
// We are just now obtaining the lock
@ -83,7 +83,7 @@ impl ReentrantMutex {
match guard.lock_var().owner {
Some(tcs) if tcs != thread::current() => {
// Another thread has the lock, wait
WaitQueue::wait(guard);
WaitQueue::wait(guard, ||{});
// Another thread has passed the lock to us
},
_ => {

View File

@ -31,7 +31,7 @@ impl RWLock {
if *wguard.lock_var() || !wguard.queue_empty() {
// Another thread has or is waiting for the write lock, wait
drop(wguard);
WaitQueue::wait(rguard);
WaitQueue::wait(rguard, ||{});
// Another thread has passed the lock to us
} else {
// No waiting writers, acquire the read lock
@ -62,7 +62,7 @@ impl RWLock {
if *wguard.lock_var() || rguard.lock_var().is_some() {
// Another thread has the lock, wait
drop(rguard);
WaitQueue::wait(wguard);
WaitQueue::wait(wguard, ||{});
// Another thread has passed the lock to us
} else {
// We are just now obtaining the lock
@ -97,6 +97,7 @@ impl RWLock {
if let Ok(mut wguard) = WaitQueue::notify_one(wguard) {
// A writer was waiting, pass the lock
*wguard.lock_var_mut() = true;
wguard.drop_after(rguard);
} else {
// No writers were waiting, the lock is released
rtassert!(rguard.queue_empty());
@ -117,21 +118,26 @@ impl RWLock {
rguard: SpinMutexGuard<'_, WaitVariable<Option<NonZeroUsize>>>,
wguard: SpinMutexGuard<'_, WaitVariable<bool>>,
) {
if let Err(mut wguard) = WaitQueue::notify_one(wguard) {
// No writers waiting, release the write lock
*wguard.lock_var_mut() = false;
if let Ok(mut rguard) = WaitQueue::notify_all(rguard) {
// One or more readers were waiting, pass the lock to them
if let NotifiedTcs::All { count } = rguard.notified_tcs() {
*rguard.lock_var_mut() = Some(count)
match WaitQueue::notify_one(wguard) {
Err(mut wguard) => {
// No writers waiting, release the write lock
*wguard.lock_var_mut() = false;
if let Ok(mut rguard) = WaitQueue::notify_all(rguard) {
// One or more readers were waiting, pass the lock to them
if let NotifiedTcs::All { count } = rguard.notified_tcs() {
*rguard.lock_var_mut() = Some(count)
} else {
unreachable!() // called notify_all
}
rguard.drop_after(wguard);
} else {
unreachable!() // called notify_all
// No readers waiting, the lock is released
}
} else {
// No readers waiting, the lock is released
},
Ok(wguard) => {
// There was a thread waiting for write, just pass the lock
wguard.drop_after(rguard);
}
} else {
// There was a thread waiting for write, just pass the lock
}
}

View File

@ -98,6 +98,12 @@ impl<'a, T> WaitGuard<'a, T> {
pub fn notified_tcs(&self) -> NotifiedTcs {
self.notified_tcs
}
/// Drop this `WaitGuard`, after dropping another `guard`.
pub fn drop_after<U>(self, guard: U) {
drop(guard);
drop(self);
}
}
impl<'a, T> Deref for WaitGuard<'a, T> {
@ -140,7 +146,7 @@ impl WaitQueue {
/// until a wakeup event.
///
/// This function does not return until this thread has been awoken.
pub fn wait<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>) {
pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) {
// very unsafe: check requirements of UnsafeList::push
unsafe {
let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
@ -149,6 +155,7 @@ impl WaitQueue {
}));
let entry = guard.queue.inner.push(&mut entry);
drop(guard);
before_wait();
while !entry.lock().wake {
// don't panic, this would invalidate `entry` during unwinding
let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
@ -545,7 +552,7 @@ mod tests {
assert!(WaitQueue::notify_one(wq2.lock()).is_ok());
});
WaitQueue::wait(locked);
WaitQueue::wait(locked, ||{});
t1.join().unwrap();
}