diff options
Diffstat (limited to 'src/sync/cv.rs')
-rw-r--r-- | src/sync/cv.rs | 1251 |
1 files changed, 1251 insertions, 0 deletions
diff --git a/src/sync/cv.rs b/src/sync/cv.rs new file mode 100644 index 0000000..714c6d6 --- /dev/null +++ b/src/sync/cv.rs @@ -0,0 +1,1251 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::mem; +use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; +use std::sync::Arc; + +use crate::sync::mu::{MutexGuard, MutexReadGuard, RawMutex}; +use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor}; + +const SPINLOCK: usize = 1 << 0; +const HAS_WAITERS: usize = 1 << 1; + +/// A primitive to wait for an event to occur without consuming CPU time. +/// +/// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some +/// condition to become true. The condition must always be verified while holding the `Mutex` lock. +/// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on +/// the `Condvar`. +/// +/// # Examples +/// +/// ```edition2018 +/// use std::sync::Arc; +/// use std::thread; +/// use std::sync::mpsc::channel; +/// +/// use libchromeos::sync::{block_on, Condvar, Mutex}; +/// +/// const N: usize = 13; +/// +/// // Spawn a few threads to increment a shared variable (non-atomically), and +/// // let all threads waiting on the Condvar know once the increments are done. +/// let data = Arc::new(Mutex::new(0)); +/// let cv = Arc::new(Condvar::new()); +/// +/// for _ in 0..N { +/// let (data, cv) = (data.clone(), cv.clone()); +/// thread::spawn(move || { +/// let mut data = block_on(data.lock()); +/// *data += 1; +/// if *data == N { +/// cv.notify_all(); +/// } +/// }); +/// } +/// +/// let mut val = block_on(data.lock()); +/// while *val != N { +/// val = block_on(cv.wait(val)); +/// } +/// ``` +pub struct Condvar { + state: AtomicUsize, + waiters: UnsafeCell<WaiterList>, + mu: UnsafeCell<usize>, +} + +impl Condvar { + /// Creates a new condition variable ready to be waited on and notified. + pub fn new() -> Condvar { + Condvar { + state: AtomicUsize::new(0), + waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())), + mu: UnsafeCell::new(0), + } + } + + /// Block the current thread until this `Condvar` is notified by another thread. + /// + /// This method will atomically unlock the `Mutex` held by `guard` and then block the current + /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up + /// the thread. + /// + /// To allow for more efficient scheduling, this call may return even when the programmer + /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a + /// loop that checks the predicate before continuing. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Condvar` is notified. + /// + /// # Panics + /// + /// This method will panic if used with more than one `Mutex` at the same time. + /// + /// # Examples + /// + /// ``` + /// # use std::sync::Arc; + /// # use std::thread; + /// + /// # use libchromeos::sync::{block_on, Condvar, Mutex}; + /// + /// # let mu = Arc::new(Mutex::new(false)); + /// # let cv = Arc::new(Condvar::new()); + /// # let (mu2, cv2) = (mu.clone(), cv.clone()); + /// + /// # let t = thread::spawn(move || { + /// # *block_on(mu2.lock()) = true; + /// # cv2.notify_all(); + /// # }); + /// + /// let mut ready = block_on(mu.lock()); + /// while !*ready { + /// ready = block_on(cv.wait(ready)); + /// } + /// + /// # t.join().expect("failed to join thread"); + /// ``` + // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code + // that doesn't compile. + #[allow(clippy::needless_lifetimes)] + pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> { + let waiter = Arc::new(Waiter::new( + WaiterKind::Exclusive, + cancel_waiter, + self as *const Condvar as usize, + WaitingFor::Condvar, + )); + + self.add_waiter(waiter.clone(), guard.as_raw_mutex()); + + // Get a reference to the mutex and then drop the lock. + let mu = guard.into_inner(); + + // Wait to be woken up. + waiter.wait().await; + + // Now re-acquire the lock. + mu.lock_from_cv().await + } + + /// Like `wait()` but takes and returns a `MutexReadGuard` instead. + // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code + // that doesn't compile. + #[allow(clippy::needless_lifetimes)] + pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> { + let waiter = Arc::new(Waiter::new( + WaiterKind::Shared, + cancel_waiter, + self as *const Condvar as usize, + WaitingFor::Condvar, + )); + + self.add_waiter(waiter.clone(), guard.as_raw_mutex()); + + // Get a reference to the mutex and then drop the lock. + let mu = guard.into_inner(); + + // Wait to be woken up. + waiter.wait().await; + + // Now re-acquire the lock. + mu.read_lock_from_cv().await + } + + fn add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex) { + // Acquire the spin lock. + let mut oldstate = self.state.load(Ordering::Relaxed); + while (oldstate & SPINLOCK) != 0 + || self.state.compare_and_swap( + oldstate, + oldstate | SPINLOCK | HAS_WAITERS, + Ordering::Acquire, + ) != oldstate + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access and the reference does not escape + // this function. + let mu = unsafe { &mut *self.mu.get() }; + let muptr = raw_mutex as *const RawMutex as usize; + + match *mu { + 0 => *mu = muptr, + p if p == muptr => {} + _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"), + } + + // Safe because the spin lock guarantees exclusive access. + unsafe { (*self.waiters.get()).push_back(waiter) }; + + // Release the spin lock. Use a direct store here because no other thread can modify + // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier + // because we just added a waiter. + self.state.store(HAS_WAITERS, Ordering::Release); + } + + /// Notify at most one thread currently waiting on the `Condvar`. + /// + /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to + /// `wait`. + /// + /// Unlike more traditional condition variable interfaces, this method requires a reference to + /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call + /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking + /// a reference to the `Mutex` here allows us to make some optimizations that can improve + /// performance by reducing unnecessary wakeups. + pub fn notify_one(&self) { + let mut oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 { + // No waiters. + return; + } + + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_and_swap(oldstate, oldstate | SPINLOCK, Ordering::Acquire) + != oldstate + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + let (mut wake_list, all_readers) = get_wake_list(waiters); + + // Safe because the spin lock guarantees exclusive access. + let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; + + let newstate = if waiters.is_empty() { + // Also clear the mutex associated with this Condvar since there are no longer any + // waiters. Safe because the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + // We are releasing the spin lock and there are no more waiters so we can clear all bits + // in `self.state`. + 0 + } else { + // There are still waiters so we need to keep the HAS_WAITERS bit in the state. + HAS_WAITERS + }; + + // Try to transfer waiters before releasing the spin lock. + if !wake_list.is_empty() { + // Safe because there was a waiter in the queue and the thread that owns the waiter also + // owns a reference to the Mutex, guaranteeing that the pointer is valid. + unsafe { (*muptr).transfer_waiters(&mut wake_list, all_readers) }; + } + + // Release the spin lock. + self.state.store(newstate, Ordering::Release); + + // Now wake any waiters still left in the wake list. + for w in wake_list { + w.wake(); + } + } + + /// Notify all threads currently waiting on the `Condvar`. + /// + /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`. + /// + /// Unlike more traditional condition variable interfaces, this method requires a reference to + /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call + /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking + /// a reference to the `Mutex` here allows us to make some optimizations that can improve + /// performance by reducing unnecessary wakeups. + pub fn notify_all(&self) { + let mut oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 { + // No waiters. + return; + } + + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_and_swap(oldstate, oldstate | SPINLOCK, Ordering::Acquire) + != oldstate + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access to `self.waiters`. + let mut wake_list = unsafe { (*self.waiters.get()).take() }; + + // Safe because the spin lock guarantees exclusive access. + let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; + + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + // Try to transfer waiters before releasing the spin lock. + if !wake_list.is_empty() { + // Safe because there was a waiter in the queue and the thread that owns the waiter also + // owns a reference to the Mutex, guaranteeing that the pointer is valid. + unsafe { (*muptr).transfer_waiters(&mut wake_list, false) }; + } + + // Mark any waiters left as no longer waiting for the Condvar. + for w in &wake_list { + w.set_waiting_for(WaitingFor::None); + } + + // Release the spin lock. We can clear all bits in the state since we took all the waiters. + self.state.store(0, Ordering::Release); + + // Now wake any waiters still left in the wake list. + for w in wake_list { + w.wake(); + } + } + + fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) -> bool { + let mut oldstate = self.state.load(Ordering::Relaxed); + while oldstate & SPINLOCK != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock provides exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + + let waiting_for = waiter.is_waiting_for(); + if waiting_for == WaitingFor::Mutex { + // The waiter was moved to the mutex's list. Retry the cancel. + let set_on_release = if waiters.is_empty() { + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + 0 + } else { + HAS_WAITERS + }; + + self.state.store(set_on_release, Ordering::Release); + + false + } else { + // Don't drop the old waiter now as we're still holding the spin lock. + let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar { + // Safe because we know that the waiter is still linked and is waiting for the Condvar, + // which guarantees that it is still in `self.waiters`. + let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) }; + cursor.remove() + } else { + None + }; + + let (mut wake_list, all_readers) = if wake_next || waiting_for == WaitingFor::None { + // Either the waiter was already woken or it's been removed from the condvar's waiter + // list and is going to be woken. Either way, we need to wake up another thread. + get_wake_list(waiters) + } else { + (WaiterList::new(WaiterAdapter::new()), false) + }; + + // Safe because the spin lock guarantees exclusive access. + let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; + + // Try to transfer waiters before releasing the spin lock. + if !wake_list.is_empty() { + // Safe because there was a waiter in the queue and the thread that owns the waiter also + // owns a reference to the Mutex, guaranteeing that the pointer is valid. + unsafe { (*muptr).transfer_waiters(&mut wake_list, all_readers) }; + } + + let set_on_release = if waiters.is_empty() { + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + 0 + } else { + HAS_WAITERS + }; + + self.state.store(set_on_release, Ordering::Release); + + // Now wake any waiters still left in the wake list. + for w in wake_list { + w.wake(); + } + + mem::drop(old_waiter); + true + } + } +} + +unsafe impl Send for Condvar {} +unsafe impl Sync for Condvar {} + +impl Default for Condvar { + fn default() -> Self { + Self::new() + } +} + +// Scan `waiters` and return all waiters that should be woken up. If all waiters in the returned +// wait list are readers then the returned bool will be true. +// +// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are +// waiting for a shared lock are also woken up. In addition one writer is woken up, if possible. +// +// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and +// the rest of the list is not scanned. +fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, bool) { + let mut to_wake = WaiterList::new(WaiterAdapter::new()); + let mut cursor = waiters.front_mut(); + + let mut waking_readers = false; + let mut all_readers = true; + while let Some(w) = cursor.get() { + match w.kind() { + WaiterKind::Exclusive if !waking_readers => { + // This is the first waiter and it's a writer. No need to check the other waiters. + // Also mark the waiter as having been removed from the Condvar's waiter list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + all_readers = false; + break; + } + + WaiterKind::Shared => { + // This is a reader and the first waiter in the list was not a writer so wake up all + // the readers in the wait list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + waking_readers = true; + } + + WaiterKind::Exclusive => { + debug_assert!(waking_readers); + if all_readers { + // We are waking readers but we need to ensure that at least one writer is woken + // up. Since we haven't yet woken up a writer, wake up this one. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + all_readers = false; + } else { + // We are waking readers and have already woken one writer. Skip this one. + cursor.move_next(); + } + } + } + } + + (to_wake, all_readers) +} + +fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) -> bool { + let condvar = cv as *const Condvar; + + // Safe because the thread that owns the waiter being canceled must also own a reference to the + // Condvar, which guarantees that this pointer is valid. + unsafe { (*condvar).cancel_waiter(waiter, wake_next) } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::future::Future; + use std::mem; + use std::ptr; + use std::rc::Rc; + use std::sync::mpsc::{channel, Sender}; + use std::sync::Arc; + use std::task::{Context, Poll}; + use std::thread::{self, JoinHandle}; + use std::time::Duration; + + use futures::channel::oneshot; + use futures::task::{waker_ref, ArcWake}; + use futures::{select, FutureExt}; + use futures_executor::{LocalPool, LocalSpawner, ThreadPool}; + use futures_util::task::LocalSpawnExt; + + use crate::sync::{block_on, Mutex}; + + // Dummy waker used when we want to manually drive futures. + struct TestWaker; + impl ArcWake for TestWaker { + fn wake_by_ref(_arc_self: &Arc<Self>) {} + } + + #[test] + fn smoke() { + let cv = Condvar::new(); + cv.notify_one(); + cv.notify_all(); + } + + #[test] + fn notify_one() { + let mu = Arc::new(Mutex::new(())); + let cv = Arc::new(Condvar::new()); + + let mu2 = mu.clone(); + let cv2 = cv.clone(); + + let guard = block_on(mu.lock()); + thread::spawn(move || { + let _g = block_on(mu2.lock()); + cv2.notify_one(); + }); + + let guard = block_on(cv.wait(guard)); + mem::drop(guard); + } + + #[test] + fn multi_mutex() { + const NUM_THREADS: usize = 5; + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + + let mut threads = Vec::with_capacity(NUM_THREADS); + for _ in 0..NUM_THREADS { + let mu = mu.clone(); + let cv = cv.clone(); + + threads.push(thread::spawn(move || { + let mut ready = block_on(mu.lock()); + while !*ready { + ready = block_on(cv.wait(ready)); + } + })); + } + + let mut g = block_on(mu.lock()); + *g = true; + mem::drop(g); + cv.notify_all(); + + threads + .into_iter() + .map(JoinHandle::join) + .collect::<thread::Result<()>>() + .expect("Failed to join threads"); + + // Now use the Condvar with a different mutex. + let alt_mu = Arc::new(Mutex::new(None)); + let alt_mu2 = alt_mu.clone(); + let cv2 = cv.clone(); + let handle = thread::spawn(move || { + let mut g = block_on(alt_mu2.lock()); + while g.is_none() { + g = block_on(cv2.wait(g)); + } + }); + + let mut alt_g = block_on(alt_mu.lock()); + *alt_g = Some(()); + mem::drop(alt_g); + cv.notify_all(); + + handle + .join() + .expect("Failed to join thread alternate mutex"); + } + + #[test] + fn notify_one_single_thread_async() { + async fn notify(mu: Rc<Mutex<()>>, cv: Rc<Condvar>) { + let _g = mu.lock().await; + cv.notify_one(); + } + + async fn wait(mu: Rc<Mutex<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) { + let mu2 = Rc::clone(&mu); + let cv2 = Rc::clone(&cv); + + let g = mu.lock().await; + // Has to be spawned _after_ acquiring the lock to prevent a race + // where the notify happens before the waiter has acquired the lock. + spawner + .spawn_local(notify(mu2, cv2)) + .expect("Failed to spawn `notify` task"); + let _g = cv.wait(g).await; + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(())); + let cv = Rc::new(Condvar::new()); + + spawner + .spawn_local(wait(mu, cv, spawner.clone())) + .expect("Failed to spawn `wait` task"); + + ex.run(); + } + + #[test] + fn notify_one_multi_thread_async() { + async fn notify(mu: Arc<Mutex<()>>, cv: Arc<Condvar>) { + let _g = mu.lock().await; + cv.notify_one(); + } + + async fn wait(mu: Arc<Mutex<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) { + let mu2 = Arc::clone(&mu); + let cv2 = Arc::clone(&cv); + + let g = mu.lock().await; + // Has to be spawned _after_ acquiring the lock to prevent a race + // where the notify happens before the waiter has acquired the lock. + pool.spawn_ok(notify(mu2, cv2)); + let _g = cv.wait(g).await; + + tx.send(()).expect("Failed to send completion notification"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(())); + let cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + ex.spawn_ok(wait(mu, cv, tx, ex.clone())); + + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion notification"); + } + + #[test] + fn notify_one_with_cancel() { + const TASKS: usize = 17; + const OBSERVERS: usize = 7; + const ITERATIONS: usize = 103; + + async fn observe(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + let _ = unsafe { ptr::read_volatile(&*count as *const usize) }; + } + + async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + *count -= 1; + } + + async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) { + for _ in 0..TASKS * OBSERVERS * ITERATIONS { + *mu.lock().await += 1; + cv.notify_one(); + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn observe_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = observe(&mu, &cv).fuse() => {}, + () = observe(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn decrement_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = decrement(&mu, &cv).fuse() => {}, + () = decrement(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let cv = Arc::new(Condvar::new()); + let alt_cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(decrement_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + for _ in 0..OBSERVERS { + ex.spawn_ok(observe_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); + ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); + + for _ in 0..TASKS + OBSERVERS + 2 { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS) + ); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_all_with_cancel() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + *count -= 1; + } + + async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) { + for _ in 0..TASKS * ITERATIONS { + *mu.lock().await += 1; + cv.notify_all(); + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn decrement_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = decrement(&mu, &cv).fuse() => {}, + () = decrement(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let cv = Arc::new(Condvar::new()); + let alt_cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(decrement_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); + ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); + + for _ in 0..TASKS + 2 { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + TASKS * ITERATIONS + ); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); + } + #[test] + fn notify_all() { + const THREADS: usize = 13; + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + let (tx, rx) = channel(); + + let mut threads = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + let mu2 = mu.clone(); + let cv2 = cv.clone(); + let tx2 = tx.clone(); + + threads.push(thread::spawn(move || { + let mut count = block_on(mu2.lock()); + *count += 1; + if *count == THREADS { + tx2.send(()).unwrap(); + } + + while *count != 0 { + count = block_on(cv2.wait(count)); + } + })); + } + + mem::drop(tx); + + // Wait till all threads have started. + rx.recv_timeout(Duration::from_secs(5)).unwrap(); + + let mut count = block_on(mu.lock()); + *count = 0; + mem::drop(count); + cv.notify_all(); + + for t in threads { + t.join().unwrap(); + } + } + + #[test] + fn notify_all_single_thread_async() { + const TASKS: usize = 13; + + async fn reset(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>) { + let mut count = mu.lock().await; + *count = 0; + cv.notify_all(); + } + + async fn watcher(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) { + let mut count = mu.lock().await; + *count += 1; + if *count == TASKS { + spawner + .spawn_local(reset(mu.clone(), cv.clone())) + .expect("Failed to spawn reset task"); + } + + while *count != 0 { + count = cv.wait(count).await; + } + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(0)); + let cv = Rc::new(Condvar::new()); + + for _ in 0..TASKS { + spawner + .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone())) + .expect("Failed to spawn watcher task"); + } + + ex.run(); + } + + #[test] + fn notify_all_multi_thread_async() { + const TASKS: usize = 13; + + async fn reset(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + *count = 0; + cv.notify_all(); + } + + async fn watcher( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + pool: ThreadPool, + tx: Sender<()>, + ) { + let mut count = mu.lock().await; + *count += 1; + if *count == TASKS { + pool.spawn_ok(reset(mu.clone(), cv.clone())); + } + + while *count != 0 { + count = cv.wait(count).await; + } + + tx.send(()).expect("Failed to send completion notification"); + } + + let pool = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone())); + } + + for _ in 0..TASKS { + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion notification"); + } + } + + #[test] + fn wake_all_readers() { + async fn read(mu: Arc<Mutex<bool>>, cv: Arc<Condvar>) { + let mut ready = mu.read_lock().await; + while !*ready { + ready = cv.wait_read(ready).await; + } + } + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + let mut readers = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First have all the readers wait on the Condvar. + for r in &mut readers { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("reader unexpectedly ready"); + } + } + + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + // Now make the condition true and notify the condvar. Even though we will call notify_one, + // all the readers should be woken up. + *block_on(mu.lock()) = true; + cv.notify_one(); + + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + // All readers should now be able to complete. + for r in &mut readers { + if let Poll::Pending = r.as_mut().poll(&mut cx) { + panic!("reader unable to complete"); + } + } + } + + #[test] + fn cancel_before_notify() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + *block_on(mu.lock()) = 2; + // Drop fut1 before notifying the cv. + mem::drop(fut1); + cv.notify_one(); + + // fut2 should now be ready to complete. + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + if let Poll::Pending = fut2.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_notify() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + *block_on(mu.lock()) = 2; + cv.notify_one(); + + // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2. + mem::drop(fut1); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + if let Poll::Pending = fut2.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_transfer() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + let mut count = block_on(mu.lock()); + *count = 2; + + // Notify the cv while holding the lock. Only transfer one waiter. + cv.notify_one(); + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + // Drop the lock and then the future. This should not cause fut2 to become runnable as it + // should still be in the Condvar's wait queue. + mem::drop(count); + mem::drop(fut1); + + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + + // Now wake up fut2. Since the lock isn't held, it should wake up immediately. + cv.notify_one(); + if let Poll::Pending = fut2.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_transfer_and_wake() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + let mut count = block_on(mu.lock()); + *count = 2; + + // Notify the cv while holding the lock. This should transfer both waiters to the mutex's + // wait queue. + cv.notify_all(); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + mem::drop(count); + + mem::drop(fut1); + + if let Poll::Pending = fut2.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn timed_wait() { + async fn wait_deadline( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + timeout: oneshot::Receiver<()>, + ) { + let mut count = mu.lock().await; + + if *count == 0 { + let mut rx = timeout.fuse(); + + while *count == 0 { + select! { + res = rx => { + if let Err(e) = res { + panic!("Error while receiving timeout notification: {}", e); + } + + return; + }, + c = cv.wait(count).fuse() => count = c, + } + } + } + + *count += 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let (tx, rx) = oneshot::channel(); + let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx)); + + if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) { + panic!("wait_deadline unexpectedly ready"); + } + + assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS); + + // Signal the channel, which should cancel the wait. + tx.send(()).expect("Failed to send wakeup"); + + // Wait for the timer to run out. + if let Poll::Pending = wait.as_mut().poll(&mut cx) { + panic!("wait_deadline unable to complete in time"); + } + + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(*block_on(mu.lock()), 0); + } +} |