diff options
Diffstat (limited to 'src/sync')
29 files changed, 1684 insertions, 3940 deletions
diff --git a/src/sync/barrier.rs b/src/sync/barrier.rs index 6286334..fddb3a5 100644 --- a/src/sync/barrier.rs +++ b/src/sync/barrier.rs @@ -96,7 +96,7 @@ impl Barrier { // wake everyone, increment the generation, and return state .waker - .broadcast(state.generation) + .send(state.generation) .expect("there is at least one receiver"); state.arrived = 0; state.generation += 1; @@ -110,9 +110,11 @@ impl Barrier { let mut wait = self.wait.clone(); loop { + let _ = wait.changed().await; + // note that the first time through the loop, this _will_ yield a generation // immediately, since we cloned a receiver that has never seen any values. - if wait.recv().await.expect("sender hasn't been closed") >= generation { + if *wait.borrow() >= generation { break; } } diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs index 070cd20..0b50e4f 100644 --- a/src/sync/batch_semaphore.rs +++ b/src/sync/batch_semaphore.rs @@ -1,3 +1,4 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] //! # Implementation Details //! //! The semaphore is implemented using an intrusive linked list of waiters. An @@ -36,7 +37,7 @@ pub(crate) struct Semaphore { } struct Waitlist { - queue: LinkedList<Waiter>, + queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, closed: bool, } @@ -96,10 +97,13 @@ impl Semaphore { /// Note that this reserves three bits of flags in the permit counter, but /// we only actually use one of them. However, the previous semaphore /// implementation used three bits, so we will continue to reserve them to - /// avoid a breaking change if additional flags need to be aadded in the + /// avoid a breaking change if additional flags need to be added in the /// future. pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3; const CLOSED: usize = 1; + // The least-significant bit in the number of permits is reserved to use + // as a flag indicating that the semaphore has been closed. Consequently + // PERMIT_SHIFT is used to leave that bit for that purpose. const PERMIT_SHIFT: usize = 1; /// Creates a new semaphore with the initial number of permits @@ -120,6 +124,27 @@ impl Semaphore { } } + /// Creates a new semaphore with the initial number of permits + /// + /// Maximum number of permits on 32-bit platforms is `1<<29`. + /// + /// If the specified number of permits exceeds the maximum permit amount + /// Then the value will get clamped to the maximum number of permits. + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + pub(crate) const fn const_new(mut permits: usize) -> Self { + // NOTE: assertions and by extension panics are still being worked on: https://github.com/rust-lang/rust/issues/74925 + // currently we just clamp the permit count when it exceeds the max + permits &= Self::MAX_PERMITS; + + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::const_new(Waitlist { + queue: LinkedList::new(), + closed: false, + }), + } + } + /// Returns the current number of available permits pub(crate) fn available_permits(&self) -> usize { self.permits.load(Acquire) >> Self::PERMIT_SHIFT @@ -134,16 +159,15 @@ impl Semaphore { } // Assign permits to the wait queue - self.add_permits_locked(added, self.waiters.lock().unwrap()); + self.add_permits_locked(added, self.waiters.lock()); } /// Closes the semaphore. This prevents the semaphore from issuing new /// permits and notifies all pending waiters. // This will be used once the bounded MPSC is updated to use the new // semaphore implementation. - #[allow(dead_code)] pub(crate) fn close(&self) { - let mut waiters = self.waiters.lock().unwrap(); + let mut waiters = self.waiters.lock(); // If the semaphore's permits counter has enough permits for an // unqueued waiter to acquire all the permits it needs immediately, // it won't touch the wait list. Therefore, we have to set a bit on @@ -161,6 +185,11 @@ impl Semaphore { } } + /// Returns true if the semaphore is closed + pub(crate) fn is_closed(&self) -> bool { + self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED + } + pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { assert!( num_permits as usize <= Self::MAX_PERMITS, @@ -170,8 +199,8 @@ impl Semaphore { let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; let mut curr = self.permits.load(Acquire); loop { - // Has the semaphore closed?git - if curr & Self::CLOSED > 0 { + // Has the semaphore closed? + if curr & Self::CLOSED == Self::CLOSED { return Err(TryAcquireError::Closed); } @@ -203,7 +232,7 @@ impl Semaphore { let mut lock = Some(waiters); let mut is_empty = false; while rem > 0 { - let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap()); + let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); 'inner: for slot in &mut wakers[..] { // Was the waiter assigned enough permits to wake it? match waiters.queue.last() { @@ -296,7 +325,7 @@ impl Semaphore { // counter. Otherwise, if we subtract the permits and then // acquire the lock, we might miss additional permits being // added while waiting for the lock. - lock = Some(self.waiters.lock().unwrap()); + lock = Some(self.waiters.lock()); } match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { @@ -306,7 +335,7 @@ impl Semaphore { if !queued { return Ready(Ok(())); } else if lock.is_none() { - break self.waiters.lock().unwrap(); + break self.waiters.lock(); } } break lock.expect("lock must be acquired before waiting"); @@ -357,7 +386,7 @@ impl Semaphore { impl fmt::Debug for Semaphore { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Semaphore") - .field("permits", &self.permits.load(Relaxed)) + .field("permits", &self.available_permits()) .finish() } } @@ -456,14 +485,7 @@ impl Drop for Acquire<'_> { // This is where we ensure safety. The future is being dropped, // which means we must ensure that the waiter entry is no longer stored // in the linked list. - let mut waiters = match self.semaphore.waiters.lock() { - Ok(lock) => lock, - // Removing the node from the linked list is necessary to ensure - // safety. Even if the lock was poisoned, we need to make sure it is - // removed from the linked list before dropping it --- otherwise, - // the list will contain a dangling pointer to this node. - Err(e) => e.into_inner(), - }; + let mut waiters = self.semaphore.waiters.lock(); // remove the entry from the list let node = NonNull::from(&mut self.node); @@ -506,20 +528,14 @@ impl TryAcquireError { /// Returns `true` if the error was caused by a closed semaphore. #[allow(dead_code)] // may be used later! pub(crate) fn is_closed(&self) -> bool { - match self { - TryAcquireError::Closed => true, - _ => false, - } + matches!(self, TryAcquireError::Closed) } /// Returns `true` if the error was caused by calling `try_acquire` on a /// semaphore with no available permits. #[allow(dead_code)] // may be used later! pub(crate) fn is_no_permits(&self) -> bool { - match self { - TryAcquireError::NoPermits => true, - _ => false, - } + matches!(self, TryAcquireError::NoPermits) } } diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs index 0c8716f..ee9aba0 100644 --- a/src/sync/broadcast.rs +++ b/src/sync/broadcast.rs @@ -21,7 +21,7 @@ //! ## Lagging //! //! As sent messages must be retained until **all** [`Receiver`] handles receive -//! a clone, broadcast channels are suspectible to the "slow receiver" problem. +//! a clone, broadcast channels are susceptible to the "slow receiver" problem. //! In this case, all but one receiver are able to receive values at the rate //! they are sent. Because one receiver is stalled, the channel starts to fill //! up. @@ -55,8 +55,8 @@ //! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe //! [`Receiver`]: crate::sync::broadcast::Receiver //! [`channel`]: crate::sync::broadcast::channel -//! [`RecvError::Lagged`]: crate::sync::broadcast::RecvError::Lagged -//! [`RecvError::Closed`]: crate::sync::broadcast::RecvError::Closed +//! [`RecvError::Lagged`]: crate::sync::broadcast::error::RecvError::Lagged +//! [`RecvError::Closed`]: crate::sync::broadcast::error::RecvError::Closed //! [`recv`]: crate::sync::broadcast::Receiver::recv //! //! # Examples @@ -107,6 +107,7 @@ //! assert_eq!(20, rx.recv().await.unwrap()); //! assert_eq!(30, rx.recv().await.unwrap()); //! } +//! ``` use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; @@ -194,58 +195,99 @@ pub struct Receiver<T> { /// Next position to read from next: u64, - - /// Used to support the deprecated `poll_recv` fn - waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>, } -/// Error returned by [`Sender::send`][Sender::send]. -/// -/// A **send** operation can only fail if there are no active receivers, -/// implying that the message could never be received. The error contains the -/// message being sent as a payload so it can be recovered. -#[derive(Debug)] -pub struct SendError<T>(pub T); +pub mod error { + //! Broadcast error types -/// An error returned from the [`recv`] function on a [`Receiver`]. -/// -/// [`recv`]: crate::sync::broadcast::Receiver::recv -/// [`Receiver`]: crate::sync::broadcast::Receiver -#[derive(Debug, PartialEq)] -pub enum RecvError { - /// There are no more active senders implying no further messages will ever - /// be sent. - Closed, + use std::fmt; - /// The receiver lagged too far behind. Attempting to receive again will - /// return the oldest message still retained by the channel. + /// Error returned by from the [`send`] function on a [`Sender`]. /// - /// Includes the number of skipped messages. - Lagged(u64), -} + /// A **send** operation can only fail if there are no active receivers, + /// implying that the message could never be received. The error contains the + /// message being sent as a payload so it can be recovered. + /// + /// [`send`]: crate::sync::broadcast::Sender::send + /// [`Sender`]: crate::sync::broadcast::Sender + #[derive(Debug)] + pub struct SendError<T>(pub T); -/// An error returned from the [`try_recv`] function on a [`Receiver`]. -/// -/// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv -/// [`Receiver`]: crate::sync::broadcast::Receiver -#[derive(Debug, PartialEq)] -pub enum TryRecvError { - /// The channel is currently empty. There are still active - /// [`Sender`][Sender] handles, so data may yet become available. - Empty, - - /// There are no more active senders implying no further messages will ever - /// be sent. - Closed, - - /// The receiver lagged too far behind and has been forcibly disconnected. - /// Attempting to receive again will return the oldest message still - /// retained by the channel. - /// - /// Includes the number of skipped messages. - Lagged(u64), + impl<T> fmt::Display for SendError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "channel closed") + } + } + + impl<T: fmt::Debug> std::error::Error for SendError<T> {} + + /// An error returned from the [`recv`] function on a [`Receiver`]. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + #[derive(Debug, PartialEq)] + pub enum RecvError { + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), + } + + impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RecvError::Closed => write!(f, "channel closed"), + RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } + } + + impl std::error::Error for RecvError {} + + /// An error returned from the [`try_recv`] function on a [`Receiver`]. + /// + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + #[derive(Debug, PartialEq)] + pub enum TryRecvError { + /// The channel is currently empty. There are still active + /// [`Sender`] handles, so data may yet become available. + /// + /// [`Sender`]: crate::sync::broadcast::Sender + Empty, + + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind and has been forcibly disconnected. + /// Attempting to receive again will return the oldest message still + /// retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), + } + + impl fmt::Display for TryRecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(f, "channel empty"), + TryRecvError::Closed => write!(f, "channel closed"), + TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } + } + + impl std::error::Error for TryRecvError {} } +use self::error::*; + /// Data shared between senders and receivers struct Shared<T> { /// slots in the channel @@ -273,7 +315,7 @@ struct Tail { closed: bool, /// Receivers waiting for a value - waiters: LinkedList<Waiter>, + waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, } /// Slot in the buffer @@ -373,8 +415,8 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe /// [`Receiver`]: crate::sync::broadcast::Receiver /// [`recv`]: crate::sync::broadcast::Receiver::recv -/// [`SendError`]: crate::sync::broadcast::SendError -/// [`RecvError`]: crate::sync::broadcast::RecvError +/// [`SendError`]: crate::sync::broadcast::error::SendError +/// [`RecvError`]: crate::sync::broadcast::error::RecvError /// /// # Examples /// @@ -400,7 +442,7 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// tx.send(20).unwrap(); /// } /// ``` -pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { +pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { assert!(capacity > 0, "capacity is empty"); assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); @@ -433,7 +475,6 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { let rx = Receiver { shared: shared.clone(), next: 0, - waiter: None, }; let tx = Sender { shared }; @@ -528,23 +569,7 @@ impl<T> Sender<T> { /// ``` pub fn subscribe(&self) -> Receiver<T> { let shared = self.shared.clone(); - - let mut tail = shared.tail.lock().unwrap(); - - if tail.rx_cnt == MAX_RECEIVERS { - panic!("max receivers"); - } - - tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); - let next = tail.pos; - - drop(tail); - - Receiver { - shared, - next, - waiter: None, - } + new_receiver(shared) } /// Returns the number of active receivers @@ -584,12 +609,12 @@ impl<T> Sender<T> { /// } /// ``` pub fn receiver_count(&self) -> usize { - let tail = self.shared.tail.lock().unwrap(); + let tail = self.shared.tail.lock(); tail.rx_cnt } fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> { - let mut tail = self.shared.tail.lock().unwrap(); + let mut tail = self.shared.tail.lock(); if tail.rx_cnt == 0 { return Err(SendError(value)); @@ -634,6 +659,22 @@ impl<T> Sender<T> { } } +fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> { + let mut tail = shared.tail.lock(); + + if tail.rx_cnt == MAX_RECEIVERS { + panic!("max receivers"); + } + + tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); + + let next = tail.pos; + + drop(tail); + + Receiver { shared, next } +} + impl Tail { fn notify_rx(&mut self) { while let Some(mut waiter) = self.waiters.pop_back() { @@ -695,7 +736,7 @@ impl<T> Receiver<T> { // the slot lock. drop(slot); - let mut tail = self.shared.tail.lock().unwrap(); + let mut tail = self.shared.tail.lock(); // Acquire slot lock again slot = self.shared.buffer[idx].read().unwrap(); @@ -784,106 +825,7 @@ impl<T> Receiver<T> { } } -impl<T> Receiver<T> -where - T: Clone, -{ - /// Attempts to return a pending value on this receiver without awaiting. - /// - /// This is useful for a flavor of "optimistic check" before deciding to - /// await on a receiver. - /// - /// Compared with [`recv`], this function has three failure cases instead of one - /// (one for closed, one for an empty buffer, one for a lagging receiver). - /// - /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have - /// dropped, indicating that no further values can be sent on the channel. - /// - /// If the [`Receiver`] handle falls behind, once the channel is full, newly - /// sent values will overwrite old values. At this point, a call to [`recv`] - /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s - /// internal cursor is updated to point to the oldest value still held by - /// the channel. A subsequent call to [`try_recv`] will return this value - /// **unless** it has been since overwritten. If there are no values to - /// receive, `Err(TryRecvError::Empty)` is returned. - /// - /// [`recv`]: crate::sync::broadcast::Receiver::recv - /// [`Receiver`]: crate::sync::broadcast::Receiver - /// - /// # Examples - /// - /// ``` - /// use tokio::sync::broadcast; - /// - /// #[tokio::main] - /// async fn main() { - /// let (tx, mut rx) = broadcast::channel(16); - /// - /// assert!(rx.try_recv().is_err()); - /// - /// tx.send(10).unwrap(); - /// - /// let value = rx.try_recv().unwrap(); - /// assert_eq!(10, value); - /// } - /// ``` - pub fn try_recv(&mut self) -> Result<T, TryRecvError> { - let guard = self.recv_ref(None)?; - guard.clone_value().ok_or(TryRecvError::Closed) - } - - #[doc(hidden)] - #[deprecated(since = "0.2.21", note = "use async fn recv()")] - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { - use Poll::{Pending, Ready}; - - // The borrow checker prohibits calling `self.poll_ref` while passing in - // a mutable ref to a field (as it should). To work around this, - // `waiter` is first *removed* from `self` then `poll_recv` is called. - // - // However, for safety, we must ensure that `waiter` is **not** dropped. - // It could be contained in the intrusive linked list. The `Receiver` - // drop implementation handles cleanup. - // - // The guard pattern is used to ensure that, on return, even due to - // panic, the waiter node is replaced on `self`. - - struct Guard<'a, T> { - waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>, - receiver: &'a mut Receiver<T>, - } - - impl<'a, T> Drop for Guard<'a, T> { - fn drop(&mut self) { - self.receiver.waiter = self.waiter.take(); - } - } - - let waiter = self.waiter.take().or_else(|| { - Some(Box::pin(UnsafeCell::new(Waiter { - queued: false, - waker: None, - pointers: linked_list::Pointers::new(), - _p: PhantomPinned, - }))) - }); - - let guard = Guard { - waiter, - receiver: self, - }; - let res = guard - .receiver - .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker()))); - - match res { - Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)), - Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)), - Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))), - Err(TryRecvError::Empty) => Pending, - } - } - +impl<T: Clone> Receiver<T> { /// Receives the next value for this receiver. /// /// Each [`Receiver`] handle will receive a clone of all values sent @@ -948,54 +890,103 @@ where /// assert_eq!(20, rx.recv().await.unwrap()); /// assert_eq!(30, rx.recv().await.unwrap()); /// } + /// ``` pub async fn recv(&mut self) -> Result<T, RecvError> { let fut = Recv::<_, T>::new(Borrow(self)); fut.await } -} -#[cfg(feature = "stream")] -#[doc(hidden)] -#[deprecated(since = "0.2.21", note = "use `into_stream()`")] -impl<T> crate::stream::Stream for Receiver<T> -where - T: Clone, -{ - type Item = Result<T, RecvError>; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Option<Result<T, RecvError>>> { - #[allow(deprecated)] - self.poll_recv(cx).map(|v| match v { - Ok(v) => Some(Ok(v)), - lag @ Err(RecvError::Lagged(_)) => Some(lag), - Err(RecvError::Closed) => None, - }) + /// Attempts to return a pending value on this receiver without awaiting. + /// + /// This is useful for a flavor of "optimistic check" before deciding to + /// await on a receiver. + /// + /// Compared with [`recv`], this function has three failure cases instead of two + /// (one for closed, one for an empty buffer, one for a lagging receiver). + /// + /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to [`recv`] + /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s + /// internal cursor is updated to point to the oldest value still held by + /// the channel. A subsequent call to [`try_recv`] will return this value + /// **unless** it has been since overwritten. If there are no values to + /// receive, `Err(TryRecvError::Empty)` is returned. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(16); + /// + /// assert!(rx.try_recv().is_err()); + /// + /// tx.send(10).unwrap(); + /// + /// let value = rx.try_recv().unwrap(); + /// assert_eq!(10, value); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + let guard = self.recv_ref(None)?; + guard.clone_value().ok_or(TryRecvError::Closed) + } + + /// Convert the receiver into a `Stream`. + /// + /// The conversion allows using `Receiver` with APIs that require stream + /// values. + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::StreamExt; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = broadcast::channel(128); + /// + /// tokio::spawn(async move { + /// for i in 0..10_i32 { + /// tx.send(i).unwrap(); + /// } + /// }); + /// + /// // Streams must be pinned to iterate. + /// tokio::pin! { + /// let stream = rx + /// .into_stream() + /// .filter(Result::is_ok) + /// .map(Result::unwrap) + /// .filter(|v| v % 2 == 0) + /// .map(|v| v + 1); + /// } + /// + /// while let Some(i) = stream.next().await { + /// println!("{}", i); + /// } + /// } + /// ``` + #[cfg(feature = "stream")] + #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] + pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> { + Recv::new(Borrow(self)) } } impl<T> Drop for Receiver<T> { fn drop(&mut self) { - let mut tail = self.shared.tail.lock().unwrap(); - - if let Some(waiter) = &self.waiter { - // safety: tail lock is held - let queued = waiter.with(|ptr| unsafe { (*ptr).queued }); - - if queued { - // Remove the node - // - // safety: tail lock is held and the wait node is verified to be in - // the list. - unsafe { - waiter.with_mut(|ptr| { - tail.waiters.remove((&mut *ptr).into()); - }); - } - } - } + let mut tail = self.shared.tail.lock(); tail.rx_cnt -= 1; let until = tail.pos; @@ -1070,48 +1061,6 @@ where cfg_stream! { use futures_core::Stream; - impl<T: Clone> Receiver<T> { - /// Convert the receiver into a `Stream`. - /// - /// The conversion allows using `Receiver` with APIs that require stream - /// values. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::StreamExt; - /// use tokio::sync::broadcast; - /// - /// #[tokio::main] - /// async fn main() { - /// let (tx, rx) = broadcast::channel(128); - /// - /// tokio::spawn(async move { - /// for i in 0..10_i32 { - /// tx.send(i).unwrap(); - /// } - /// }); - /// - /// // Streams must be pinned to iterate. - /// tokio::pin! { - /// let stream = rx - /// .into_stream() - /// .filter(Result::is_ok) - /// .map(Result::unwrap) - /// .filter(|v| v % 2 == 0) - /// .map(|v| v + 1); - /// } - /// - /// while let Some(i) = stream.next().await { - /// println!("{}", i); - /// } - /// } - /// ``` - pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> { - Recv::new(Borrow(self)) - } - } - impl<R, T: Clone> Stream for Recv<R, T> where R: AsMut<Receiver<T>>, @@ -1141,7 +1090,7 @@ where fn drop(&mut self) { // Acquire the tail lock. This is required for safety before accessing // the waiter node. - let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap(); + let mut tail = self.receiver.as_mut().shared.tail.lock(); // safety: tail lock is held let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); @@ -1211,27 +1160,4 @@ impl<'a, T> Drop for RecvGuard<'a, T> { } } -impl fmt::Display for RecvError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - RecvError::Closed => write!(f, "channel closed"), - RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), - } - } -} - -impl std::error::Error for RecvError {} - -impl fmt::Display for TryRecvError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TryRecvError::Empty => write!(f, "channel empty"), - TryRecvError::Closed => write!(f, "channel closed"), - TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), - } - } -} - -impl std::error::Error for TryRecvError {} - fn is_unpin<T: Unpin>() {} diff --git a/src/sync/cancellation_token.rs b/src/sync/cancellation_token.rs deleted file mode 100644 index d60d8e0..0000000 --- a/src/sync/cancellation_token.rs +++ /dev/null @@ -1,861 +0,0 @@ -//! An asynchronously awaitable `CancellationToken`. -//! The token allows to signal a cancellation request to one or more tasks. - -use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::Mutex; -use crate::util::intrusive_double_linked_list::{LinkedList, ListNode}; - -use core::future::Future; -use core::pin::Pin; -use core::ptr::NonNull; -use core::sync::atomic::Ordering; -use core::task::{Context, Poll, Waker}; - -/// A token which can be used to signal a cancellation request to one or more -/// tasks. -/// -/// Tasks can call [`CancellationToken::cancelled()`] in order to -/// obtain a Future which will be resolved when cancellation is requested. -/// -/// Cancellation can be requested through the [`CancellationToken::cancel`] method. -/// -/// # Examples -/// -/// ```ignore -/// use tokio::select; -/// use tokio::scope::CancellationToken; -/// -/// #[tokio::main] -/// async fn main() { -/// let token = CancellationToken::new(); -/// let cloned_token = token.clone(); -/// -/// let join_handle = tokio::spawn(async move { -/// // Wait for either cancellation or a very long time -/// select! { -/// _ = cloned_token.cancelled() => { -/// // The token was cancelled -/// 5 -/// } -/// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => { -/// 99 -/// } -/// } -/// }); -/// -/// tokio::spawn(async move { -/// tokio::time::delay_for(std::time::Duration::from_millis(10)).await; -/// token.cancel(); -/// }); -/// -/// assert_eq!(5, join_handle.await.unwrap()); -/// } -/// ``` -pub struct CancellationToken { - inner: NonNull<CancellationTokenState>, -} - -// Safety: The CancellationToken is thread-safe and can be moved between threads, -// since all methods are internally synchronized. -unsafe impl Send for CancellationToken {} -unsafe impl Sync for CancellationToken {} - -/// A Future that is resolved once the corresponding [`CancellationToken`] -/// was cancelled -#[must_use = "futures do nothing unless polled"] -pub struct WaitForCancellationFuture<'a> { - /// The CancellationToken that is associated with this WaitForCancellationFuture - cancellation_token: Option<&'a CancellationToken>, - /// Node for waiting at the cancellation_token - wait_node: ListNode<WaitQueueEntry>, - /// Whether this future was registered at the token yet as a waiter - is_registered: bool, -} - -// Safety: Futures can be sent between threads as long as the underlying -// cancellation_token is thread-safe (Sync), -// which allows to poll/register/unregister from a different thread. -unsafe impl<'a> Send for WaitForCancellationFuture<'a> {} - -// ===== impl CancellationToken ===== - -impl core::fmt::Debug for CancellationToken { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("CancellationToken") - .field("is_cancelled", &self.is_cancelled()) - .finish() - } -} - -impl Clone for CancellationToken { - fn clone(&self) -> Self { - // Safety: The state inside a `CancellationToken` is always valid, since - // is reference counted - let inner = self.state(); - - // Tokens are cloned by increasing their refcount - let current_state = inner.snapshot(); - inner.increment_refcount(current_state); - - CancellationToken { inner: self.inner } - } -} - -impl Drop for CancellationToken { - fn drop(&mut self) { - let token_state_pointer = self.inner; - - // Safety: The state inside a `CancellationToken` is always valid, since - // is reference counted - let inner = unsafe { &mut *self.inner.as_ptr() }; - - let mut current_state = inner.snapshot(); - - // We need to safe the parent, since the state might be released by the - // next call - let parent = inner.parent; - - // Drop our own refcount - current_state = inner.decrement_refcount(current_state); - - // If this was the last reference, unregister from the parent - if current_state.refcount == 0 { - if let Some(mut parent) = parent { - // Safety: Since we still retain a reference on the parent, it must be valid. - let parent = unsafe { parent.as_mut() }; - parent.unregister_child(token_state_pointer, current_state); - } - } - } -} - -impl CancellationToken { - /// Creates a new CancellationToken in the non-cancelled state. - pub fn new() -> CancellationToken { - let state = Box::new(CancellationTokenState::new( - None, - StateSnapshot { - cancel_state: CancellationState::NotCancelled, - has_parent_ref: false, - refcount: 1, - }, - )); - - // Safety: We just created the Box. The pointer is guaranteed to be - // not null - CancellationToken { - inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) }, - } - } - - /// Returns a reference to the utilized `CancellationTokenState`. - fn state(&self) -> &CancellationTokenState { - // Safety: The state inside a `CancellationToken` is always valid, since - // is reference counted - unsafe { &*self.inner.as_ptr() } - } - - /// Creates a `CancellationToken` which will get cancelled whenever the - /// current token gets cancelled. - /// - /// If the current token is already cancelled, the child token will get - /// returned in cancelled state. - /// - /// # Examples - /// - /// ```ignore - /// use tokio::select; - /// use tokio::scope::CancellationToken; - /// - /// #[tokio::main] - /// async fn main() { - /// let token = CancellationToken::new(); - /// let child_token = token.child_token(); - /// - /// let join_handle = tokio::spawn(async move { - /// // Wait for either cancellation or a very long time - /// select! { - /// _ = child_token.cancelled() => { - /// // The token was cancelled - /// 5 - /// } - /// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => { - /// 99 - /// } - /// } - /// }); - /// - /// tokio::spawn(async move { - /// tokio::time::delay_for(std::time::Duration::from_millis(10)).await; - /// token.cancel(); - /// }); - /// - /// assert_eq!(5, join_handle.await.unwrap()); - /// } - /// ``` - pub fn child_token(&self) -> CancellationToken { - let inner = self.state(); - - // Increment the refcount of this token. It will be referenced by the - // child, independent of whether the child is immediately cancelled or - // not. - let _current_state = inner.increment_refcount(inner.snapshot()); - - let mut unpacked_child_state = StateSnapshot { - has_parent_ref: true, - refcount: 1, - cancel_state: CancellationState::NotCancelled, - }; - let mut child_token_state = Box::new(CancellationTokenState::new( - Some(self.inner), - unpacked_child_state, - )); - - { - let mut guard = inner.synchronized.lock().unwrap(); - if guard.is_cancelled { - // This task was already cancelled. In this case we should not - // insert the child into the list, since it would never get removed - // from the list. - (*child_token_state.synchronized.lock().unwrap()).is_cancelled = true; - unpacked_child_state.cancel_state = CancellationState::Cancelled; - // Since it's not in the list, the parent doesn't need to retain - // a reference to it. - unpacked_child_state.has_parent_ref = false; - child_token_state - .state - .store(unpacked_child_state.pack(), Ordering::SeqCst); - } else { - if let Some(mut first_child) = guard.first_child { - child_token_state.from_parent.next_peer = Some(first_child); - // Safety: We manipulate other child task inside the Mutex - // and retain a parent reference on it. The child token can't - // get invalidated while the Mutex is held. - unsafe { - first_child.as_mut().from_parent.prev_peer = - Some((&mut *child_token_state).into()) - }; - } - guard.first_child = Some((&mut *child_token_state).into()); - } - }; - - let child_token_ptr = Box::into_raw(child_token_state); - // Safety: We just created the pointer from a `Box` - CancellationToken { - inner: unsafe { NonNull::new_unchecked(child_token_ptr) }, - } - } - - /// Cancel the [`CancellationToken`] and all child tokens which had been - /// derived from it. - /// - /// This will wake up all tasks which are waiting for cancellation. - pub fn cancel(&self) { - self.state().cancel(); - } - - /// Returns `true` if the `CancellationToken` had been cancelled - pub fn is_cancelled(&self) -> bool { - self.state().is_cancelled() - } - - /// Returns a `Future` that gets fulfilled when cancellation is requested. - pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { - WaitForCancellationFuture { - cancellation_token: Some(self), - wait_node: ListNode::new(WaitQueueEntry::new()), - is_registered: false, - } - } - - unsafe fn register( - &self, - wait_node: &mut ListNode<WaitQueueEntry>, - cx: &mut Context<'_>, - ) -> Poll<()> { - self.state().register(wait_node, cx) - } - - fn check_for_cancellation( - &self, - wait_node: &mut ListNode<WaitQueueEntry>, - cx: &mut Context<'_>, - ) -> Poll<()> { - self.state().check_for_cancellation(wait_node, cx) - } - - fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) { - self.state().unregister(wait_node) - } -} - -// ===== impl WaitForCancellationFuture ===== - -impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("WaitForCancellationFuture").finish() - } -} - -impl<'a> Future for WaitForCancellationFuture<'a> { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - // Safety: We do not move anything out of `WaitForCancellationFuture` - let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; - - let cancellation_token = mut_self - .cancellation_token - .expect("polled WaitForCancellationFuture after completion"); - - let poll_res = if !mut_self.is_registered { - // Safety: The `ListNode` is pinned through the Future, - // and we will unregister it in `WaitForCancellationFuture::drop` - // before the Future is dropped and the memory reference is invalidated. - unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) } - } else { - cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx) - }; - - if let Poll::Ready(()) = poll_res { - // The cancellation_token was signalled - mut_self.cancellation_token = None; - // A signalled Token means the Waker won't be enqueued anymore - mut_self.is_registered = false; - mut_self.wait_node.task = None; - } else { - // This `Future` and its stored `Waker` stay registered at the - // `CancellationToken` - mut_self.is_registered = true; - } - - poll_res - } -} - -impl<'a> Drop for WaitForCancellationFuture<'a> { - fn drop(&mut self) { - // If this WaitForCancellationFuture has been polled and it was added to the - // wait queue at the cancellation_token, it must be removed before dropping. - // Otherwise the cancellation_token would access invalid memory. - if let Some(token) = self.cancellation_token { - if self.is_registered { - token.unregister(&mut self.wait_node); - } - } - } -} - -/// Tracks how the future had interacted with the [`CancellationToken`] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -enum PollState { - /// The task has never interacted with the [`CancellationToken`]. - New, - /// The task was added to the wait queue at the [`CancellationToken`]. - Waiting, - /// The task has been polled to completion. - Done, -} - -/// Tracks the WaitForCancellationFuture waiting state. -/// Access to this struct is synchronized through the mutex in the CancellationToken. -struct WaitQueueEntry { - /// The task handle of the waiting task - task: Option<Waker>, - // Current polling state. This state is only updated inside the Mutex of - // the CancellationToken. - state: PollState, -} - -impl WaitQueueEntry { - /// Creates a new WaitQueueEntry - fn new() -> WaitQueueEntry { - WaitQueueEntry { - task: None, - state: PollState::New, - } - } -} - -struct SynchronizedState { - waiters: LinkedList<WaitQueueEntry>, - first_child: Option<NonNull<CancellationTokenState>>, - is_cancelled: bool, -} - -impl SynchronizedState { - fn new() -> Self { - Self { - waiters: LinkedList::new(), - first_child: None, - is_cancelled: false, - } - } -} - -/// Information embedded in child tokens which is synchronized through the Mutex -/// in their parent. -struct SynchronizedThroughParent { - next_peer: Option<NonNull<CancellationTokenState>>, - prev_peer: Option<NonNull<CancellationTokenState>>, -} - -/// Possible states of a `CancellationToken` -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum CancellationState { - NotCancelled = 0, - Cancelling = 1, - Cancelled = 2, -} - -impl CancellationState { - fn pack(self) -> usize { - self as usize - } - - fn unpack(value: usize) -> Self { - match value { - 0 => CancellationState::NotCancelled, - 1 => CancellationState::Cancelling, - 2 => CancellationState::Cancelled, - _ => unreachable!("Invalid value"), - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -struct StateSnapshot { - /// The amount of references to this particular CancellationToken. - /// `CancellationToken` structs hold these references to a `CancellationTokenState`. - /// Also the state is referenced by the state of each child. - refcount: usize, - /// Whether the state is still referenced by it's parent and can therefore - /// not be freed. - has_parent_ref: bool, - /// Whether the token is cancelled - cancel_state: CancellationState, -} - -impl StateSnapshot { - /// Packs the snapshot into a `usize` - fn pack(self) -> usize { - self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack() - } - - /// Unpacks the snapshot from a `usize` - fn unpack(value: usize) -> Self { - let refcount = value >> 3; - let has_parent_ref = value & 4 != 0; - let cancel_state = CancellationState::unpack(value & 0x03); - - StateSnapshot { - refcount, - has_parent_ref, - cancel_state, - } - } - - /// Whether this `CancellationTokenState` is still referenced by any - /// `CancellationToken`. - fn has_refs(&self) -> bool { - self.refcount != 0 || self.has_parent_ref - } -} - -/// The maximum permitted amount of references to a CancellationToken. This -/// is derived from the intent to never use more than 32bit in the `Snapshot`. -const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3; - -/// Internal state of the `CancellationToken` pair above -struct CancellationTokenState { - state: AtomicUsize, - parent: Option<NonNull<CancellationTokenState>>, - from_parent: SynchronizedThroughParent, - synchronized: Mutex<SynchronizedState>, -} - -impl CancellationTokenState { - fn new( - parent: Option<NonNull<CancellationTokenState>>, - state: StateSnapshot, - ) -> CancellationTokenState { - CancellationTokenState { - parent, - from_parent: SynchronizedThroughParent { - prev_peer: None, - next_peer: None, - }, - state: AtomicUsize::new(state.pack()), - synchronized: Mutex::new(SynchronizedState::new()), - } - } - - /// Returns a snapshot of the current atomic state of the token - fn snapshot(&self) -> StateSnapshot { - StateSnapshot::unpack(self.state.load(Ordering::SeqCst)) - } - - fn atomic_update_state<F>(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot - where - F: Fn(StateSnapshot) -> StateSnapshot, - { - let mut current_packed_state = current_state.pack(); - loop { - let next_state = func(current_state); - match self.state.compare_exchange( - current_packed_state, - next_state.pack(), - Ordering::SeqCst, - Ordering::SeqCst, - ) { - Ok(_) => { - return next_state; - } - Err(actual) => { - current_packed_state = actual; - current_state = StateSnapshot::unpack(actual); - } - } - } - } - - fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { - self.atomic_update_state(current_state, |mut state: StateSnapshot| { - if state.refcount >= MAX_REFS as usize { - eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded"); - std::process::abort(); - } - state.refcount += 1; - state - }) - } - - fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { - let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { - state.refcount -= 1; - state - }); - - // Drop the State if it is not referenced anymore - if !current_state.has_refs() { - // Safety: `CancellationTokenState` is always stored in refcounted - // Boxes - let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; - } - - current_state - } - - fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot { - let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { - state.has_parent_ref = false; - state - }); - - // Drop the State if it is not referenced anymore - if !current_state.has_refs() { - // Safety: `CancellationTokenState` is always stored in refcounted - // Boxes - let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; - } - - current_state - } - - /// Unregisters a child from the parent token. - /// The child tokens state is not exactly known at this point in time. - /// If the parent token is cancelled, the child token gets removed from the - /// parents list, and might therefore already have been freed. If the parent - /// token is not cancelled, the child token is still valid. - fn unregister_child( - &mut self, - mut child_state: NonNull<CancellationTokenState>, - current_child_state: StateSnapshot, - ) { - let removed_child = { - // Remove the child toke from the parents linked list - let mut guard = self.synchronized.lock().unwrap(); - if !guard.is_cancelled { - // Safety: Since the token was not cancelled, the child must - // still be in the list and valid. - let mut child_state = unsafe { child_state.as_mut() }; - debug_assert!(child_state.snapshot().has_parent_ref); - - if guard.first_child == Some(child_state.into()) { - guard.first_child = child_state.from_parent.next_peer; - } - // Safety: If peers wouldn't be valid anymore, they would try - // to remove themselves from the list. This would require locking - // the Mutex that we currently own. - unsafe { - if let Some(mut prev_peer) = child_state.from_parent.prev_peer { - prev_peer.as_mut().from_parent.next_peer = - child_state.from_parent.next_peer; - } - if let Some(mut next_peer) = child_state.from_parent.next_peer { - next_peer.as_mut().from_parent.prev_peer = - child_state.from_parent.prev_peer; - } - } - child_state.from_parent.prev_peer = None; - child_state.from_parent.next_peer = None; - - // The child is no longer referenced by the parent, since we were able - // to remove its reference from the parents list. - true - } else { - // Do not touch the linked list anymore. If the parent is cancelled - // it will move all childs outside of the Mutex and manipulate - // the pointers there. Manipulating the pointers here too could - // lead to races. Therefore leave them just as as and let the - // parent deal with it. The parent will make sure to retain a - // reference to this state as long as it manipulates the list - // pointers. Therefore the pointers are not dangling. - false - } - }; - - if removed_child { - // If the token removed itself from the parents list, it can reset - // the the parent ref status. If it is isn't able to do so, because the - // parent removed it from the list, there is no need to do this. - // The parent ref acts as as another reference count. Therefore - // removing this reference can free the object. - // Safety: The token was in the list. This means the parent wasn't - // cancelled before, and the token must still be alive. - unsafe { child_state.as_mut().remove_parent_ref(current_child_state) }; - } - - // Decrement the refcount on the parent and free it if necessary - self.decrement_refcount(self.snapshot()); - } - - fn cancel(&self) { - // Move the state of the CancellationToken from `NotCancelled` to `Cancelling` - let mut current_state = self.snapshot(); - - let state_after_cancellation = loop { - if current_state.cancel_state != CancellationState::NotCancelled { - // Another task already initiated the cancellation - return; - } - - let mut next_state = current_state; - next_state.cancel_state = CancellationState::Cancelling; - match self.state.compare_exchange( - current_state.pack(), - next_state.pack(), - Ordering::SeqCst, - Ordering::SeqCst, - ) { - Ok(_) => break next_state, - Err(actual) => current_state = StateSnapshot::unpack(actual), - } - }; - - // This task cancelled the token - - // Take the task list out of the Token - // We do not want to cancel child token inside this lock. If one of the - // child tasks would have additional child tokens, we would recursively - // take locks. - - // Doing this action has an impact if the child token is dropped concurrently: - // It will try to deregister itself from the parent task, but can not find - // itself in the task list anymore. Therefore it needs to assume the parent - // has extracted the list and will process it. It may not modify the list. - // This is OK from a memory safety perspective, since the parent still - // retains a reference to the child task until it finished iterating over - // it. - - let mut first_child = { - let mut guard = self.synchronized.lock().unwrap(); - // Save the cancellation also inside the Mutex - // This allows child tokens which want to detach themselves to detect - // that this is no longer required since the parent cleared the list. - guard.is_cancelled = true; - - // Wakeup all waiters - // This happens inside the lock to make cancellation reliable - // If we would access waiters outside of the lock, the pointers - // may no longer be valid. - // Typically this shouldn't be an issue, since waking a task should - // only move it from the blocked into the ready state and not have - // further side effects. - - // Use a reverse iterator, so that the oldest waiter gets - // scheduled first - guard.waiters.reverse_drain(|waiter| { - // We are not allowed to move the `Waker` out of the list node. - // The `Future` relies on the fact that the old `Waker` stays there - // as long as the `Future` has not completed in order to perform - // the `will_wake()` check. - // Therefore `wake_by_ref` is used instead of `wake()` - if let Some(handle) = &mut waiter.task { - handle.wake_by_ref(); - } - // Mark the waiter to have been removed from the list. - waiter.state = PollState::Done; - }); - - guard.first_child.take() - }; - - while let Some(mut child) = first_child { - // Safety: We know this is a valid pointer since it is in our child pointer - // list. It can't have been freed in between, since we retain a a reference - // to each child. - let mut_child = unsafe { child.as_mut() }; - - // Get the next child and clean up list pointers - first_child = mut_child.from_parent.next_peer; - mut_child.from_parent.prev_peer = None; - mut_child.from_parent.next_peer = None; - - // Cancel the child task - mut_child.cancel(); - - // Drop the parent reference. This `CancellationToken` is not interested - // in interacting with the child anymore. - // This is ONLY allowed once we promised not to touch the state anymore - // after this interaction. - mut_child.remove_parent_ref(mut_child.snapshot()); - } - - // The cancellation has completed - // At this point in time tasks which registered a wait node can be sure - // that this wait node already had been dequeued from the list without - // needing to inspect the list. - self.atomic_update_state(state_after_cancellation, |mut state| { - state.cancel_state = CancellationState::Cancelled; - state - }); - } - - /// Returns `true` if the `CancellationToken` had been cancelled - fn is_cancelled(&self) -> bool { - let current_state = self.snapshot(); - current_state.cancel_state != CancellationState::NotCancelled - } - - /// Registers a waiting task at the `CancellationToken`. - /// Safety: This method is only safe as long as the waiting waiting task - /// will properly unregister the wait node before it gets moved. - unsafe fn register( - &self, - wait_node: &mut ListNode<WaitQueueEntry>, - cx: &mut Context<'_>, - ) -> Poll<()> { - debug_assert_eq!(PollState::New, wait_node.state); - let current_state = self.snapshot(); - - // Perform an optimistic cancellation check before. This is not strictly - // necessary since we also check for cancellation in the Mutex, but - // reduces the necessary work to be performed for tasks which already - // had been cancelled. - if current_state.cancel_state != CancellationState::NotCancelled { - return Poll::Ready(()); - } - - // So far the token is not cancelled. However it could be cancelld before - // we get the chance to store the `Waker`. Therfore we need to check - // for cancellation again inside the mutex. - let mut guard = self.synchronized.lock().unwrap(); - if guard.is_cancelled { - // Cancellation was signalled - wait_node.state = PollState::Done; - Poll::Ready(()) - } else { - // Added the task to the wait queue - wait_node.task = Some(cx.waker().clone()); - wait_node.state = PollState::Waiting; - guard.waiters.add_front(wait_node); - Poll::Pending - } - } - - fn check_for_cancellation( - &self, - wait_node: &mut ListNode<WaitQueueEntry>, - cx: &mut Context<'_>, - ) -> Poll<()> { - debug_assert!( - wait_node.task.is_some(), - "Method can only be called after task had been registered" - ); - - let current_state = self.snapshot(); - - if current_state.cancel_state != CancellationState::NotCancelled { - // If the cancellation had been fully completed we know that our `Waker` - // is no longer registered at the `CancellationToken`. - // Otherwise the cancel call may or may not yet have iterated - // through the waiters list and removed the wait nodes. - // If it hasn't yet, we need to remove it. Otherwise an attempt to - // reuse the `wait_node´ might get freed due to the `WaitForCancellationFuture` - // getting dropped before the cancellation had interacted with it. - if current_state.cancel_state != CancellationState::Cancelled { - self.unregister(wait_node); - } - Poll::Ready(()) - } else { - // Check if we need to swap the `Waker`. This will make the check more - // expensive, since the `Waker` is synchronized through the Mutex. - // If we don't need to perform a `Waker` update, an atomic check for - // cancellation is sufficient. - let need_waker_update = wait_node - .task - .as_ref() - .map(|waker| waker.will_wake(cx.waker())) - .unwrap_or(true); - - if need_waker_update { - let guard = self.synchronized.lock().unwrap(); - if guard.is_cancelled { - // Cancellation was signalled. Since this cancellation signal - // is set inside the Mutex, the old waiter must already have - // been removed from the waiting list - debug_assert_eq!(PollState::Done, wait_node.state); - wait_node.task = None; - Poll::Ready(()) - } else { - // The WaitForCancellationFuture is already in the queue. - // The CancellationToken can't have been cancelled, - // since this would change the is_cancelled flag inside the mutex. - // Therefore we just have to update the Waker. A follow-up - // cancellation will always use the new waker. - wait_node.task = Some(cx.waker().clone()); - Poll::Pending - } - } else { - // Do nothing. If the token gets cancelled, this task will get - // woken again and can fetch the cancellation. - Poll::Pending - } - } - } - - fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) { - debug_assert!( - wait_node.task.is_some(), - "waiter can not be active without task" - ); - - let mut guard = self.synchronized.lock().unwrap(); - // WaitForCancellationFuture only needs to get removed if it has been added to - // the wait queue of the CancellationToken. - // This has happened in the PollState::Waiting case. - if let PollState::Waiting = wait_node.state { - // Safety: Due to the state, we know that the node must be part - // of the waiter list - if !unsafe { guard.waiters.remove(wait_node) } { - // Panic if the address isn't found. This can only happen if the contract was - // violated, e.g. the WaitQueueEntry got moved after the initial poll. - panic!("Future could not be removed from wait queue"); - } - wait_node.state = PollState::Done; - } - wait_node.task = None; - } -} diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 3d96106..57ae277 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -20,7 +20,7 @@ //! few flavors of channels provided by Tokio. Each channel flavor supports //! different message passing patterns. When a channel supports multiple //! producers, many separate tasks may **send** messages. When a channel -//! supports muliple consumers, many different separate tasks may **receive** +//! supports multiple consumers, many different separate tasks may **receive** //! messages. //! //! Tokio provides many different channel flavors as different message passing @@ -106,7 +106,7 @@ //! //! #[tokio::main] //! async fn main() { -//! let (mut tx, mut rx) = mpsc::channel(100); +//! let (tx, mut rx) = mpsc::channel(100); //! //! tokio::spawn(async move { //! for i in 0..10 { @@ -150,7 +150,7 @@ //! for _ in 0..10 { //! // Each task needs its own `tx` handle. This is done by cloning the //! // original handle. -//! let mut tx = tx.clone(); +//! let tx = tx.clone(); //! //! tokio::spawn(async move { //! tx.send(&b"data to write"[..]).await.unwrap(); @@ -213,7 +213,7 @@ //! //! // Spawn tasks that will send the increment command. //! for _ in 0..10 { -//! let mut cmd_tx = cmd_tx.clone(); +//! let cmd_tx = cmd_tx.clone(); //! //! join_handles.push(tokio::spawn(async move { //! let (resp_tx, resp_rx) = oneshot::channel(); @@ -322,7 +322,7 @@ //! tokio::spawn(async move { //! loop { //! // Wait 10 seconds between checks -//! time::delay_for(Duration::from_secs(10)).await; +//! time::sleep(Duration::from_secs(10)).await; //! //! // Load the configuration file //! let new_config = Config::load_from_file().await.unwrap(); @@ -330,7 +330,7 @@ //! // If the configuration changed, send the new config value //! // on the watch channel. //! if new_config != config { -//! tx.broadcast(new_config.clone()).unwrap(); +//! tx.send(new_config.clone()).unwrap(); //! config = new_config; //! } //! } @@ -355,17 +355,15 @@ //! let op = my_async_operation(); //! tokio::pin!(op); //! -//! // Receive the **initial** configuration value. As this is the -//! // first time the config is received from the watch, it will -//! // always complete immediatedly. -//! let mut conf = rx.recv().await.unwrap(); +//! // Get the initial config value +//! let mut conf = rx.borrow().clone(); //! //! let mut op_start = Instant::now(); -//! let mut delay = time::delay_until(op_start + conf.timeout); +//! let mut sleep = time::sleep_until(op_start + conf.timeout); //! //! loop { //! tokio::select! { -//! _ = &mut delay => { +//! _ = &mut sleep => { //! // The operation elapsed. Restart it //! op.set(my_async_operation()); //! @@ -373,14 +371,14 @@ //! op_start = Instant::now(); //! //! // Restart the timeout -//! delay = time::delay_until(op_start + conf.timeout); +//! sleep = time::sleep_until(op_start + conf.timeout); //! } -//! new_conf = rx.recv() => { -//! conf = new_conf.unwrap(); +//! _ = rx.changed() => { +//! conf = rx.borrow().clone(); //! //! // The configuration has been updated. Update the -//! // `delay` using the new `timeout` value. -//! delay.reset(op_start + conf.timeout); +//! // `sleep` using the new `timeout` value. +//! sleep.reset(op_start + conf.timeout); //! } //! _ = &mut op => { //! // The operation completed! @@ -399,14 +397,14 @@ //! } //! ``` //! -//! [`watch` channel]: crate::sync::watch -//! [`broadcast` channel]: crate::sync::broadcast +//! [`watch` channel]: mod@crate::sync::watch +//! [`broadcast` channel]: mod@crate::sync::broadcast //! //! # State synchronization //! //! The remaining synchronization primitives focus on synchronizing state. //! These are asynchronous equivalents to versions provided by `std`. They -//! operate in a similar way as their `std` counterparts parts but will wait +//! operate in a similar way as their `std` counterparts but will wait //! asynchronously instead of blocking the thread. //! //! * [`Barrier`](Barrier) Ensures multiple tasks will wait for each other to @@ -434,23 +432,17 @@ cfg_sync! { pub mod broadcast; - cfg_unstable! { - mod cancellation_token; - pub use cancellation_token::{CancellationToken, WaitForCancellationFuture}; - } - pub mod mpsc; mod mutex; pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard}; - mod notify; + pub(crate) mod notify; pub use notify::Notify; pub mod oneshot; pub(crate) mod batch_semaphore; - pub(crate) mod semaphore_ll; mod semaphore; pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; @@ -464,20 +456,30 @@ cfg_sync! { } cfg_not_sync! { + #[cfg(any(feature = "fs", feature = "signal", all(unix, feature = "process")))] + pub(crate) mod batch_semaphore; + + cfg_fs! { + mod mutex; + pub(crate) use mutex::Mutex; + } + + #[cfg(any(feature = "rt", feature = "signal", all(unix, feature = "process")))] + pub(crate) mod notify; + cfg_atomic_waker_impl! { mod task; pub(crate) use task::AtomicWaker; } #[cfg(any( - feature = "rt-core", + feature = "rt", feature = "process", feature = "signal"))] pub(crate) mod oneshot; - cfg_signal! { + cfg_signal_internal! { pub(crate) mod mpsc; - pub(crate) mod semaphore_ll; } } diff --git a/src/sync/mpsc/block.rs b/src/sync/mpsc/block.rs index 7bf1619..e062f2b 100644 --- a/src/sync/mpsc/block.rs +++ b/src/sync/mpsc/block.rs @@ -1,8 +1,6 @@ -use crate::loom::{ - cell::UnsafeCell, - sync::atomic::{AtomicPtr, AtomicUsize}, - thread, -}; +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; +use crate::loom::thread; use std::mem::MaybeUninit; use std::ops; diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs index afca8c5..06b3717 100644 --- a/src/sync/mpsc/bounded.rs +++ b/src/sync/mpsc/bounded.rs @@ -1,6 +1,6 @@ +use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{ClosedError, SendError, TryRecvError, TrySendError}; -use crate::sync::semaphore_ll as semaphore; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; cfg_time! { use crate::sync::mpsc::error::SendTimeoutError; @@ -8,6 +8,7 @@ cfg_time! { } use std::fmt; +#[cfg(any(feature = "signal", feature = "process", feature = "stream"))] use std::task::{Context, Poll}; /// Send values to the associated `Receiver`. @@ -17,20 +18,14 @@ pub struct Sender<T> { chan: chan::Tx<T, Semaphore>, } -impl<T> Clone for Sender<T> { - fn clone(&self) -> Self { - Sender { - chan: self.chan.clone(), - } - } -} - -impl<T> fmt::Debug for Sender<T> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Sender") - .field("chan", &self.chan) - .finish() - } +/// Permit to send one value into the channel. +/// +/// `Permit` values are returned by [`Sender::reserve()`] and are used to +/// guarantee channel capacity before generating a message to send. +/// +/// [`Sender::reserve()`]: Sender::reserve +pub struct Permit<'a, T> { + chan: &'a chan::Tx<T, Semaphore>, } /// Receive values from the associated `Sender`. @@ -41,16 +36,12 @@ pub struct Receiver<T> { chan: chan::Rx<T, Semaphore>, } -impl<T> fmt::Debug for Receiver<T> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Receiver") - .field("chan", &self.chan) - .finish() - } -} - -/// Creates a bounded mpsc channel for communicating between asynchronous tasks, -/// returning the sender/receiver halves. +/// Creates a bounded mpsc channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to `send` new messages will wait until a message is +/// received from the channel. The provided buffer capacity must be at least 1. /// /// All data sent on `Sender` will become available on `Receiver` in the same /// order as it was sent. @@ -62,6 +53,10 @@ impl<T> fmt::Debug for Receiver<T> { /// will return a `SendError`. Similarly, if `Sender` is disconnected while /// trying to `recv`, the `recv` method will return a `RecvError`. /// +/// # Panics +/// +/// Panics if the buffer capacity is 0. +/// /// # Examples /// /// ```rust @@ -69,7 +64,7 @@ impl<T> fmt::Debug for Receiver<T> { /// /// #[tokio::main] /// async fn main() { -/// let (mut tx, mut rx) = mpsc::channel(100); +/// let (tx, mut rx) = mpsc::channel(100); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -117,7 +112,7 @@ impl<T> Receiver<T> { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(100); + /// let (tx, mut rx) = mpsc::channel(100); /// /// tokio::spawn(async move { /// tx.send("hello").await.unwrap(); @@ -135,7 +130,7 @@ impl<T> Receiver<T> { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(100); + /// let (tx, mut rx) = mpsc::channel(100); /// /// tx.send("hello").await.unwrap(); /// tx.send("world").await.unwrap(); @@ -146,15 +141,48 @@ impl<T> Receiver<T> { /// ``` pub async fn recv(&mut self) -> Option<T> { use crate::future::poll_fn; - - poll_fn(|cx| self.poll_recv(cx)).await + poll_fn(|cx| self.chan.recv(cx)).await } - #[doc(hidden)] // TODO: document - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + #[cfg(any(feature = "signal", feature = "process"))] + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { self.chan.recv(cx) } + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::runtime::Runtime; + /// use tokio::sync::mpsc; + /// + /// fn main() { + /// let (tx, mut rx) = mpsc::channel::<u8>(10); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Some(10), rx.blocking_recv()); + /// }); + /// + /// Runtime::new() + /// .unwrap() + /// .block_on(async move { + /// let _ = tx.send(10).await; + /// }); + /// sync_code.join().unwrap() + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(&mut self) -> Option<T> { + crate::future::block_on(self.recv()) + } + /// Attempts to return a pending value on this receiver without blocking. /// /// This method will never block the caller in order to wait for data to @@ -173,12 +201,53 @@ impl<T> Receiver<T> { /// Closes the receiving half of a channel, without dropping it. /// /// This prevents any further messages from being sent on the channel while - /// still enabling the receiver to drain messages that are buffered. + /// still enabling the receiver to drain messages that are buffered. Any + /// outstanding [`Permit`] values will still be able to send messages. + /// + /// In order to guarantee no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. + /// + /// [`Permit`]: Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(20); + /// + /// tokio::spawn(async move { + /// let mut i = 0; + /// while let Ok(permit) = tx.reserve().await { + /// permit.send(i); + /// i += 1; + /// } + /// }); + /// + /// rx.close(); + /// + /// while let Some(msg) = rx.recv().await { + /// println!("got {}", msg); + /// } + /// + /// // Channel closed and no messages are lost. + /// } + /// ``` pub fn close(&mut self) { self.chan.close(); } } +impl<T> fmt::Debug for Receiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver") + .field("chan", &self.chan) + .finish() + } +} + impl<T> Unpin for Receiver<T> {} cfg_stream! { @@ -186,7 +255,7 @@ cfg_stream! { type Item = T; fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - self.poll_recv(cx) + self.chan.recv(cx) } } } @@ -225,7 +294,7 @@ impl<T> Sender<T> { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(1); + /// let (tx, mut rx) = mpsc::channel(1); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -241,18 +310,49 @@ impl<T> Sender<T> { /// } /// } /// ``` - pub async fn send(&mut self, value: T) -> Result<(), SendError<T>> { - use crate::future::poll_fn; - - if poll_fn(|cx| self.poll_ready(cx)).await.is_err() { - return Err(SendError(value)); + pub async fn send(&self, value: T) -> Result<(), SendError<T>> { + match self.reserve().await { + Ok(permit) => { + permit.send(value); + Ok(()) + } + Err(_) => Err(SendError(value)), } + } - match self.try_send(value) { - Ok(()) => Ok(()), - Err(TrySendError::Full(_)) => unreachable!(), - Err(TrySendError::Closed(value)) => Err(SendError(value)), - } + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx1, rx) = mpsc::channel::<()>(1); + /// let tx2 = tx1.clone(); + /// let tx3 = tx1.clone(); + /// let tx4 = tx1.clone(); + /// let tx5 = tx1.clone(); + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// futures::join!( + /// tx1.closed(), + /// tx2.closed(), + /// tx3.closed(), + /// tx4.closed(), + /// tx5.closed() + /// ); + //// println!("Receiver dropped"); + /// } + /// ``` + pub async fn closed(&self) { + self.chan.closed().await } /// Attempts to immediately send a message on this `Sender` @@ -262,9 +362,6 @@ impl<T> Sender<T> { /// with [`send`], this function has two failure cases instead of one (one for /// disconnection, one for a full buffer). /// - /// This function may be paired with [`poll_ready`] in order to wait for - /// channel capacity before trying to send a value. - /// /// # Errors /// /// If the channel capacity has been reached, i.e., the channel has `n` @@ -276,7 +373,6 @@ impl<T> Sender<T> { /// an error. The error includes the value passed to `send`. /// /// [`send`]: Sender::send - /// [`poll_ready`]: Sender::poll_ready /// [`channel`]: channel /// [`close`]: Receiver::close /// @@ -288,8 +384,8 @@ impl<T> Sender<T> { /// #[tokio::main] /// async fn main() { /// // Create a channel with buffer size 1 - /// let (mut tx1, mut rx) = mpsc::channel(1); - /// let mut tx2 = tx1.clone(); + /// let (tx1, mut rx) = mpsc::channel(1); + /// let tx2 = tx1.clone(); /// /// tokio::spawn(async move { /// tx1.send(1).await.unwrap(); @@ -317,8 +413,15 @@ impl<T> Sender<T> { /// } /// } /// ``` - pub fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { - self.chan.try_send(message)?; + pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)), + } + + // Send the message + self.chan.send(message); Ok(()) } @@ -346,11 +449,11 @@ impl<T> Sender<T> { /// /// ```rust /// use tokio::sync::mpsc; - /// use tokio::time::{delay_for, Duration}; + /// use tokio::time::{sleep, Duration}; /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(1); + /// let (tx, mut rx) = mpsc::channel(1); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -363,117 +466,213 @@ impl<T> Sender<T> { /// /// while let Some(i) = rx.recv().await { /// println!("got = {}", i); - /// delay_for(Duration::from_millis(200)).await; + /// sleep(Duration::from_millis(200)).await; /// } /// } /// ``` #[cfg(feature = "time")] #[cfg_attr(docsrs, doc(cfg(feature = "time")))] pub async fn send_timeout( - &mut self, + &self, value: T, timeout: Duration, ) -> Result<(), SendTimeoutError<T>> { - use crate::future::poll_fn; - - match crate::time::timeout(timeout, poll_fn(|cx| self.poll_ready(cx))).await { + let permit = match crate::time::timeout(timeout, self.reserve()).await { Err(_) => { return Err(SendTimeoutError::Timeout(value)); } Ok(Err(_)) => { return Err(SendTimeoutError::Closed(value)); } - Ok(_) => {} - } + Ok(Ok(permit)) => permit, + }; - match self.try_send(value) { - Ok(()) => Ok(()), - Err(TrySendError::Full(_)) => unreachable!(), - Err(TrySendError::Closed(value)) => Err(SendTimeoutError::Closed(value)), - } + permit.send(value); + Ok(()) } - /// Returns `Poll::Ready(Ok(()))` when the channel is able to accept another item. + /// Blocking send to call outside of asynchronous contexts. /// - /// If the channel is full, then `Poll::Pending` is returned and the task is notified when a - /// slot becomes available. + /// # Panics /// - /// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a call to `try_send` will succeed unless - /// the channel has since been closed. To provide this guarantee, the channel reserves one slot - /// in the channel for the coming send. This reserved slot is not available to other `Sender` - /// instances, so you need to be careful to not end up with deadlocks by blocking after calling - /// `poll_ready` but before sending an element. + /// This function panics if called within an asynchronous execution + /// context. /// - /// If, after `poll_ready` succeeds, you decide you do not wish to send an item after all, you - /// can use [`disarm`](Sender::disarm) to release the reserved slot. + /// # Examples /// - /// Until an item is sent or [`disarm`](Sender::disarm) is called, repeated calls to - /// `poll_ready` will return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))` if channel - /// is closed. - pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> { - self.chan.poll_ready(cx).map_err(|_| ClosedError::new()) + /// ``` + /// use std::thread; + /// use tokio::runtime::Runtime; + /// use tokio::sync::mpsc; + /// + /// fn main() { + /// let (tx, mut rx) = mpsc::channel::<u8>(1); + /// + /// let sync_code = thread::spawn(move || { + /// tx.blocking_send(10).unwrap(); + /// }); + /// + /// Runtime::new().unwrap().block_on(async move { + /// assert_eq!(Some(10), rx.recv().await); + /// }); + /// sync_code.join().unwrap() + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_send(&self, value: T) -> Result<(), SendError<T>> { + crate::future::block_on(self.send(value)) } - /// Undo a successful call to `poll_ready`. + /// Checks if the channel has been closed. This happens when the + /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is + /// called. /// - /// Once a call to `poll_ready` returns `Poll::Ready(Ok(()))`, it holds up one slot in the - /// channel to make room for the coming send. `disarm` allows you to give up that slot if you - /// decide you do not wish to send an item after all. After calling `disarm`, you must call - /// `poll_ready` until it returns `Poll::Ready(Ok(()))` before attempting to send again. + /// [`Receiver`]: crate::sync::mpsc::Receiver + /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close /// - /// Returns `false` if no slot is reserved for this sender (usually because `poll_ready` was - /// not previously called, or did not succeed). + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::channel::<()>(42); + /// assert!(!tx.is_closed()); /// - /// # Motivation + /// let tx2 = tx.clone(); + /// assert!(!tx2.is_closed()); /// - /// Since `poll_ready` takes up one of the finite number of slots in a bounded channel, callers - /// need to send an item shortly after `poll_ready` succeeds. If they do not, idle senders may - /// take up all the slots of the channel, and prevent active senders from getting any requests - /// through. Consider this code that forwards from one channel to another: + /// drop(rx); + /// assert!(tx.is_closed()); + /// assert!(tx2.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } + + /// Wait for channel capacity. Once capacity to send one message is + /// available, it is reserved for the caller. + /// + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. A [`Permit`] is returned to track + /// the reserved capacity. The [`send`] function on [`Permit`] consumes the + /// reserved capacity. + /// + /// Dropping [`Permit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`Permit`]: Permit + /// [`send`]: Permit::send + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); /// - /// ```rust,ignore - /// loop { - /// ready!(tx.poll_ready(cx))?; - /// if let Some(item) = ready!(rx.poll_recv(cx)) { - /// tx.try_send(item)?; - /// } else { - /// break; - /// } + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); /// } /// ``` + pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> { + match self.chan.semaphore().0.acquire(1).await { + Ok(_) => {} + Err(_) => return Err(SendError(())), + } + + Ok(Permit { chan: &self.chan }) + } +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Self { + Sender { + chan: self.chan.clone(), + } + } +} + +impl<T> fmt::Debug for Sender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender") + .field("chan", &self.chan) + .finish() + } +} + +// ===== impl Permit ===== + +impl<T> Permit<'_, T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); /// - /// If many such forwarders exist, and they all forward into a single (cloned) `Sender`, then - /// any number of forwarders may be waiting for `rx.poll_recv` at the same time. While they do, - /// they are effectively each reducing the channel's capacity by 1. If enough of these - /// forwarders are idle, forwarders whose `rx` _do_ have elements will be unable to find a spot - /// for them through `poll_ready`, and the system will deadlock. - /// - /// `disarm` solves this problem by allowing you to give up the reserved slot if you find that - /// you have to block. We can then fix the code above by writing: - /// - /// ```rust,ignore - /// loop { - /// ready!(tx.poll_ready(cx))?; - /// let item = rx.poll_recv(cx); - /// if let Poll::Ready(Ok(_)) = item { - /// // we're going to send the item below, so don't disarm - /// } else { - /// // give up our send slot, we won't need it for a while - /// tx.disarm(); - /// } - /// if let Some(item) = ready!(item) { - /// tx.try_send(item)?; - /// } else { - /// break; - /// } + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Send a message on the permit + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); /// } /// ``` - pub fn disarm(&mut self) -> bool { - if self.chan.is_ready() { - self.chan.disarm(); - true - } else { - false + pub fn send(self, value: T) { + use std::mem; + + self.chan.send(value); + + // Avoid the drop logic + mem::forget(self); + } +} + +impl<T> Drop for Permit<'_, T> { + fn drop(&mut self) { + use chan::Semaphore; + + let semaphore = self.chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + if semaphore.is_closed() && semaphore.is_idle() { + self.chan.wake_rx(); } } } + +impl<T> fmt::Debug for Permit<'_, T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Permit") + .field("chan", &self.chan) + .finish() + } +} diff --git a/src/sync/mpsc/chan.rs b/src/sync/mpsc/chan.rs index 148ee3a..c78fb50 100644 --- a/src/sync/mpsc/chan.rs +++ b/src/sync/mpsc/chan.rs @@ -2,8 +2,9 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; -use crate::sync::mpsc::error::{ClosedError, TryRecvError}; -use crate::sync::mpsc::{error, list}; +use crate::sync::mpsc::error::TryRecvError; +use crate::sync::mpsc::list; +use crate::sync::notify::Notify; use std::fmt; use std::process; @@ -12,21 +13,13 @@ use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; /// Channel sender -pub(crate) struct Tx<T, S: Semaphore> { +pub(crate) struct Tx<T, S> { inner: Arc<Chan<T, S>>, - permit: S::Permit, } -impl<T, S: Semaphore> fmt::Debug for Tx<T, S> -where - S::Permit: fmt::Debug, - S: fmt::Debug, -{ +impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Tx") - .field("inner", &self.inner) - .field("permit", &self.permit) - .finish() + fmt.debug_struct("Tx").field("inner", &self.inner).finish() } } @@ -35,70 +28,26 @@ pub(crate) struct Rx<T, S: Semaphore> { inner: Arc<Chan<T, S>>, } -impl<T, S: Semaphore> fmt::Debug for Rx<T, S> -where - S: fmt::Debug, -{ +impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Rx").field("inner", &self.inner).finish() } } -#[derive(Debug, Eq, PartialEq)] -pub(crate) enum TrySendError { - Closed, - Full, -} - -impl<T> From<(T, TrySendError)> for error::SendError<T> { - fn from(src: (T, TrySendError)) -> error::SendError<T> { - match src.1 { - TrySendError::Closed => error::SendError(src.0), - TrySendError::Full => unreachable!(), - } - } -} - -impl<T> From<(T, TrySendError)> for error::TrySendError<T> { - fn from(src: (T, TrySendError)) -> error::TrySendError<T> { - match src.1 { - TrySendError::Closed => error::TrySendError::Closed(src.0), - TrySendError::Full => error::TrySendError::Full(src.0), - } - } -} - pub(crate) trait Semaphore { - type Permit; - - fn new_permit() -> Self::Permit; - - /// The permit is dropped without a value being sent. In this case, the - /// permit must be returned to the semaphore. - fn drop_permit(&self, permit: &mut Self::Permit); - fn is_idle(&self) -> bool; fn add_permit(&self); - fn poll_acquire( - &self, - cx: &mut Context<'_>, - permit: &mut Self::Permit, - ) -> Poll<Result<(), ClosedError>>; - - fn try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>; - - /// A value was sent into the channel and the permit held by `tx` is - /// dropped. In this case, the permit should not immeditely be returned to - /// the semaphore. Instead, the permit is returnred to the semaphore once - /// the sent value is read by the rx handle. - fn forget(&self, permit: &mut Self::Permit); - fn close(&self); + + fn is_closed(&self) -> bool; } struct Chan<T, S> { + /// Notifies all tasks listening for the receiver being dropped + notify_rx_closed: Notify, + /// Handle to the push half of the lock-free list. tx: list::Tx<T>, @@ -153,13 +102,11 @@ impl<T> fmt::Debug for RxFields<T> { unsafe impl<T: Send, S: Send> Send for Chan<T, S> {} unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {} -pub(crate) fn channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) -where - S: Semaphore, -{ +pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) { let (tx, rx) = list::channel(); let chan = Arc::new(Chan { + notify_rx_closed: Notify::new(), tx, semaphore, rx_waker: AtomicWaker::new(), @@ -175,48 +122,60 @@ where // ===== impl Tx ===== -impl<T, S> Tx<T, S> -where - S: Semaphore, -{ +impl<T, S> Tx<T, S> { fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> { - Tx { - inner: chan, - permit: S::new_permit(), - } - } - - pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> { - self.inner.semaphore.poll_acquire(cx, &mut self.permit) + Tx { inner: chan } } - pub(crate) fn disarm(&mut self) { - // TODO: should this error if not acquired? - self.inner.semaphore.drop_permit(&mut self.permit) + pub(super) fn semaphore(&self) -> &S { + &self.inner.semaphore } /// Send a message and notify the receiver. - pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> { - self.inner.try_send(value, &mut self.permit) + pub(crate) fn send(&self, value: T) { + self.inner.send(value); } -} -impl<T> Tx<T, (crate::sync::semaphore_ll::Semaphore, usize)> { - pub(crate) fn is_ready(&self) -> bool { - self.permit.is_acquired() + /// Wake the receive half + pub(crate) fn wake_rx(&self) { + self.inner.rx_waker.wake(); } } -impl<T> Tx<T, AtomicUsize> { - pub(crate) fn send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)> { - self.inner.try_send(value, &mut ()) +impl<T, S: Semaphore> Tx<T, S> { + pub(crate) fn is_closed(&self) -> bool { + self.inner.semaphore.is_closed() + } + + pub(crate) async fn closed(&self) { + use std::future::Future; + use std::pin::Pin; + use std::task::Poll; + + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. Requesting the notification + // requires polling the future once. + let notified = self.inner.notify_rx_closed.notified(); + pin!(notified); + + // Polling the future once is guaranteed to return `Pending` as `watch` + // only notifies using `notify_waiters`. + crate::future::poll_fn(|cx| { + let res = Pin::new(&mut notified).poll(cx); + assert!(!res.is_ready()); + Poll::Ready(()) + }) + .await; + + if self.inner.semaphore.is_closed() { + return; + } + notified.await; } } -impl<T, S> Clone for Tx<T, S> -where - S: Semaphore, -{ +impl<T, S> Clone for Tx<T, S> { fn clone(&self) -> Tx<T, S> { // Using a Relaxed ordering here is sufficient as the caller holds a // strong ref to `self`, preventing a concurrent decrement to zero. @@ -224,18 +183,12 @@ where Tx { inner: self.inner.clone(), - permit: S::new_permit(), } } } -impl<T, S> Drop for Tx<T, S> -where - S: Semaphore, -{ +impl<T, S> Drop for Tx<T, S> { fn drop(&mut self) { - self.inner.semaphore.drop_permit(&mut self.permit); - if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 { return; } @@ -244,16 +197,13 @@ where self.inner.tx.close(); // Notify the receiver - self.inner.rx_waker.wake(); + self.wake_rx(); } } // ===== impl Rx ===== -impl<T, S> Rx<T, S> -where - S: Semaphore, -{ +impl<T, S: Semaphore> Rx<T, S> { fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> { Rx { inner: chan } } @@ -270,6 +220,7 @@ where }); self.inner.semaphore.close(); + self.inner.notify_rx_closed.notify_waiters(); } /// Receive the next value @@ -341,10 +292,7 @@ where } } -impl<T, S> Drop for Rx<T, S> -where - S: Semaphore, -{ +impl<T, S: Semaphore> Drop for Rx<T, S> { fn drop(&mut self) { use super::block::Read::Value; @@ -362,25 +310,13 @@ where // ===== impl Chan ===== -impl<T, S> Chan<T, S> -where - S: Semaphore, -{ - fn try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)> { - if let Err(e) = self.semaphore.try_acquire(permit) { - return Err((value, e)); - } - +impl<T, S> Chan<T, S> { + fn send(&self, value: T) { // Push the value self.tx.push(value); // Notify the rx task self.rx_waker.wake(); - - // Release the permit - self.semaphore.forget(permit); - - Ok(()) } } @@ -399,72 +335,24 @@ impl<T, S> Drop for Chan<T, S> { } } -use crate::sync::semaphore_ll::TryAcquireError; - -impl From<TryAcquireError> for TrySendError { - fn from(src: TryAcquireError) -> TrySendError { - if src.is_closed() { - TrySendError::Closed - } else if src.is_no_permits() { - TrySendError::Full - } else { - unreachable!(); - } - } -} - // ===== impl Semaphore for (::Semaphore, capacity) ===== -use crate::sync::semaphore_ll::Permit; - -impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { - type Permit = Permit; - - fn new_permit() -> Permit { - Permit::new() - } - - fn drop_permit(&self, permit: &mut Permit) { - permit.release(1, &self.0); - } - +impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { fn add_permit(&self) { - self.0.add_permits(1) + self.0.release(1) } fn is_idle(&self) -> bool { self.0.available_permits() == self.1 } - fn poll_acquire( - &self, - cx: &mut Context<'_>, - permit: &mut Permit, - ) -> Poll<Result<(), ClosedError>> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - permit - .poll_acquire(cx, 1, &self.0) - .map_err(|_| ClosedError::new()) - .map(move |r| { - coop.made_progress(); - r - }) - } - - fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { - permit.try_acquire(1, &self.0)?; - Ok(()) - } - - fn forget(&self, permit: &mut Self::Permit) { - permit.forget(1); - } - fn close(&self) { self.0.close(); } + + fn is_closed(&self) -> bool { + self.0.is_closed() + } } // ===== impl Semaphore for AtomicUsize ===== @@ -473,12 +361,6 @@ use std::sync::atomic::Ordering::{Acquire, Release}; use std::usize; impl Semaphore for AtomicUsize { - type Permit = (); - - fn new_permit() {} - - fn drop_permit(&self, _permit: &mut ()) {} - fn add_permit(&self) { let prev = self.fetch_sub(2, Release); @@ -492,40 +374,11 @@ impl Semaphore for AtomicUsize { self.load(Acquire) >> 1 == 0 } - fn poll_acquire( - &self, - _cx: &mut Context<'_>, - permit: &mut (), - ) -> Poll<Result<(), ClosedError>> { - Ready(self.try_acquire(permit).map_err(|_| ClosedError::new())) - } - - fn try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError> { - let mut curr = self.load(Acquire); - - loop { - if curr & 1 == 1 { - return Err(TrySendError::Closed); - } - - if curr == usize::MAX ^ 1 { - // Overflowed the ref count. There is no safe way to recover, so - // abort the process. In practice, this should never happen. - process::abort() - } - - match self.compare_exchange(curr, curr + 2, AcqRel, Acquire) { - Ok(_) => return Ok(()), - Err(actual) => { - curr = actual; - } - } - } - } - - fn forget(&self, _permit: &mut ()) {} - fn close(&self) { self.fetch_or(1, Release); } + + fn is_closed(&self) -> bool { + self.load(Acquire) & 1 == 1 + } } diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs index 72c42aa..7705452 100644 --- a/src/sync/mpsc/error.rs +++ b/src/sync/mpsc/error.rs @@ -94,26 +94,6 @@ impl fmt::Display for TryRecvError { impl Error for TryRecvError {} -// ===== ClosedError ===== - -/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready). -#[derive(Debug)] -pub struct ClosedError(()); - -impl ClosedError { - pub(crate) fn new() -> ClosedError { - ClosedError(()) - } -} - -impl fmt::Display for ClosedError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "channel closed") - } -} - -impl Error for ClosedError {} - cfg_time! { // ===== SendTimeoutError ===== diff --git a/src/sync/mpsc/list.rs b/src/sync/mpsc/list.rs index 53f82a2..2f4c532 100644 --- a/src/sync/mpsc/list.rs +++ b/src/sync/mpsc/list.rs @@ -1,9 +1,7 @@ //! A concurrent, lock-free, FIFO list. -use crate::loom::{ - sync::atomic::{AtomicPtr, AtomicUsize}, - thread, -}; +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; +use crate::loom::thread; use crate::sync::mpsc::block::{self, Block}; use std::fmt; diff --git a/src/sync/mpsc/mod.rs b/src/sync/mpsc/mod.rs index c489c9f..a2bcf83 100644 --- a/src/sync/mpsc/mod.rs +++ b/src/sync/mpsc/mod.rs @@ -1,23 +1,29 @@ #![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] -//! A multi-producer, single-consumer queue for sending values across +//! A multi-producer, single-consumer queue for sending values between //! asynchronous tasks. //! -//! Similar to `std`, channel creation provides [`Receiver`] and [`Sender`] -//! handles. [`Receiver`] implements `Stream` and allows a task to read values -//! out of the channel. If there is no message to read, the current task will be -//! notified when a new value is sent. If the channel is at capacity, the send -//! is rejected and the task will be notified when additional capacity is -//! available. In other words, the channel provides backpressure. -//! //! This module provides two variants of the channel: bounded and unbounded. The //! bounded variant has a limit on the number of messages that the channel can //! store, and if this limit is reached, trying to send another message will //! wait until a message is received from the channel. An unbounded channel has -//! an infinite capacity, so the `send` method never does any kind of sleeping. +//! an infinite capacity, so the `send` method will always complete immediately. //! This makes the [`UnboundedSender`] usable from both synchronous and //! asynchronous code. //! +//! Similar to the `mpsc` channels provided by `std`, the channel constructor +//! functions provide separate send and receive handles, [`Sender`] and +//! [`Receiver`] for the bounded channel, [`UnboundedSender`] and +//! [`UnboundedReceiver`] for the unbounded channel. Both [`Receiver`] and +//! [`UnboundedReceiver`] implement [`Stream`] and allow a task to read +//! values out of the channel. If there is no message to read, the current task +//! will be notified when a new value is sent. [`Sender`] and +//! [`UnboundedSender`] allow sending values into the channel. If the bounded +//! channel is at capacity, the send is rejected and the task will be notified +//! when additional capacity is available. In other words, the channel provides +//! backpressure. +//! +//! //! # Disconnection //! //! When all [`Sender`] handles have been dropped, it is no longer @@ -43,11 +49,10 @@ //! are two situations to consider: //! //! **Bounded channel**: If you need a bounded channel, you should use a bounded -//! Tokio `mpsc` channel for both directions of communication. To call the async -//! [`send`][bounded-send] or [`recv`][bounded-recv] methods in sync code, you -//! will need to use [`Handle::block_on`], which allow you to execute an async -//! method in synchronous code. This is necessary because a bounded channel may -//! need to wait for additional capacity to become available. +//! Tokio `mpsc` channel for both directions of communication. Instead of calling +//! the async [`send`][bounded-send] or [`recv`][bounded-recv] methods, in +//! synchronous code you will need to use the [`blocking_send`][blocking-send] or +//! [`blocking_recv`][blocking-recv] methods. //! //! **Unbounded channel**: You should use the kind of channel that matches where //! the receiver is. So for sending a message _from async to sync_, you should @@ -57,9 +62,13 @@ //! //! [`Sender`]: crate::sync::mpsc::Sender //! [`Receiver`]: crate::sync::mpsc::Receiver +//! [`Stream`]: crate::stream::Stream //! [bounded-send]: crate::sync::mpsc::Sender::send() //! [bounded-recv]: crate::sync::mpsc::Receiver::recv() +//! [blocking-send]: crate::sync::mpsc::Sender::blocking_send() +//! [blocking-recv]: crate::sync::mpsc::Receiver::blocking_recv() //! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender +//! [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver //! [`Handle::block_on`]: crate::runtime::Handle::block_on() //! [std-unbounded]: std::sync::mpsc::channel //! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html @@ -67,7 +76,7 @@ pub(super) mod block; mod bounded; -pub use self::bounded::{channel, Receiver, Sender}; +pub use self::bounded::{channel, Permit, Receiver, Sender}; mod chan; diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs index 1b2288a..fe882d5 100644 --- a/src/sync/mpsc/unbounded.rs +++ b/src/sync/mpsc/unbounded.rs @@ -47,7 +47,7 @@ impl<T> fmt::Debug for UnboundedReceiver<T> { } /// Creates an unbounded mpsc channel for communicating between asynchronous -/// tasks. +/// tasks without backpressure. /// /// A `send` on this channel will always succeed as long as the receive half has /// not been closed. If the receiver falls behind, messages will be arbitrarily @@ -73,8 +73,7 @@ impl<T> UnboundedReceiver<T> { UnboundedReceiver { chan } } - #[doc(hidden)] // TODO: doc - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { self.chan.recv(cx) } @@ -174,7 +173,97 @@ impl<T> UnboundedSender<T> { /// [`close`]: UnboundedReceiver::close /// [`UnboundedReceiver`]: UnboundedReceiver pub fn send(&self, message: T) -> Result<(), SendError<T>> { - self.chan.send_unbounded(message)?; + if !self.inc_num_messages() { + return Err(SendError(message)); + } + + self.chan.send(message); Ok(()) } + + fn inc_num_messages(&self) -> bool { + use std::process; + use std::sync::atomic::Ordering::{AcqRel, Acquire}; + + let mut curr = self.chan.semaphore().load(Acquire); + + loop { + if curr & 1 == 1 { + return false; + } + + if curr == usize::MAX ^ 1 { + // Overflowed the ref count. There is no safe way to recover, so + // abort the process. In practice, this should never happen. + process::abort() + } + + match self + .chan + .semaphore() + .compare_exchange(curr, curr + 2, AcqRel, Acquire) + { + Ok(_) => return true, + Err(actual) => { + curr = actual; + } + } + } + } + + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx1, rx) = mpsc::unbounded_channel::<()>(); + /// let tx2 = tx1.clone(); + /// let tx3 = tx1.clone(); + /// let tx4 = tx1.clone(); + /// let tx5 = tx1.clone(); + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// futures::join!( + /// tx1.closed(), + /// tx2.closed(), + /// tx3.closed(), + /// tx4.closed(), + /// tx5.closed() + /// ); + //// println!("Receiver dropped"); + /// } + /// ``` + pub async fn closed(&self) { + self.chan.closed().await + } + /// Checks if the channel has been closed. This happens when the + /// [`UnboundedReceiver`] is dropped, or when the + /// [`UnboundedReceiver::close`] method is called. + /// + /// [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver + /// [`UnboundedReceiver::close`]: crate::sync::mpsc::UnboundedReceiver::close + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>(); + /// assert!(!tx.is_closed()); + /// + /// let tx2 = tx.clone(); + /// assert!(!tx2.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// assert!(tx2.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } } diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs index 642058b..21e44ca 100644 --- a/src/sync/mutex.rs +++ b/src/sync/mutex.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + use crate::sync::batch_semaphore as semaphore; use std::cell::UnsafeCell; @@ -115,7 +117,6 @@ use std::sync::Arc; /// [`std::sync::Mutex`]: struct@std::sync::Mutex /// [`Send`]: trait@std::marker::Send /// [`lock`]: method@Mutex::lock -#[derive(Debug)] pub struct Mutex<T: ?Sized> { s: semaphore::Semaphore, c: UnsafeCell<T>, @@ -220,6 +221,27 @@ impl<T: ?Sized> Mutex<T> { } } + /// Creates a new lock in an unlocked state ready for use. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// static LOCK: Mutex<i32> = Mutex::const_new(5); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(t: T) -> Self + where + T: Sized, + { + Self { + c: UnsafeCell::new(t), + s: semaphore::Semaphore::const_new(1), + } + } + /// Locks this mutex, causing the current task /// to yield until the lock has been acquired. /// When the lock has been acquired, function returns a [`MutexGuard`]. @@ -305,6 +327,30 @@ impl<T: ?Sized> Mutex<T> { } } + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `Mutex` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// fn main() { + /// let mut mutex = Mutex::new(1); + /// + /// let n = mutex.get_mut(); + /// *n = 2; + /// } + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.c.get() + } + } + /// Attempts to acquire the lock, and returns [`TryLockError`] if the lock /// is currently held somewhere else. /// @@ -373,6 +419,20 @@ where } } +impl<T> std::fmt::Debug for Mutex<T> +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("Mutex"); + match self.try_lock() { + Ok(inner) => d.field("data", &*inner), + Err(_) => d.field("data", &format_args!("<locked>")), + }; + d.finish() + } +} + // === impl MutexGuard === impl<T: ?Sized> Drop for MutexGuard<'_, T> { diff --git a/src/sync/notify.rs b/src/sync/notify.rs index 5cb41e8..922f109 100644 --- a/src/sync/notify.rs +++ b/src/sync/notify.rs @@ -1,3 +1,10 @@ +// Allow `unreachable_pub` warnings when sync is not enabled +// due to the usage of `Notify` within the `rt` feature set. +// When this module is compiled with `sync` enabled we will warn on +// this lint. When `rt` is enabled we use `pub(crate)` which +// triggers this warning but it is safe to ignore in this case. +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + use crate::loom::sync::atomic::AtomicU8; use crate::loom::sync::Mutex; use crate::util::linked_list::{self, LinkedList}; @@ -10,6 +17,8 @@ use std::ptr::NonNull; use std::sync::atomic::Ordering::SeqCst; use std::task::{Context, Poll, Waker}; +type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; + /// Notify a single task to wake up. /// /// `Notify` provides a basic mechanism to notify a single task of an event. @@ -17,20 +26,20 @@ use std::task::{Context, Poll, Waker}; /// another task to perform an operation. /// /// `Notify` can be thought of as a [`Semaphore`] starting with 0 permits. -/// [`notified().await`] waits for a permit to become available, and [`notify()`] +/// [`notified().await`] waits for a permit to become available, and [`notify_one()`] /// sets a permit **if there currently are no available permits**. /// /// The synchronization details of `Notify` are similar to /// [`thread::park`][park] and [`Thread::unpark`][unpark] from std. A [`Notify`] /// value contains a single permit. [`notified().await`] waits for the permit to -/// be made available, consumes the permit, and resumes. [`notify()`] sets the +/// be made available, consumes the permit, and resumes. [`notify_one()`] sets the /// permit, waking a pending task if there is one. /// -/// If `notify()` is called **before** `notfied().await`, then the next call to +/// If `notify_one()` is called **before** `notified().await`, then the next call to /// `notified().await` will complete immediately, consuming the permit. Any /// subsequent calls to `notified().await` will wait for a new permit. /// -/// If `notify()` is called **multiple** times before `notified().await`, only a +/// If `notify_one()` is called **multiple** times before `notified().await`, only a /// **single** permit is stored. The next call to `notified().await` will /// complete immediately, but the one after will wait for a new permit. /// @@ -53,7 +62,7 @@ use std::task::{Context, Poll, Waker}; /// }); /// /// println!("sending notification"); -/// notify.notify(); +/// notify.notify_one(); /// } /// ``` /// @@ -76,7 +85,7 @@ use std::task::{Context, Poll, Waker}; /// .push_back(value); /// /// // Notify the consumer a value is available -/// self.notify.notify(); +/// self.notify.notify_one(); /// } /// /// pub async fn recv(&self) -> T { @@ -96,12 +105,20 @@ use std::task::{Context, Poll, Waker}; /// [park]: std::thread::park /// [unpark]: std::thread::Thread::unpark /// [`notified().await`]: Notify::notified() -/// [`notify()`]: Notify::notify() +/// [`notify_one()`]: Notify::notify_one() /// [`Semaphore`]: crate::sync::Semaphore #[derive(Debug)] pub struct Notify { state: AtomicU8, - waiters: Mutex<LinkedList<Waiter>>, + waiters: Mutex<WaitList>, +} + +#[derive(Debug, Clone, Copy)] +enum NotificationType { + // Notification triggered by calling `notify_waiters` + AllWaiters, + // Notification triggered by calling `notify_one` + OneWaiter, } #[derive(Debug)] @@ -113,7 +130,7 @@ struct Waiter { waker: Option<Waker>, /// `true` if the notification has been assigned to this waiter. - notified: bool, + notified: Option<NotificationType>, /// Should not be `Unpin`. _p: PhantomPinned, @@ -121,7 +138,7 @@ struct Waiter { /// Future returned from `notified()` #[derive(Debug)] -struct Notified<'a> { +pub struct Notified<'a> { /// The `Notify` being received on. notify: &'a Notify, @@ -168,14 +185,38 @@ impl Notify { } } + /// Create a new `Notify`, initialized without a permit. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// + /// static NOTIFY: Notify = Notify::const_new(); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new() -> Notify { + Notify { + state: AtomicU8::new(0), + waiters: Mutex::const_new(LinkedList::new()), + } + } + /// Wait for a notification. /// + /// Equivalent to: + /// + /// ```ignore + /// async fn notified(&self); + /// ``` + /// /// Each `Notify` value holds a single permit. If a permit is available from - /// an earlier call to [`notify()`], then `notified().await` will complete + /// an earlier call to [`notify_one()`], then `notified().await` will complete /// immediately, consuming that permit. Otherwise, `notified().await` waits - /// for a permit to be made available by the next call to `notify()`. + /// for a permit to be made available by the next call to `notify_one()`. /// - /// [`notify()`]: Notify::notify + /// [`notify_one()`]: Notify::notify_one /// /// # Examples /// @@ -194,21 +235,20 @@ impl Notify { /// }); /// /// println!("sending notification"); - /// notify.notify(); + /// notify.notify_one(); /// } /// ``` - pub async fn notified(&self) { + pub fn notified(&self) -> Notified<'_> { Notified { notify: self, state: State::Init, waiter: UnsafeCell::new(Waiter { pointers: linked_list::Pointers::new(), waker: None, - notified: false, + notified: None, _p: PhantomPinned, }), } - .await } /// Notifies a waiting task @@ -216,10 +256,10 @@ impl Notify { /// If a task is currently waiting, that task is notified. Otherwise, a /// permit is stored in this `Notify` value and the **next** call to /// [`notified().await`] will complete immediately consuming the permit made - /// available by this call to `notify()`. + /// available by this call to `notify_one()`. /// /// At most one permit may be stored by `Notify`. Many sequential calls to - /// `notify` will result in a single permit being stored. The next call to + /// `notify_one` will result in a single permit being stored. The next call to /// `notified().await` will complete immediately, but the one after that /// will wait. /// @@ -242,10 +282,10 @@ impl Notify { /// }); /// /// println!("sending notification"); - /// notify.notify(); + /// notify.notify_one(); /// } /// ``` - pub fn notify(&self) { + pub fn notify_one(&self) { // Load the current state let mut curr = self.state.load(SeqCst); @@ -266,7 +306,7 @@ impl Notify { } // There are waiters, the lock must be acquired to notify. - let mut waiters = self.waiters.lock().unwrap(); + let mut waiters = self.waiters.lock(); // The state must be reloaded while the lock is held. The state may only // transition out of WAITING while the lock is held. @@ -277,6 +317,45 @@ impl Notify { waker.wake(); } } + + /// Notifies all waiting tasks + pub(crate) fn notify_waiters(&self) { + // There are waiters, the lock must be acquired to notify. + let mut waiters = self.waiters.lock(); + + // The state must be reloaded while the lock is held. The state may only + // transition out of WAITING while the lock is held. + let curr = self.state.load(SeqCst); + + if let EMPTY | NOTIFIED = curr { + // There are no waiting tasks. In this case, no synchronization is + // established between `notify` and `notified().await`. + return; + } + + // At this point, it is guaranteed that the state will not + // concurrently change, as holding the lock is required to + // transition **out** of `WAITING`. + // + // Get pending waiters + while let Some(mut waiter) = waiters.pop_back() { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.notified.is_none()); + + waiter.notified = Some(NotificationType::AllWaiters); + + if let Some(waker) = waiter.waker.take() { + waker.wake(); + } + } + + // All waiters have been notified, the state must be transitioned to + // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be + // held, a `store` is sufficient. + self.state.store(EMPTY, SeqCst); + } } impl Default for Notify { @@ -285,7 +364,7 @@ impl Default for Notify { } } -fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) -> Option<Waker> { +fn notify_locked(waiters: &mut WaitList, state: &AtomicU8, curr: u8) -> Option<Waker> { loop { match curr { EMPTY | NOTIFIED => { @@ -311,9 +390,9 @@ fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) - // Safety: `waiters` lock is still held. let waiter = unsafe { waiter.as_mut() }; - assert!(!waiter.notified); + assert!(waiter.notified.is_none()); - waiter.notified = true; + waiter.notified = Some(NotificationType::OneWaiter); let waker = waiter.waker.take(); if waiters.is_empty() { @@ -373,7 +452,7 @@ impl Future for Notified<'_> { // Acquire the lock and attempt to transition to the waiting // state. - let mut waiters = notify.waiters.lock().unwrap(); + let mut waiters = notify.waiters.lock(); // Reload the state with the lock held let mut curr = notify.state.load(SeqCst); @@ -428,6 +507,8 @@ impl Future for Notified<'_> { waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); *state = Waiting; + + return Poll::Pending; } Waiting => { // Currently in the "Waiting" state, implying the caller has @@ -435,16 +516,16 @@ impl Future for Notified<'_> { // `notify.waiters`). In order to access the waker fields, // we must hold the lock. - let waiters = notify.waiters.lock().unwrap(); + let waiters = notify.waiters.lock(); // Safety: called while locked let w = unsafe { &mut *waiter.get() }; - if w.notified { + if w.notified.is_some() { // Our waker has been notified. Reset the fields and // remove it from the list. w.waker = None; - w.notified = false; + w.notified = None; *state = Done; } else { @@ -483,12 +564,12 @@ impl Drop for Notified<'_> { // longer stored in the linked list. if let Waiting = *state { let mut notify_state = WAITING; - let mut waiters = notify.waiters.lock().unwrap(); + let mut waiters = notify.waiters.lock(); // `Notify.state` may be in any of the three states (Empty, Waiting, // Notified). It doesn't actually matter what the atomic is set to // at this point. We hold the lock and will ensure the atomic is in - // the correct state once th elock is dropped. + // the correct state once the lock is dropped. // // Because the atomic state is not checked, at first glance, it may // seem like this routine does not handle the case where the @@ -516,14 +597,13 @@ impl Drop for Notified<'_> { notify.state.store(EMPTY, SeqCst); } - // See if the node was notified but not received. In this case, the - // notification must be sent to another waiter. + // See if the node was notified but not received. In this case, if + // the notification was triggered via `notify_one`, it must be sent + // to the next waiter. // // Safety: with the entry removed from the linked list, there can be // no concurrent access to the entry - let notified = unsafe { (*waiter.get()).notified }; - - if notified { + if let Some(NotificationType::OneWaiter) = unsafe { (*waiter.get()).notified } { if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) { drop(waiters); waker.wake(); diff --git a/src/sync/oneshot.rs b/src/sync/oneshot.rs index 17767e7..951ab71 100644 --- a/src/sync/oneshot.rs +++ b/src/sync/oneshot.rs @@ -124,7 +124,6 @@ struct State(usize); /// } /// ``` pub fn channel<T>() -> (Sender<T>, Receiver<T>) { - #[allow(deprecated)] let inner = Arc::new(Inner { state: AtomicUsize::new(State::new().as_usize()), value: UnsafeCell::new(None), @@ -197,8 +196,7 @@ impl<T> Sender<T> { Ok(()) } - #[doc(hidden)] // TODO: remove - pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { + fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { // Keep track of task budget let coop = ready!(crate::coop::poll_proceed(cx)); diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs index 3d2a2f7..a84c4c1 100644 --- a/src/sync/rwlock.rs +++ b/src/sync/rwlock.rs @@ -1,5 +1,8 @@ -use crate::sync::batch_semaphore::{AcquireError, Semaphore}; +use crate::sync::batch_semaphore::Semaphore; use std::cell::UnsafeCell; +use std::fmt; +use std::marker; +use std::mem; use std::ops; #[cfg(not(loom))] @@ -8,7 +11,7 @@ const MAX_READS: usize = 32; #[cfg(loom)] const MAX_READS: usize = 10; -/// An asynchronous reader-writer lock +/// An asynchronous reader-writer lock. /// /// This type of lock allows a number of readers or at most one writer at any /// point in time. The write portion of this lock typically allows modification @@ -83,10 +86,140 @@ pub struct RwLock<T: ?Sized> { /// [`RwLock`]. /// /// [`read`]: method@RwLock::read -#[derive(Debug)] +/// [`RwLock`]: struct@RwLock pub struct RwLockReadGuard<'a, T: ?Sized> { - permit: ReleasingPermit<'a, T>, - lock: &'a RwLock<T>, + s: &'a Semaphore, + data: *const T, + marker: marker::PhantomData<&'a T>, +} + +impl<'a, T> RwLockReadGuard<'a, T> { + /// Make a new `RwLockReadGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockReadGuard::map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockReadGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockReadGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = lock.read().await; + /// let guard = RwLockReadGuard::map(guard, |f| &f.0); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(this: Self, f: F) -> RwLockReadGuard<'a, U> + where + F: FnOnce(&T) -> &U, + { + let data = f(&*this) as *const U; + let s = this.s; + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`RwLockReadGuard`] for a component of the + /// locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockReadGuard::try_map(..)`. A method would interfere with methods of the + /// same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockReadGuard::try_map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockReadGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = lock.read().await; + /// let guard = RwLockReadGuard::try_map(guard, |f| Some(&f.0)).expect("should not fail"); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>(this: Self, f: F) -> Result<RwLockReadGuard<'a, U>, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + let data = match f(&*this) { + Some(data) => data as *const U, + None => return Err(this), + }; + let s = this.s; + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + Ok(RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + }) + } +} + +impl<'a, T: ?Sized> fmt::Debug for RwLockReadGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> fmt::Display for RwLockReadGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> { + fn drop(&mut self) { + self.s.release(1); + } } /// RAII structure used to release the exclusive write access of a lock when @@ -97,32 +230,195 @@ pub struct RwLockReadGuard<'a, T: ?Sized> { /// /// [`write`]: method@RwLock::write /// [`RwLock`]: struct@RwLock -#[derive(Debug)] pub struct RwLockWriteGuard<'a, T: ?Sized> { - permit: ReleasingPermit<'a, T>, - lock: &'a RwLock<T>, + s: &'a Semaphore, + data: *mut T, + marker: marker::PhantomData<&'a mut T>, } -// Wrapper arround Permit that releases on Drop -#[derive(Debug)] -struct ReleasingPermit<'a, T: ?Sized> { - num_permits: u16, - lock: &'a RwLock<T>, +impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { + /// Make a new `RwLockWriteGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockWriteGuard::map(..)`. A method would interfere with methods of + /// the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockWriteGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + RwLockWriteGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`RwLockWriteGuard`] for a component of + /// the locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockWriteGuard::try_map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from + /// the [`parking_lot` crate]. + /// + /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let guard = lock.write().await; + /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>(mut this: Self, f: F) -> Result<RwLockWriteGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + Ok(RwLockWriteGuard { + s, + data, + marker: marker::PhantomData, + }) + } + + /// Atomically downgrades a write lock into a read lock without allowing + /// any writers to take exclusive access of the lock in the meantime. + /// + /// **Note:** This won't *necessarily* allow any additional readers to acquire + /// locks, since [`RwLock`] is fair and it is possible that a writer is next + /// in line. + /// + /// Returns an RAII guard which will drop the read access of this rwlock + /// when dropped. + /// + /// # Examples + /// + /// ``` + /// # use tokio::sync::RwLock; + /// # use std::sync::Arc; + /// # + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// + /// let n = lock.write().await; + /// + /// let cloned_lock = lock.clone(); + /// let handle = tokio::spawn(async move { + /// *cloned_lock.write().await = 2; + /// }); + /// + /// let n = n.downgrade(); + /// assert_eq!(*n, 1, "downgrade is atomic"); + /// + /// assert_eq!(*lock.read().await, 1, "additional readers can obtain locks"); + /// + /// drop(n); + /// handle.await.unwrap(); + /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock"); + /// # } + /// ``` + /// + /// [`RwLock`]: struct@RwLock + pub fn downgrade(self) -> RwLockReadGuard<'a, T> { + let RwLockWriteGuard { s, data, .. } = self; + + // Release all but one of the permits held by the write guard + s.release(MAX_READS - 1); + + RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + } + } } -impl<'a, T: ?Sized> ReleasingPermit<'a, T> { - async fn acquire( - lock: &'a RwLock<T>, - num_permits: u16, - ) -> Result<ReleasingPermit<'a, T>, AcquireError> { - lock.s.acquire(num_permits.into()).await?; - Ok(Self { num_permits, lock }) +impl<'a, T: ?Sized> fmt::Debug for RwLockWriteGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) } } -impl<T: ?Sized> Drop for ReleasingPermit<'_, T> { +impl<'a, T: ?Sized> fmt::Display for RwLockWriteGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> { fn drop(&mut self) { - self.lock.s.release(self.num_permits as usize); + self.s.release(MAX_READS); } } @@ -139,9 +435,11 @@ fn bounds() { check_sync::<RwLock<u32>>(); check_unpin::<RwLock<u32>>(); + check_send::<RwLockReadGuard<'_, u32>>(); check_sync::<RwLockReadGuard<'_, u32>>(); check_unpin::<RwLockReadGuard<'_, u32>>(); + check_send::<RwLockWriteGuard<'_, u32>>(); check_sync::<RwLockWriteGuard<'_, u32>>(); check_unpin::<RwLockWriteGuard<'_, u32>>(); @@ -155,8 +453,17 @@ fn bounds() { // RwLock<T>. unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {} unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {} +// NB: These impls need to be explicit since we're storing a raw pointer. +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send`. +unsafe impl<T> Send for RwLockReadGuard<'_, T> where T: ?Sized + Sync {} unsafe impl<T> Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {} unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send` - but since this is also provides mutable access, we need to +// make sure that `T` is `Send` since its value can be sent across thread +// boundaries. +unsafe impl<T> Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} impl<T: ?Sized> RwLock<T> { /// Creates a new instance of an `RwLock<T>` which is unlocked. @@ -178,6 +485,27 @@ impl<T: ?Sized> RwLock<T> { } } + /// Creates a new instance of an `RwLock<T>` which is unlocked. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// static LOCK: RwLock<i32> = RwLock::const_new(5); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(value: T) -> RwLock<T> + where + T: Sized, + { + RwLock { + c: UnsafeCell::new(value), + s: Semaphore::const_new(MAX_READS), + } + } + /// Locks this rwlock with shared read access, causing the current task /// to yield until the lock has been acquired. /// @@ -210,12 +538,16 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn read(&self) -> RwLockReadGuard<'_, T> { - let permit = ReleasingPermit::acquire(self, 1).await.unwrap_or_else(|_| { + self.s.acquire(1).await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); - RwLockReadGuard { lock: self, permit } + RwLockReadGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + } } /// Locks this rwlock with exclusive write access, causing the current task @@ -241,15 +573,40 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn write(&self) -> RwLockWriteGuard<'_, T> { - let permit = ReleasingPermit::acquire(self, MAX_READS as u16) - .await - .unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. - unreachable!() - }); - - RwLockWriteGuard { lock: self, permit } + self.s.acquire(MAX_READS as u32).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + RwLockWriteGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + } + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `RwLock` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// fn main() { + /// let mut lock = RwLock::new(1); + /// + /// let n = lock.get_mut(); + /// *n = 2; + /// } + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.c.get() + } } /// Consumes the lock, returning the underlying data. @@ -265,7 +622,7 @@ impl<T: ?Sized> ops::Deref for RwLockReadGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - unsafe { &*self.lock.c.get() } + unsafe { &*self.data } } } @@ -273,13 +630,13 @@ impl<T: ?Sized> ops::Deref for RwLockWriteGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - unsafe { &*self.lock.c.get() } + unsafe { &*self.data } } } impl<T: ?Sized> ops::DerefMut for RwLockWriteGuard<'_, T> { fn deref_mut(&mut self) -> &mut T { - unsafe { &mut *self.lock.c.get() } + unsafe { &mut *self.data } } } @@ -289,7 +646,7 @@ impl<T> From<T> for RwLock<T> { } } -impl<T> Default for RwLock<T> +impl<T: ?Sized> Default for RwLock<T> where T: Default, { diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs index 2489d34..43dd976 100644 --- a/src/sync/semaphore.rs +++ b/src/sync/semaphore.rs @@ -1,7 +1,7 @@ use super::batch_semaphore as ll; // low level implementation use std::sync::Arc; -/// Counting semaphore performing asynchronous permit aquisition. +/// Counting semaphore performing asynchronous permit acquisition. /// /// A semaphore maintains a set of permits. Permits are used to synchronize /// access to a shared resource. A semaphore differs from a mutex in that it @@ -74,6 +74,15 @@ impl Semaphore { } } + /// Creates a new semaphore with the initial number of permits. + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(permits: usize) -> Self { + Self { + ll_sem: ll::Semaphore::const_new(permits), + } + } + /// Returns the current number of available permits. pub fn available_permits(&self) -> usize { self.ll_sem.available_permits() @@ -114,7 +123,7 @@ impl Semaphore { pub async fn acquire_owned(self: Arc<Self>) -> OwnedSemaphorePermit { self.ll_sem.acquire(1).await.unwrap(); OwnedSemaphorePermit { - sem: self.clone(), + sem: self, permits: 1, } } @@ -127,7 +136,7 @@ impl Semaphore { pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> { match self.ll_sem.try_acquire(1) { Ok(_) => Ok(OwnedSemaphorePermit { - sem: self.clone(), + sem: self, permits: 1, }), Err(_) => Err(TryAcquireError(())), diff --git a/src/sync/semaphore_ll.rs b/src/sync/semaphore_ll.rs deleted file mode 100644 index 25d25ac..0000000 --- a/src/sync/semaphore_ll.rs +++ /dev/null @@ -1,1221 +0,0 @@ -#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] - -//! Thread-safe, asynchronous counting semaphore. -//! -//! A `Semaphore` instance holds a set of permits. Permits are used to -//! synchronize access to a shared resource. -//! -//! Before accessing the shared resource, callers acquire a permit from the -//! semaphore. Once the permit is acquired, the caller then enters the critical -//! section. If no permits are available, then acquiring the semaphore returns -//! `Pending`. The task is woken once a permit becomes available. - -use crate::loom::cell::UnsafeCell; -use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; -use crate::loom::thread; - -use std::cmp; -use std::fmt; -use std::ptr::{self, NonNull}; -use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release}; -use std::task::Poll::{Pending, Ready}; -use std::task::{Context, Poll}; -use std::usize; - -/// Futures-aware semaphore. -pub(crate) struct Semaphore { - /// Tracks both the waiter queue tail pointer and the number of remaining - /// permits. - state: AtomicUsize, - - /// waiter queue head pointer. - head: UnsafeCell<NonNull<Waiter>>, - - /// Coordinates access to the queue head. - rx_lock: AtomicUsize, - - /// Stub waiter node used as part of the MPSC channel algorithm. - stub: Box<Waiter>, -} - -/// A semaphore permit -/// -/// Tracks the lifecycle of a semaphore permit. -/// -/// An instance of `Permit` is intended to be used with a **single** instance of -/// `Semaphore`. Using a single instance of `Permit` with multiple semaphore -/// instances will result in unexpected behavior. -/// -/// `Permit` does **not** release the permit back to the semaphore on drop. It -/// is the user's responsibility to ensure that `Permit::release` is called -/// before dropping the permit. -#[derive(Debug)] -pub(crate) struct Permit { - waiter: Option<Box<Waiter>>, - state: PermitState, -} - -/// Error returned by `Permit::poll_acquire`. -#[derive(Debug)] -pub(crate) struct AcquireError(()); - -/// Error returned by `Permit::try_acquire`. -#[derive(Debug)] -pub(crate) enum TryAcquireError { - Closed, - NoPermits, -} - -/// Node used to notify the semaphore waiter when permit is available. -#[derive(Debug)] -struct Waiter { - /// Stores waiter state. - /// - /// See `WaiterState` for more details. - state: AtomicUsize, - - /// Task to wake when a permit is made available. - waker: AtomicWaker, - - /// Next pointer in the queue of waiting senders. - next: AtomicPtr<Waiter>, -} - -/// Semaphore state -/// -/// The 2 low bits track the modes. -/// -/// - Closed -/// - Full -/// -/// When not full, the rest of the `usize` tracks the total number of messages -/// in the channel. When full, the rest of the `usize` is a pointer to the tail -/// of the "waiting senders" queue. -#[derive(Copy, Clone)] -struct SemState(usize); - -/// Permit state -#[derive(Debug, Copy, Clone)] -enum PermitState { - /// Currently waiting for permits to be made available and assigned to the - /// waiter. - Waiting(u16), - - /// The number of acquired permits - Acquired(u16), -} - -/// State for an individual waker node -#[derive(Debug, Copy, Clone)] -struct WaiterState(usize); - -/// Waiter node is in the semaphore queue -const QUEUED: usize = 0b001; - -/// Semaphore has been closed, no more permits will be issued. -const CLOSED: usize = 0b10; - -/// The permit that owns the `Waiter` dropped. -const DROPPED: usize = 0b100; - -/// Represents "one requested permit" in the waiter state -const PERMIT_ONE: usize = 0b1000; - -/// Masks the waiter state to only contain bits tracking number of requested -/// permits. -const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1); - -/// How much to shift a permit count to pack it into the waker state -const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros(); - -/// Flag differentiating between available permits and waiter pointers. -/// -/// If we assume pointers are properly aligned, then the least significant bit -/// will always be zero. So, we use that bit to track if the value represents a -/// number. -const NUM_FLAG: usize = 0b01; - -/// Signal the semaphore is closed -const CLOSED_FLAG: usize = 0b10; - -/// Maximum number of permits a semaphore can manage -const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT; - -/// When representing "numbers", the state has to be shifted this much (to get -/// rid of the flag bit). -const NUM_SHIFT: usize = 2; - -// ===== impl Semaphore ===== - -impl Semaphore { - /// Creates a new semaphore with the initial number of permits - /// - /// # Panics - /// - /// Panics if `permits` is zero. - pub(crate) fn new(permits: usize) -> Semaphore { - let stub = Box::new(Waiter::new()); - let ptr = NonNull::from(&*stub); - - // Allocations are aligned - debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0); - - let state = SemState::new(permits, &stub); - - Semaphore { - state: AtomicUsize::new(state.to_usize()), - head: UnsafeCell::new(ptr), - rx_lock: AtomicUsize::new(0), - stub, - } - } - - /// Returns the current number of available permits - pub(crate) fn available_permits(&self) -> usize { - let curr = SemState(self.state.load(Acquire)); - curr.available_permits() - } - - /// Tries to acquire the requested number of permits, registering the waiter - /// if not enough permits are available. - fn poll_acquire( - &self, - cx: &mut Context<'_>, - num_permits: u16, - permit: &mut Permit, - ) -> Poll<Result<(), AcquireError>> { - self.poll_acquire2(num_permits, || { - let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new())); - - waiter.waker.register_by_ref(cx.waker()); - - Some(NonNull::from(&**waiter)) - }) - } - - fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { - match self.poll_acquire2(num_permits, || None) { - Poll::Ready(res) => res.map_err(to_try_acquire), - Poll::Pending => Err(TryAcquireError::NoPermits), - } - } - - /// Polls for a permit - /// - /// Tries to acquire available permits first. If unable to acquire a - /// sufficient number of permits, the caller's waiter is pushed onto the - /// semaphore's wait queue. - fn poll_acquire2<F>( - &self, - num_permits: u16, - mut get_waiter: F, - ) -> Poll<Result<(), AcquireError>> - where - F: FnMut() -> Option<NonNull<Waiter>>, - { - let num_permits = num_permits as usize; - - // Load the current state - let mut curr = SemState(self.state.load(Acquire)); - - // Saves a ref to the waiter node - let mut maybe_waiter: Option<NonNull<Waiter>> = None; - - /// Used in branches where we attempt to push the waiter into the wait - /// queue but fail due to permits becoming available or the wait queue - /// transitioning to "closed". In this case, the waiter must be - /// transitioned back to the "idle" state. - macro_rules! revert_to_idle { - () => { - if let Some(waiter) = maybe_waiter { - unsafe { waiter.as_ref() }.revert_to_idle(); - } - }; - } - - loop { - let mut next = curr; - - if curr.is_closed() { - revert_to_idle!(); - return Ready(Err(AcquireError::closed())); - } - - let acquired = next.acquire_permits(num_permits, &self.stub); - - if !acquired { - // There are not enough available permits to satisfy the - // request. The permit transitions to a waiting state. - debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits); - - if let Some(waiter) = maybe_waiter.as_ref() { - // Safety: the caller owns the waiter. - let w = unsafe { waiter.as_ref() }; - w.set_permits_to_acquire(num_permits - curr.available_permits()); - } else { - // Get the waiter for the permit. - if let Some(waiter) = get_waiter() { - // Safety: the caller owns the waiter. - let w = unsafe { waiter.as_ref() }; - - // If there are any currently available permits, the - // waiter acquires those immediately and waits for the - // remaining permits to become available. - if !w.to_queued(num_permits - curr.available_permits()) { - // The node is alrady queued, there is no further work - // to do. - return Pending; - } - - maybe_waiter = Some(waiter); - } else { - // No waiter, this indicates the caller does not wish to - // "wait", so there is nothing left to do. - return Pending; - } - } - - next.set_waiter(maybe_waiter.unwrap()); - } - - debug_assert_ne!(curr.0, 0); - debug_assert_ne!(next.0, 0); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - if acquired { - // Successfully acquire permits **without** queuing the - // waiter node. The waiter node is not currently in the - // queue. - revert_to_idle!(); - return Ready(Ok(())); - } else { - // The node is pushed into the queue, the final step is - // to set the node's "next" pointer to return the wait - // queue into a consistent state. - - let prev_waiter = - curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub)); - - let waiter = maybe_waiter.unwrap(); - - // Link the nodes. - // - // Safety: the mpsc algorithm guarantees the old tail of - // the queue is not removed from the queue during the - // push process. - unsafe { - prev_waiter.as_ref().store_next(waiter); - } - - return Pending; - } - } - Err(actual) => { - curr = SemState(actual); - } - } - } - } - - /// Closes the semaphore. This prevents the semaphore from issuing new - /// permits and notifies all pending waiters. - pub(crate) fn close(&self) { - // Acquire the `rx_lock`, setting the "closed" flag on the lock. - let prev = self.rx_lock.fetch_or(1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(0, true); - } - /// Adds `n` new permits to the semaphore. - /// - /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. - pub(crate) fn add_permits(&self, n: usize) { - if n == 0 { - return; - } - - // TODO: Handle overflow. A panic is not sufficient, the process must - // abort. - let prev = self.rx_lock.fetch_add(n << 1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(n, false); - } - - fn add_permits_locked(&self, mut rem: usize, mut closed: bool) { - while rem > 0 || closed { - if closed { - SemState::fetch_set_closed(&self.state, AcqRel); - } - - // Release the permits and notify - self.add_permits_locked2(rem, closed); - - let n = rem << 1; - - let actual = if closed { - let actual = self.rx_lock.fetch_sub(n | 1, AcqRel); - closed = false; - actual - } else { - let actual = self.rx_lock.fetch_sub(n, AcqRel); - closed = actual & 1 == 1; - actual - }; - - rem = (actual >> 1) - rem; - } - } - - /// Releases a specific amount of permits to the semaphore - /// - /// This function is called by `add_permits` after the add lock has been - /// acquired. - fn add_permits_locked2(&self, mut n: usize, closed: bool) { - // If closing the semaphore, we want to drain the entire queue. The - // number of permits being assigned doesn't matter. - if closed { - n = usize::MAX; - } - - 'outer: while n > 0 { - unsafe { - let mut head = self.head.with(|head| *head); - let mut next_ptr = head.as_ref().next.load(Acquire); - - let stub = self.stub(); - - if head == stub { - // The stub node indicates an empty queue. Any remaining - // permits get assigned back to the semaphore. - let next = match NonNull::new(next_ptr) { - Some(next) => next, - None => { - // This loop is not part of the standard intrusive mpsc - // channel algorithm. This is where we atomically pop - // the last task and add `n` to the remaining capacity. - // - // This modification to the pop algorithm works because, - // at this point, we have not done any work (only done - // reading). We have a *pretty* good idea that there is - // no concurrent pusher. - // - // The capacity is then atomically added by doing an - // AcqRel CAS on `state`. The `state` cell is the - // linchpin of the algorithm. - // - // By successfully CASing `head` w/ AcqRel, we ensure - // that, if any thread was racing and entered a push, we - // see that and abort pop, retrying as it is - // "inconsistent". - let mut curr = SemState::load(&self.state, Acquire); - - loop { - if curr.has_waiter(&self.stub) { - // A waiter is being added concurrently. - // This is the MPSC queue's "inconsistent" - // state and we must loop and try again. - thread::yield_now(); - continue 'outer; - } - - // If closing, nothing more to do. - if closed { - debug_assert!(curr.is_closed(), "state = {:?}", curr); - return; - } - - let mut next = curr; - next.release_permits(n, &self.stub); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return, - Err(actual) => { - curr = SemState(actual); - } - } - } - } - }; - - self.head.with_mut(|head| *head = next); - head = next; - next_ptr = next.as_ref().next.load(Acquire); - } - - // `head` points to a waiter assign permits to the waiter. If - // all requested permits are satisfied, then we can continue, - // otherwise the node stays in the wait queue. - if !head.as_ref().assign_permits(&mut n, closed) { - assert_eq!(n, 0); - return; - } - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - self.remove_queued(head, closed); - continue 'outer; - } - - let state = SemState::load(&self.state, Acquire); - - // This must always be a pointer as the wait list is not empty. - let tail = state.waiter().unwrap(); - - if tail != head { - // Inconsistent - thread::yield_now(); - continue 'outer; - } - - self.push_stub(closed); - - next_ptr = head.as_ref().next.load(Acquire); - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - self.remove_queued(head, closed); - continue 'outer; - } - - // Inconsistent state, loop - thread::yield_now(); - } - } - } - - /// The wait node has had all of its permits assigned and has been removed - /// from the wait queue. - /// - /// Attempt to remove the QUEUED bit from the node. If additional permits - /// are concurrently requested, the node must be pushed back into the wait - /// queued. - fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) { - let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire)); - - loop { - if curr.is_dropped() { - // The Permit dropped, it is on us to release the memory - let _ = unsafe { Box::from_raw(waiter.as_ptr()) }; - return; - } - - // The node is removed from the queue. We attempt to unset the - // queued bit, but concurrently the waiter has requested more - // permits. When the waiter requested more permits, it saw the - // queued bit set so took no further action. This requires us to - // push the node back into the queue. - if curr.permits_to_acquire() > 0 { - // More permits are requested. The waiter must be re-queued - unsafe { - self.push_waiter(waiter, closed); - } - return; - } - - let mut next = curr; - next.unset_queued(); - - let w = unsafe { waiter.as_ref() }; - - match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return, - Err(actual) => { - curr = WaiterState(actual); - } - } - } - } - - unsafe fn push_stub(&self, closed: bool) { - self.push_waiter(self.stub(), closed); - } - - unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) { - // Set the next pointer. This does not require an atomic operation as - // this node is not accessible. The write will be flushed with the next - // operation - waiter.as_ref().next.store(ptr::null_mut(), Relaxed); - - // Update the tail to point to the new node. We need to see the previous - // node in order to update the next pointer as well as release `task` - // to any other threads calling `push`. - let next = SemState::new_ptr(waiter, closed); - let prev = SemState(self.state.swap(next.0, AcqRel)); - - debug_assert_eq!(closed, prev.is_closed()); - - // This function is only called when there are pending tasks. Because of - // this, the state must *always* be in pointer mode. - let prev = prev.waiter().unwrap(); - - // No cycles plz - debug_assert_ne!(prev, waiter); - - // Release `task` to the consume end. - prev.as_ref().next.store(waiter.as_ptr(), Release); - } - - fn stub(&self) -> NonNull<Waiter> { - unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) } - } -} - -impl Drop for Semaphore { - fn drop(&mut self) { - self.close(); - } -} - -impl fmt::Debug for Semaphore { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Semaphore") - .field("state", &SemState::load(&self.state, Relaxed)) - .field("head", &self.head.with(|ptr| ptr)) - .field("rx_lock", &self.rx_lock.load(Relaxed)) - .field("stub", &self.stub) - .finish() - } -} - -unsafe impl Send for Semaphore {} -unsafe impl Sync for Semaphore {} - -// ===== impl Permit ===== - -impl Permit { - /// Creates a new `Permit`. - /// - /// The permit begins in the "unacquired" state. - pub(crate) fn new() -> Permit { - use PermitState::Acquired; - - Permit { - waiter: None, - state: Acquired(0), - } - } - - /// Returns `true` if the permit has been acquired - #[allow(dead_code)] // may be used later - pub(crate) fn is_acquired(&self) -> bool { - match self.state { - PermitState::Acquired(num) if num > 0 => true, - _ => false, - } - } - - /// Tries to acquire the permit. If no permits are available, the current task - /// is notified once a new permit becomes available. - pub(crate) fn poll_acquire( - &mut self, - cx: &mut Context<'_>, - num_permits: u16, - semaphore: &Semaphore, - ) -> Poll<Result<(), AcquireError>> { - use std::cmp::Ordering::*; - use PermitState::*; - - match self.state { - Waiting(requested) => { - // There must be a waiter - let waiter = self.waiter.as_ref().unwrap(); - - match requested.cmp(&num_permits) { - Less => { - let delta = num_permits - requested; - - // Request additional permits. If the waiter has been - // dequeued, it must be re-queued. - if !waiter.try_inc_permits_to_acquire(delta as usize) { - let waiter = NonNull::from(&**waiter); - - // Ignore the result. The check for - // `permits_to_acquire()` will converge the state as - // needed - let _ = semaphore.poll_acquire2(delta, || Some(waiter))?; - } - - self.state = Waiting(num_permits); - } - Greater => { - let delta = requested - num_permits; - let to_release = waiter.try_dec_permits_to_acquire(delta as usize); - - semaphore.add_permits(to_release); - self.state = Waiting(num_permits); - } - Equal => {} - } - - if waiter.permits_to_acquire()? == 0 { - self.state = Acquired(requested); - return Ready(Ok(())); - } - - waiter.waker.register_by_ref(cx.waker()); - - if waiter.permits_to_acquire()? == 0 { - self.state = Acquired(requested); - return Ready(Ok(())); - } - - Pending - } - Acquired(acquired) => { - if acquired >= num_permits { - Ready(Ok(())) - } else { - match semaphore.poll_acquire(cx, num_permits - acquired, self)? { - Ready(()) => { - self.state = Acquired(num_permits); - Ready(Ok(())) - } - Pending => { - self.state = Waiting(num_permits); - Pending - } - } - } - } - } - } - - /// Tries to acquire the permit. - pub(crate) fn try_acquire( - &mut self, - num_permits: u16, - semaphore: &Semaphore, - ) -> Result<(), TryAcquireError> { - use PermitState::*; - - match self.state { - Waiting(requested) => { - // There must be a waiter - let waiter = self.waiter.as_ref().unwrap(); - - if requested > num_permits { - let delta = requested - num_permits; - let to_release = waiter.try_dec_permits_to_acquire(delta as usize); - - semaphore.add_permits(to_release); - self.state = Waiting(num_permits); - } - - let res = waiter.permits_to_acquire().map_err(to_try_acquire)?; - - if res == 0 { - if requested < num_permits { - // Try to acquire the additional permits - semaphore.try_acquire(num_permits - requested)?; - } - - self.state = Acquired(num_permits); - Ok(()) - } else { - Err(TryAcquireError::NoPermits) - } - } - Acquired(acquired) => { - if acquired < num_permits { - semaphore.try_acquire(num_permits - acquired)?; - self.state = Acquired(num_permits); - } - - Ok(()) - } - } - } - - /// Releases a permit back to the semaphore - pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) { - let n = self.forget(n); - semaphore.add_permits(n as usize); - } - - /// Forgets the permit **without** releasing it back to the semaphore. - /// - /// After calling `forget`, `poll_acquire` is able to acquire new permit - /// from the semaphore. - /// - /// Repeatedly calling `forget` without associated calls to `add_permit` - /// will result in the semaphore losing all permits. - /// - /// Will forget **at most** the number of acquired permits. This number is - /// returned. - pub(crate) fn forget(&mut self, n: u16) -> u16 { - use PermitState::*; - - match self.state { - Waiting(requested) => { - let n = cmp::min(n, requested); - - // Decrement - let acquired = self - .waiter - .as_ref() - .unwrap() - .try_dec_permits_to_acquire(n as usize) as u16; - - if n == requested { - self.state = Acquired(0); - } else if acquired == requested - n { - self.state = Waiting(acquired); - } else { - self.state = Waiting(requested - n); - } - - acquired - } - Acquired(acquired) => { - let n = cmp::min(n, acquired); - self.state = Acquired(acquired - n); - n - } - } - } -} - -impl Default for Permit { - fn default() -> Self { - Self::new() - } -} - -impl Drop for Permit { - fn drop(&mut self) { - if let Some(waiter) = self.waiter.take() { - // Set the dropped flag - let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel)); - - if state.is_queued() { - // The waiter is stored in the queue. The semaphore will drop it - std::mem::forget(waiter); - } - } - } -} - -// ===== impl AcquireError ==== - -impl AcquireError { - fn closed() -> AcquireError { - AcquireError(()) - } -} - -fn to_try_acquire(_: AcquireError) -> TryAcquireError { - TryAcquireError::Closed -} - -impl fmt::Display for AcquireError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "semaphore closed") - } -} - -impl std::error::Error for AcquireError {} - -// ===== impl TryAcquireError ===== - -impl TryAcquireError { - /// Returns `true` if the error was caused by a closed semaphore. - pub(crate) fn is_closed(&self) -> bool { - match self { - TryAcquireError::Closed => true, - _ => false, - } - } - - /// Returns `true` if the error was caused by calling `try_acquire` on a - /// semaphore with no available permits. - pub(crate) fn is_no_permits(&self) -> bool { - match self { - TryAcquireError::NoPermits => true, - _ => false, - } - } -} - -impl fmt::Display for TryAcquireError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TryAcquireError::Closed => write!(fmt, "semaphore closed"), - TryAcquireError::NoPermits => write!(fmt, "no permits available"), - } - } -} - -impl std::error::Error for TryAcquireError {} - -// ===== impl Waiter ===== - -impl Waiter { - fn new() -> Waiter { - Waiter { - state: AtomicUsize::new(0), - waker: AtomicWaker::new(), - next: AtomicPtr::new(ptr::null_mut()), - } - } - - fn permits_to_acquire(&self) -> Result<usize, AcquireError> { - let state = WaiterState(self.state.load(Acquire)); - - if state.is_closed() { - Err(AcquireError(())) - } else { - Ok(state.permits_to_acquire()) - } - } - - /// Only increments the number of permits *if* the waiter is currently - /// queued. - /// - /// # Returns - /// - /// `true` if the number of permits to acquire has been incremented. `false` - /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`. - fn try_inc_permits_to_acquire(&self, n: usize) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - if !curr.is_queued() { - assert_eq!(0, curr.permits_to_acquire()); - return false; - } - - let mut next = curr; - next.set_permits_to_acquire(n + curr.permits_to_acquire()); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return true, - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Try to decrement the number of permits to acquire. This returns the - /// actual number of permits that were decremented. The delta betweeen `n` - /// and the return has been assigned to the permit and the caller must - /// assign these back to the semaphore. - fn try_dec_permits_to_acquire(&self, n: usize) -> usize { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - if !curr.is_queued() { - assert_eq!(0, curr.permits_to_acquire()); - } - - let delta = cmp::min(n, curr.permits_to_acquire()); - let rem = curr.permits_to_acquire() - delta; - - let mut next = curr; - next.set_permits_to_acquire(rem); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return n - delta, - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Store the number of remaining permits needed to satisfy the waiter and - /// transition to the "QUEUED" state. - /// - /// # Returns - /// - /// `true` if the `QUEUED` bit was set as part of the transition. - fn to_queued(&self, num_permits: usize) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - // The waiter should **not** be waiting for any permits. - debug_assert_eq!(curr.permits_to_acquire(), 0); - - loop { - let mut next = curr; - next.set_permits_to_acquire(num_permits); - next.set_queued(); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - if curr.is_queued() { - return false; - } else { - // Make sure the next pointer is null - self.next.store(ptr::null_mut(), Relaxed); - return true; - } - } - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Set the number of permits to acquire. - /// - /// This function is only called when the waiter is being inserted into the - /// wait queue. Because of this, there are no concurrent threads that can - /// modify the state and using `store` is safe. - fn set_permits_to_acquire(&self, num_permits: usize) { - debug_assert!(WaiterState(self.state.load(Acquire)).is_queued()); - - let mut state = WaiterState(QUEUED); - state.set_permits_to_acquire(num_permits); - - self.state.store(state.0, Release); - } - - /// Assign permits to the waiter. - /// - /// Returns `true` if the waiter should be removed from the queue - fn assign_permits(&self, n: &mut usize, closed: bool) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - let mut next = curr; - - // Number of permits to assign to this waiter - let assign = cmp::min(curr.permits_to_acquire(), *n); - - // Assign the permits - next.set_permits_to_acquire(curr.permits_to_acquire() - assign); - - if closed { - next.set_closed(); - } - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - // Update `n` - *n -= assign; - - if next.permits_to_acquire() == 0 { - if curr.permits_to_acquire() > 0 { - self.waker.wake(); - } - - return true; - } else { - return false; - } - } - Err(actual) => curr = WaiterState(actual), - } - } - } - - fn revert_to_idle(&self) { - // An idle node is not waiting on any permits - self.state.store(0, Relaxed); - } - - fn store_next(&self, next: NonNull<Waiter>) { - self.next.store(next.as_ptr(), Release); - } -} - -// ===== impl SemState ===== - -impl SemState { - /// Returns a new default `State` value. - fn new(permits: usize, stub: &Waiter) -> SemState { - assert!(permits <= MAX_PERMITS); - - if permits > 0 { - SemState((permits << NUM_SHIFT) | NUM_FLAG) - } else { - SemState(stub as *const _ as usize) - } - } - - /// Returns a `State` tracking `ptr` as the tail of the queue. - fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState { - let mut val = tail.as_ptr() as usize; - - if closed { - val |= CLOSED_FLAG; - } - - SemState(val) - } - - /// Returns the amount of remaining capacity - fn available_permits(self) -> usize { - if !self.has_available_permits() { - return 0; - } - - self.0 >> NUM_SHIFT - } - - /// Returns `true` if the state has permits that can be claimed by a waiter. - fn has_available_permits(self) -> bool { - self.0 & NUM_FLAG == NUM_FLAG - } - - fn has_waiter(self, stub: &Waiter) -> bool { - !self.has_available_permits() && !self.is_stub(stub) - } - - /// Tries to atomically acquire specified number of permits. - /// - /// # Return - /// - /// Returns `true` if the specified number of permits were acquired, `false` - /// otherwise. Returning false does not mean that there are no more - /// available permits. - fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool { - debug_assert!(num > 0); - - if self.available_permits() < num { - return false; - } - - debug_assert!(self.waiter().is_none()); - - self.0 -= num << NUM_SHIFT; - - if self.0 == NUM_FLAG { - // Set the state to the stub pointer. - self.0 = stub as *const _ as usize; - } - - true - } - - /// Releases permits - /// - /// Returns `true` if the permits were accepted. - fn release_permits(&mut self, permits: usize, stub: &Waiter) { - debug_assert!(permits > 0); - - if self.is_stub(stub) { - self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG); - return; - } - - debug_assert!(self.has_available_permits()); - - self.0 += permits << NUM_SHIFT; - } - - fn is_waiter(self) -> bool { - self.0 & NUM_FLAG == 0 - } - - /// Returns the waiter, if one is set. - fn waiter(self) -> Option<NonNull<Waiter>> { - if self.is_waiter() { - let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored"); - - Some(waiter) - } else { - None - } - } - - /// Assumes `self` represents a pointer - fn as_ptr(self) -> *mut Waiter { - (self.0 & !CLOSED_FLAG) as *mut Waiter - } - - /// Sets to a pointer to a waiter. - /// - /// This can only be done from the full state. - fn set_waiter(&mut self, waiter: NonNull<Waiter>) { - let waiter = waiter.as_ptr() as usize; - debug_assert!(!self.is_closed()); - - self.0 = waiter; - } - - fn is_stub(self, stub: &Waiter) -> bool { - self.as_ptr() as usize == stub as *const _ as usize - } - - /// Loads the state from an AtomicUsize. - fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState { - let value = cell.load(ordering); - SemState(value) - } - - fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState { - let value = cell.fetch_or(CLOSED_FLAG, ordering); - SemState(value) - } - - fn is_closed(self) -> bool { - self.0 & CLOSED_FLAG == CLOSED_FLAG - } - - /// Converts the state into a `usize` representation. - fn to_usize(self) -> usize { - self.0 - } -} - -impl fmt::Debug for SemState { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut fmt = fmt.debug_struct("SemState"); - - if self.is_waiter() { - fmt.field("state", &"<waiter>"); - } else { - fmt.field("permits", &self.available_permits()); - } - - fmt.finish() - } -} - -// ===== impl WaiterState ===== - -impl WaiterState { - fn permits_to_acquire(self) -> usize { - self.0 >> PERMIT_SHIFT - } - - fn set_permits_to_acquire(&mut self, val: usize) { - self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK) - } - - fn is_queued(self) -> bool { - self.0 & QUEUED == QUEUED - } - - fn set_queued(&mut self) { - self.0 |= QUEUED; - } - - fn is_closed(self) -> bool { - self.0 & CLOSED == CLOSED - } - - fn set_closed(&mut self) { - self.0 |= CLOSED; - } - - fn unset_queued(&mut self) { - assert!(self.is_queued()); - self.0 -= QUEUED; - } - - fn is_dropped(self) -> bool { - self.0 & DROPPED == DROPPED - } -} diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs index 73b1745..ae4cac7 100644 --- a/src/sync/task/atomic_waker.rs +++ b/src/sync/task/atomic_waker.rs @@ -141,13 +141,12 @@ impl AtomicWaker { } } + /* /// Registers the current waker to be notified on calls to `wake`. - /// - /// This is the same as calling `register_task` with `task::current()`. - #[cfg(feature = "io-driver")] pub(crate) fn register(&self, waker: Waker) { self.do_register(waker); } + */ /// Registers the provided waker to be notified on calls to `wake`. /// diff --git a/src/sync/tests/loom_broadcast.rs b/src/sync/tests/loom_broadcast.rs index da12fb9..4b1f034 100644 --- a/src/sync/tests/loom_broadcast.rs +++ b/src/sync/tests/loom_broadcast.rs @@ -1,5 +1,5 @@ use crate::sync::broadcast; -use crate::sync::broadcast::RecvError::{Closed, Lagged}; +use crate::sync::broadcast::error::RecvError::{Closed, Lagged}; use loom::future::block_on; use loom::sync::Arc; diff --git a/src/sync/tests/loom_cancellation_token.rs b/src/sync/tests/loom_cancellation_token.rs deleted file mode 100644 index e9c9f3d..0000000 --- a/src/sync/tests/loom_cancellation_token.rs +++ /dev/null @@ -1,155 +0,0 @@ -use crate::sync::CancellationToken; - -use loom::{future::block_on, thread}; -use tokio_test::assert_ok; - -#[test] -fn cancel_token() { - loom::model(|| { - let token = CancellationToken::new(); - let token1 = token.clone(); - - let th1 = thread::spawn(move || { - block_on(async { - token1.cancelled().await; - }); - }); - - let th2 = thread::spawn(move || { - token.cancel(); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - }); -} - -#[test] -fn cancel_with_child() { - loom::model(|| { - let token = CancellationToken::new(); - let token1 = token.clone(); - let token2 = token.clone(); - let child_token = token.child_token(); - - let th1 = thread::spawn(move || { - block_on(async { - token1.cancelled().await; - }); - }); - - let th2 = thread::spawn(move || { - token2.cancel(); - }); - - let th3 = thread::spawn(move || { - block_on(async { - child_token.cancelled().await; - }); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} - -#[test] -fn drop_token_no_child() { - loom::model(|| { - let token = CancellationToken::new(); - let token1 = token.clone(); - let token2 = token.clone(); - - let th1 = thread::spawn(move || { - drop(token1); - }); - - let th2 = thread::spawn(move || { - drop(token2); - }); - - let th3 = thread::spawn(move || { - drop(token); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} - -#[test] -fn drop_token_with_childs() { - loom::model(|| { - let token1 = CancellationToken::new(); - let child_token1 = token1.child_token(); - let child_token2 = token1.child_token(); - - let th1 = thread::spawn(move || { - drop(token1); - }); - - let th2 = thread::spawn(move || { - drop(child_token1); - }); - - let th3 = thread::spawn(move || { - drop(child_token2); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} - -#[test] -fn drop_and_cancel_token() { - loom::model(|| { - let token1 = CancellationToken::new(); - let token2 = token1.clone(); - let child_token = token1.child_token(); - - let th1 = thread::spawn(move || { - drop(token1); - }); - - let th2 = thread::spawn(move || { - token2.cancel(); - }); - - let th3 = thread::spawn(move || { - drop(child_token); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} - -#[test] -fn cancel_parent_and_child() { - loom::model(|| { - let token1 = CancellationToken::new(); - let token2 = token1.clone(); - let child_token = token1.child_token(); - - let th1 = thread::spawn(move || { - drop(token1); - }); - - let th2 = thread::spawn(move || { - token2.cancel(); - }); - - let th3 = thread::spawn(move || { - child_token.cancel(); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} diff --git a/src/sync/tests/loom_mpsc.rs b/src/sync/tests/loom_mpsc.rs index 6a1a6ab..c12313b 100644 --- a/src/sync/tests/loom_mpsc.rs +++ b/src/sync/tests/loom_mpsc.rs @@ -2,22 +2,24 @@ use crate::sync::mpsc; use futures::future::poll_fn; use loom::future::block_on; +use loom::sync::Arc; use loom::thread; +use tokio_test::assert_ok; #[test] fn closing_tx() { loom::model(|| { - let (mut tx, mut rx) = mpsc::channel(16); + let (tx, mut rx) = mpsc::channel(16); thread::spawn(move || { tx.try_send(()).unwrap(); drop(tx); }); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_some()); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } @@ -32,15 +34,70 @@ fn closing_unbounded_tx() { drop(tx); }); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_some()); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } #[test] +fn closing_bounded_rx() { + loom::model(|| { + let (tx1, rx) = mpsc::channel::<()>(16); + let tx2 = tx1.clone(); + thread::spawn(move || { + drop(rx); + }); + + block_on(tx1.closed()); + block_on(tx2.closed()); + }); +} + +#[test] +fn closing_and_sending() { + loom::model(|| { + let (tx1, mut rx) = mpsc::channel::<()>(16); + let tx1 = Arc::new(tx1); + let tx2 = tx1.clone(); + + let th1 = thread::spawn(move || { + tx1.try_send(()).unwrap(); + }); + + let th2 = thread::spawn(move || { + block_on(tx2.closed()); + }); + + let th3 = thread::spawn(move || { + let v = block_on(rx.recv()); + assert!(v.is_some()); + drop(rx); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn closing_unbounded_rx() { + loom::model(|| { + let (tx1, rx) = mpsc::unbounded_channel::<()>(); + let tx2 = tx1.clone(); + thread::spawn(move || { + drop(rx); + }); + + block_on(tx1.closed()); + block_on(tx2.closed()); + }); +} + +#[test] fn dropping_tx() { loom::model(|| { let (tx, mut rx) = mpsc::channel::<()>(16); @@ -53,7 +110,7 @@ fn dropping_tx() { } drop(tx); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } @@ -71,7 +128,7 @@ fn dropping_unbounded_tx() { } drop(tx); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } diff --git a/src/sync/tests/loom_notify.rs b/src/sync/tests/loom_notify.rs index 60981d4..79a5bf8 100644 --- a/src/sync/tests/loom_notify.rs +++ b/src/sync/tests/loom_notify.rs @@ -16,7 +16,7 @@ fn notify_one() { }); }); - tx.notify(); + tx.notify_one(); th.join().unwrap(); }); } @@ -34,12 +34,12 @@ fn notify_multi() { ths.push(thread::spawn(move || { block_on(async { notify.notified().await; - notify.notify(); + notify.notify_one(); }) })); } - notify.notify(); + notify.notify_one(); for th in ths.drain(..) { th.join().unwrap(); @@ -67,7 +67,7 @@ fn notify_drop() { block_on(poll_fn(|cx| { if recv.as_mut().poll(cx).is_ready() { - rx1.notify(); + rx1.notify_one(); } Poll::Ready(()) })); @@ -77,12 +77,12 @@ fn notify_drop() { block_on(async { rx2.notified().await; // Trigger second notification - rx2.notify(); + rx2.notify_one(); rx2.notified().await; }); }); - notify.notify(); + notify.notify_one(); th1.join().unwrap(); th2.join().unwrap(); diff --git a/src/sync/tests/loom_oneshot.rs b/src/sync/tests/loom_oneshot.rs index dfa7459..9729cfb 100644 --- a/src/sync/tests/loom_oneshot.rs +++ b/src/sync/tests/loom_oneshot.rs @@ -75,8 +75,10 @@ impl Future for OnClose<'_> { type Output = bool; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> { - let res = self.get_mut().tx.poll_closed(cx); - Ready(res.is_ready()) + let fut = self.get_mut().tx.closed(); + crate::pin!(fut); + + Ready(fut.poll(cx).is_ready()) } } diff --git a/src/sync/tests/loom_semaphore_ll.rs b/src/sync/tests/loom_semaphore_ll.rs deleted file mode 100644 index b5e5efb..0000000 --- a/src/sync/tests/loom_semaphore_ll.rs +++ /dev/null @@ -1,192 +0,0 @@ -use crate::sync::semaphore_ll::*; - -use futures::future::poll_fn; -use loom::future::block_on; -use loom::thread; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::SeqCst; -use std::sync::Arc; -use std::task::Poll::Ready; -use std::task::{Context, Poll}; - -#[test] -fn basic_usage() { - const NUM: usize = 2; - - struct Actor { - waiter: Permit, - shared: Arc<Shared>, - } - - struct Shared { - semaphore: Semaphore, - active: AtomicUsize, - } - - impl Future for Actor { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let me = &mut *self; - - ready!(me.waiter.poll_acquire(cx, 1, &me.shared.semaphore)).unwrap(); - - let actual = me.shared.active.fetch_add(1, SeqCst); - assert!(actual <= NUM - 1); - - let actual = me.shared.active.fetch_sub(1, SeqCst); - assert!(actual <= NUM); - - me.waiter.release(1, &me.shared.semaphore); - - Ready(()) - } - } - - loom::model(|| { - let shared = Arc::new(Shared { - semaphore: Semaphore::new(NUM), - active: AtomicUsize::new(0), - }); - - for _ in 0..NUM { - let shared = shared.clone(); - - thread::spawn(move || { - block_on(Actor { - waiter: Permit::new(), - shared, - }); - }); - } - - block_on(Actor { - waiter: Permit::new(), - shared, - }); - }); -} - -#[test] -fn release() { - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - { - let semaphore = semaphore.clone(); - thread::spawn(move || { - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap(); - - permit.release(1, &semaphore); - }); - } - - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap(); - - permit.release(1, &semaphore); - }); -} - -#[test] -fn basic_closing() { - const NUM: usize = 2; - - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - for _ in 0..NUM { - let semaphore = semaphore.clone(); - - thread::spawn(move || { - let mut permit = Permit::new(); - - for _ in 0..2 { - block_on(poll_fn(|cx| { - permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ()) - }))?; - - permit.release(1, &semaphore); - } - - Ok::<(), ()>(()) - }); - } - - semaphore.close(); - }); -} - -#[test] -fn concurrent_close() { - const NUM: usize = 3; - - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - for _ in 0..NUM { - let semaphore = semaphore.clone(); - - thread::spawn(move || { - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| { - permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ()) - }))?; - - permit.release(1, &semaphore); - - semaphore.close(); - - Ok::<(), ()>(()) - }); - } - }); -} - -#[test] -fn batch() { - let mut b = loom::model::Builder::new(); - b.preemption_bound = Some(1); - - b.check(|| { - let semaphore = Arc::new(Semaphore::new(10)); - let active = Arc::new(AtomicUsize::new(0)); - let mut ths = vec![]; - - for _ in 0..2 { - let semaphore = semaphore.clone(); - let active = active.clone(); - - ths.push(thread::spawn(move || { - let mut permit = Permit::new(); - - for n in &[4, 10, 8] { - block_on(poll_fn(|cx| permit.poll_acquire(cx, *n, &semaphore))).unwrap(); - - active.fetch_add(*n as usize, SeqCst); - - let num_active = active.load(SeqCst); - assert!(num_active <= 10); - - thread::yield_now(); - - active.fetch_sub(*n as usize, SeqCst); - - permit.release(*n, &semaphore); - } - })); - } - - for th in ths.into_iter() { - th.join().unwrap(); - } - - assert_eq!(10, semaphore.available_permits()); - }); -} diff --git a/src/sync/tests/loom_watch.rs b/src/sync/tests/loom_watch.rs new file mode 100644 index 0000000..c575b5b --- /dev/null +++ b/src/sync/tests/loom_watch.rs @@ -0,0 +1,36 @@ +use crate::sync::watch; + +use loom::future::block_on; +use loom::thread; + +#[test] +fn smoke() { + loom::model(|| { + let (tx, mut rx1) = watch::channel(1); + let mut rx2 = rx1.clone(); + let mut rx3 = rx1.clone(); + let mut rx4 = rx1.clone(); + let mut rx5 = rx1.clone(); + + let th = thread::spawn(move || { + tx.send(2).unwrap(); + }); + + block_on(rx1.changed()).unwrap(); + assert_eq!(*rx1.borrow(), 2); + + block_on(rx2.changed()).unwrap(); + assert_eq!(*rx2.borrow(), 2); + + block_on(rx3.changed()).unwrap(); + assert_eq!(*rx3.borrow(), 2); + + block_on(rx4.changed()).unwrap(); + assert_eq!(*rx4.borrow(), 2); + + block_on(rx5.changed()).unwrap(); + assert_eq!(*rx5.borrow(), 2); + + th.join().unwrap(); + }) +} diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs index 6ba8c1f..a78be6f 100644 --- a/src/sync/tests/mod.rs +++ b/src/sync/tests/mod.rs @@ -1,18 +1,15 @@ cfg_not_loom! { mod atomic_waker; - mod semaphore_ll; mod semaphore_batch; } cfg_loom! { mod loom_atomic_waker; mod loom_broadcast; - #[cfg(tokio_unstable)] - mod loom_cancellation_token; mod loom_list; mod loom_mpsc; mod loom_notify; mod loom_oneshot; mod loom_semaphore_batch; - mod loom_semaphore_ll; + mod loom_watch; } diff --git a/src/sync/tests/semaphore_ll.rs b/src/sync/tests/semaphore_ll.rs deleted file mode 100644 index bfb0757..0000000 --- a/src/sync/tests/semaphore_ll.rs +++ /dev/null @@ -1,470 +0,0 @@ -use crate::sync::semaphore_ll::{Permit, Semaphore}; -use tokio_test::*; - -#[test] -fn poll_acquire_one_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - assert!(!permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); -} - -#[test] -fn poll_acquire_many_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - assert!(!permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling for a larger number of permits acquires more - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 8, &s))); - assert_eq!(s.available_permits(), 92); - assert!(permit.is_acquired()); -} - -#[test] -fn try_acquire_one_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = Permit::new(); - assert!(!permit.is_acquired()); - - assert_ok!(permit.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ok!(permit.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); -} - -#[test] -fn try_acquire_many_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = Permit::new(); - assert!(!permit.is_acquired()); - - assert_ok!(permit.try_acquire(5, &s)); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ok!(permit.try_acquire(5, &s)); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); -} - -#[test] -fn poll_acquire_one_unavailable() { - let s = Semaphore::new(1); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - - // Acquire the first permit - assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 0); - - permit_2.enter(|cx, mut p| { - // Try to acquire the second permit - assert_pending!(p.poll_acquire(cx, 1, &s)); - }); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 0); - assert!(permit_2.is_woken()); - assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn forget_acquired() { - let s = Semaphore::new(1); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); - - permit.forget(1); - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn forget_waiting() { - let s = Semaphore::new(0); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); - - permit.forget(1); - - s.add_permits(1); - - assert!(!permit.is_woken()); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn poll_acquire_many_unavailable() { - let s = Semaphore::new(5); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - let mut permit_3 = task::spawn(Permit::new()); - - // Acquire the first permit - assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 4); - - permit_2.enter(|cx, mut p| { - // Try to acquire the second permit - assert_pending!(p.poll_acquire(cx, 5, &s)); - }); - - assert_eq!(s.available_permits(), 0); - - permit_3.enter(|cx, mut p| { - // Try to acquire the third permit - assert_pending!(p.poll_acquire(cx, 3, &s)); - }); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 0); - assert!(permit_2.is_woken()); - assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - - assert!(!permit_3.is_woken()); - assert_eq!(s.available_permits(), 0); - - permit_2.release(1, &s); - assert!(!permit_3.is_woken()); - assert_eq!(s.available_permits(), 0); - - permit_2.release(2, &s); - assert!(permit_3.is_woken()); - - assert_ready_ok!(permit_3.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn try_acquire_one_unavailable() { - let s = Semaphore::new(1); - - let mut permit_1 = Permit::new(); - let mut permit_2 = Permit::new(); - - // Acquire the first permit - assert_ok!(permit_1.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 0); - - assert_err!(permit_2.try_acquire(1, &s)); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 1); - assert_ok!(permit_2.try_acquire(1, &s)); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn try_acquire_many_unavailable() { - let s = Semaphore::new(5); - - let mut permit_1 = Permit::new(); - let mut permit_2 = Permit::new(); - - // Acquire the first permit - assert_ok!(permit_1.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 4); - - assert_err!(permit_2.try_acquire(5, &s)); - - permit_1.release(1, &s); - assert_eq!(s.available_permits(), 5); - - assert_ok!(permit_2.try_acquire(5, &s)); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 2); -} - -#[test] -fn poll_acquire_one_zero_permits() { - let s = Semaphore::new(0); - assert_eq!(s.available_permits(), 0); - - let mut permit = task::spawn(Permit::new()); - - // Try to acquire the permit - permit.enter(|cx, mut p| { - assert_pending!(p.poll_acquire(cx, 1, &s)); - }); - - s.add_permits(1); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -#[should_panic] -fn validates_max_permits() { - use std::usize; - Semaphore::new((usize::MAX >> 2) + 1); -} - -#[test] -fn close_semaphore_prevents_acquire() { - let s = Semaphore::new(5); - s.close(); - - assert_eq!(5, s.available_permits()); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - - assert_ready_err!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(5, s.available_permits()); - - assert_ready_err!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(5, s.available_permits()); -} - -#[test] -fn close_semaphore_notifies_permit1() { - let s = Semaphore::new(0); - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.close(); - - assert!(permit.is_woken()); - assert_ready_err!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -fn close_semaphore_notifies_permit2() { - let s = Semaphore::new(2); - - let mut permit1 = task::spawn(Permit::new()); - let mut permit2 = task::spawn(Permit::new()); - let mut permit3 = task::spawn(Permit::new()); - let mut permit4 = task::spawn(Permit::new()); - - // Acquire a couple of permits - assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_pending!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_pending!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.close(); - - assert!(permit3.is_woken()); - assert!(permit4.is_woken()); - - assert_ready_err!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_ready_err!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(0, s.available_permits()); - - permit1.release(1, &s); - - assert_eq!(1, s.available_permits()); - - assert_ready_err!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - permit2.release(1, &s); - - assert_eq!(2, s.available_permits()); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_before_assigned() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); - - s.add_permits(1); - assert!(!permit.is_woken()); - - s.add_permits(1); - assert!(permit.is_woken()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn try_acquire_additional_permits_while_waiting_before_assigned() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - assert_err!(permit.enter(|_, mut p| p.try_acquire(3, &s))); - - s.add_permits(1); - assert!(permit.is_woken()); - - assert_ok!(permit.enter(|_, mut p| p.try_acquire(2, &s))); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_after_assigned_success() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - s.add_permits(2); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_after_assigned_requeue() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - s.add_permits(2); - - assert!(permit.is_woken()); - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s))); - - s.add_permits(1); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s))); -} - -#[test] -fn poll_acquire_fewer_permits_while_waiting() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(s.available_permits(), 0); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn poll_acquire_fewer_permits_after_assigned() { - let s = Semaphore::new(1); - - let mut permit1 = task::spawn(Permit::new()); - let mut permit2 = task::spawn(Permit::new()); - - assert_pending!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 0); - - assert_pending!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.add_permits(4); - assert!(permit1.is_woken()); - assert!(!permit2.is_woken()); - - assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); - - assert!(permit2.is_woken()); - assert_eq!(s.available_permits(), 1); - - assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -fn forget_partial_1() { - let s = Semaphore::new(0); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - s.add_permits(1); - - assert_eq!(0, s.available_permits()); - - permit.release(1, &s); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn forget_partial_2() { - let s = Semaphore::new(0); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - s.add_permits(1); - - assert_eq!(0, s.available_permits()); - - permit.release(1, &s); - - s.add_permits(1); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(s.available_permits(), 0); -} diff --git a/src/sync/watch.rs b/src/sync/watch.rs index 13033d9..ec73832 100644 --- a/src/sync/watch.rs +++ b/src/sync/watch.rs @@ -6,13 +6,11 @@ //! //! # Usage //! -//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are -//! the producer and sender halves of the channel. The channel is -//! created with an initial value. [`Receiver::recv`] will always -//! be ready upon creation and will yield either this initial value or -//! the latest value that has been sent by `Sender`. -//! -//! Calls to [`Receiver::recv`] will always yield the latest value. +//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer +//! and sender halves of the channel. The channel is created with an initial +//! value. The **latest** value stored in the channel is accessed with +//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new +//! value to sent by the [`Sender`] half. //! //! # Examples //! @@ -23,21 +21,21 @@ //! let (tx, mut rx) = watch::channel("hello"); //! //! tokio::spawn(async move { -//! while let Some(value) = rx.recv().await { -//! println!("received = {:?}", value); +//! while rx.changed().await.is_ok() { +//! println!("received = {:?}", *rx.borrow()); //! } //! }); //! -//! tx.broadcast("world")?; +//! tx.send("world")?; //! # Ok(()) //! # } //! ``` //! //! # Closing //! -//! [`Sender::closed`] allows the producer to detect when all [`Receiver`] -//! handles have been dropped. This indicates that there is no further interest -//! in the values being produced and work can be stopped. +//! [`Sender::is_closed`] and [`Sender::closed`] allow the producer to detect +//! when all [`Receiver`] handles have been dropped. This indicates that there +//! is no further interest in the values being produced and work can be stopped. //! //! # Thread safety //! @@ -47,20 +45,18 @@ //! //! [`Sender`]: crate::sync::watch::Sender //! [`Receiver`]: crate::sync::watch::Receiver -//! [`Receiver::recv`]: crate::sync::watch::Receiver::recv +//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed +//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow //! [`channel`]: crate::sync::watch::channel +//! [`Sender::is_closed`]: crate::sync::watch::Sender::is_closed //! [`Sender::closed`]: crate::sync::watch::Sender::closed -use crate::future::poll_fn; -use crate::sync::task::AtomicWaker; +use crate::sync::Notify; -use fnv::FnvHashSet; use std::ops; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; -use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak}; -use std::task::Poll::{Pending, Ready}; -use std::task::{Context, Poll}; +use std::sync::{Arc, RwLock, RwLockReadGuard}; /// Receives values from the associated [`Sender`](struct@Sender). /// @@ -70,8 +66,8 @@ pub struct Receiver<T> { /// Pointer to the shared state shared: Arc<Shared<T>>, - /// Pointer to the watcher's internal state - inner: Watcher, + /// Last observed version + version: usize, } /// Sends values to the associated [`Receiver`](struct@Receiver). @@ -79,7 +75,7 @@ pub struct Receiver<T> { /// Instances are created by the [`channel`](fn@channel) function. #[derive(Debug)] pub struct Sender<T> { - shared: Weak<Shared<T>>, + shared: Arc<Shared<T>>, } /// Returns a reference to the inner value @@ -92,6 +88,27 @@ pub struct Ref<'a, T> { inner: RwLockReadGuard<'a, T>, } +#[derive(Debug)] +struct Shared<T> { + /// The most recent value + value: RwLock<T>, + + /// The current version + /// + /// The lowest bit represents a "closed" state. The rest of the bits + /// represent the current version. + version: AtomicUsize, + + /// Tracks the number of `Receiver` instances + ref_count_rx: AtomicUsize, + + /// Notifies waiting receivers that the value changed. + notify_rx: Notify, + + /// Notifies any task listening for `Receiver` dropped events + notify_tx: Notify, +} + pub mod error { //! Watch error types @@ -112,37 +129,20 @@ pub mod error { } impl<T: fmt::Debug> std::error::Error for SendError<T> {} -} - -#[derive(Debug)] -struct Shared<T> { - /// The most recent value - value: RwLock<T>, - /// The current version - /// - /// The lowest bit represents a "closed" state. The rest of the bits - /// represent the current version. - version: AtomicUsize, - - /// All watchers - watchers: Mutex<Watchers>, - - /// Task to notify when all watchers drop - cancel: AtomicWaker, -} + /// Error produced when receiving a change notification. + #[derive(Debug)] + pub struct RecvError(pub(super) ()); -type Watchers = FnvHashSet<Watcher>; + // ===== impl RecvError ===== -/// The watcher's ID is based on the Arc's pointer. -#[derive(Clone, Debug)] -struct Watcher(Arc<WatchInner>); + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } -#[derive(Debug)] -struct WatchInner { - /// Last observed version - version: AtomicUsize, - waker: AtomicWaker, + impl std::error::Error for RecvError {} } const CLOSED: usize = 1; @@ -162,41 +162,32 @@ const CLOSED: usize = 1; /// let (tx, mut rx) = watch::channel("hello"); /// /// tokio::spawn(async move { -/// while let Some(value) = rx.recv().await { -/// println!("received = {:?}", value); +/// while rx.changed().await.is_ok() { +/// println!("received = {:?}", *rx.borrow()); /// } /// }); /// -/// tx.broadcast("world")?; +/// tx.send("world")?; /// # Ok(()) /// # } /// ``` /// /// [`Sender`]: struct@Sender /// [`Receiver`]: struct@Receiver -pub fn channel<T: Clone>(init: T) -> (Sender<T>, Receiver<T>) { - const VERSION_0: usize = 0; - const VERSION_1: usize = 2; - - // We don't start knowing VERSION_1 - let inner = Watcher::new_version(VERSION_0); - - // Insert the watcher - let mut watchers = FnvHashSet::with_capacity_and_hasher(0, Default::default()); - watchers.insert(inner.clone()); - +pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { let shared = Arc::new(Shared { value: RwLock::new(init), - version: AtomicUsize::new(VERSION_1), - watchers: Mutex::new(watchers), - cancel: AtomicWaker::new(), + version: AtomicUsize::new(0), + ref_count_rx: AtomicUsize::new(1), + notify_rx: Notify::new(), + notify_tx: Notify::new(), }); let tx = Sender { - shared: Arc::downgrade(&shared), + shared: shared.clone(), }; - let rx = Receiver { shared, inner }; + let rx = Receiver { shared, version: 0 }; (tx, rx) } @@ -221,39 +212,13 @@ impl<T> Receiver<T> { Ref { inner } } - // TODO: document - #[doc(hidden)] - pub fn poll_recv_ref<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll<Option<Ref<'a, T>>> { - // Make sure the task is up to date - self.inner.waker.register_by_ref(cx.waker()); - - let state = self.shared.version.load(SeqCst); - let version = state & !CLOSED; - - if self.inner.version.swap(version, Relaxed) != version { - let inner = self.shared.value.read().unwrap(); - - return Ready(Some(Ref { inner })); - } - - if CLOSED == state & CLOSED { - // The `Store` handle has been dropped. - return Ready(None); - } - - Pending - } -} - -impl<T: Clone> Receiver<T> { - /// Attempts to clone the latest value sent via the channel. + /// Wait for a change notification /// - /// If this is the first time the function is called on a `Receiver` - /// instance, then the function completes immediately with the **current** - /// value held by the channel. On the next call, the function waits until - /// a new value is sent in the channel. + /// Returns when a new value has been sent by the [`Sender`] since the last + /// time `changed()` was called. When the `Sender` half is dropped, `Err` is + /// returned. /// - /// `None` is returned if the `Sender` half is dropped. + /// [`Sender`]: struct@Sender /// /// # Examples /// @@ -264,118 +229,170 @@ impl<T: Clone> Receiver<T> { /// async fn main() { /// let (tx, mut rx) = watch::channel("hello"); /// - /// let v = rx.recv().await.unwrap(); - /// assert_eq!(v, "hello"); - /// /// tokio::spawn(async move { - /// tx.broadcast("goodbye").unwrap(); + /// tx.send("goodbye").unwrap(); /// }); /// - /// // Waits for the new task to spawn and send the value. - /// let v = rx.recv().await.unwrap(); - /// assert_eq!(v, "goodbye"); + /// assert!(rx.changed().await.is_ok()); + /// assert_eq!(*rx.borrow(), "goodbye"); /// - /// let v = rx.recv().await; - /// assert!(v.is_none()); + /// // The `tx` handle has been dropped + /// assert!(rx.changed().await.is_err()); /// } /// ``` - pub async fn recv(&mut self) -> Option<T> { - poll_fn(|cx| { - let v_ref = ready!(self.poll_recv_ref(cx)); - Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone())) + pub async fn changed(&mut self) -> Result<(), error::RecvError> { + use std::future::Future; + use std::pin::Pin; + use std::task::Poll; + + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. Requesting the notification + // requires polling the future once. + let notified = self.shared.notify_rx.notified(); + pin!(notified); + + // Polling the future once is guaranteed to return `Pending` as `watch` + // only notifies using `notify_waiters`. + crate::future::poll_fn(|cx| { + let res = Pin::new(&mut notified).poll(cx); + assert!(!res.is_ready()); + Poll::Ready(()) }) - .await + .await; + + if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { + return ret; + } + + notified.await; + + maybe_changed(&self.shared, &mut self.version) + .expect("[bug] failed to observe change after notificaton.") } } -#[cfg(feature = "stream")] -impl<T: Clone> crate::stream::Stream for Receiver<T> { - type Item = T; - - fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - let v_ref = ready!(self.poll_recv_ref(cx)); +fn maybe_changed<T>( + shared: &Shared<T>, + version: &mut usize, +) -> Option<Result<(), error::RecvError>> { + // Load the version from the state + let state = shared.version.load(SeqCst); + let new_version = state & !CLOSED; + + if *version != new_version { + // Observe the new version and return + *version = new_version; + return Some(Ok(())); + } - Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone())) + if CLOSED == state & CLOSED { + // All receivers have dropped. + return Some(Err(error::RecvError(()))); } + + None } impl<T> Clone for Receiver<T> { fn clone(&self) -> Self { - let ver = self.inner.version.load(Relaxed); - let inner = Watcher::new_version(ver); + let version = self.version; let shared = self.shared.clone(); - shared.watchers.lock().unwrap().insert(inner.clone()); + // No synchronization necessary as this is only used as a counter and + // not memory access. + shared.ref_count_rx.fetch_add(1, Relaxed); - Receiver { shared, inner } + Receiver { version, shared } } } impl<T> Drop for Receiver<T> { fn drop(&mut self) { - self.shared.watchers.lock().unwrap().remove(&self.inner); + // No synchronization necessary as this is only used as a counter and + // not memory access. + if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) { + // This is the last `Receiver` handle, tasks waiting on `Sender::closed()` + self.shared.notify_tx.notify_waiters(); + } } } impl<T> Sender<T> { - /// Broadcasts a new value via the channel, notifying all receivers. - pub fn broadcast(&self, value: T) -> Result<(), error::SendError<T>> { - let shared = match self.shared.upgrade() { - Some(shared) => shared, - // All `Watch` handles have been canceled - None => return Err(error::SendError { inner: value }), - }; - - // Replace the value - { - let mut lock = shared.value.write().unwrap(); - *lock = value; + /// Sends a new value via the channel, notifying all receivers. + pub fn send(&self, value: T) -> Result<(), error::SendError<T>> { + // This is pretty much only useful as a hint anyway, so synchronization isn't critical. + if 0 == self.shared.ref_count_rx.load(Relaxed) { + return Err(error::SendError { inner: value }); } + *self.shared.value.write().unwrap() = value; + // Update the version. 2 is used so that the CLOSED bit is not set. - shared.version.fetch_add(2, SeqCst); + self.shared.version.fetch_add(2, SeqCst); // Notify all watchers - notify_all(&*shared); + self.shared.notify_rx.notify_waiters(); Ok(()) } + /// Checks if the channel has been closed. This happens when all receivers + /// have dropped. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::watch::channel(()); + /// assert!(!tx.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.shared.ref_count_rx.load(Relaxed) == 0 + } + /// Completes when all receivers have dropped. /// /// This allows the producer to get notified when interest in the produced /// values is canceled and immediately stop doing work. - pub async fn closed(&mut self) { - poll_fn(|cx| self.poll_close(cx)).await - } + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// // use `rx` + /// drop(rx); + /// }); + /// + /// // Waits for `rx` to drop + /// tx.closed().await; + /// println!("the `rx` handles dropped") + /// } + /// ``` + pub async fn closed(&self) { + let notified = self.shared.notify_tx.notified(); - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { - match self.shared.upgrade() { - Some(shared) => { - shared.cancel.register_by_ref(cx.waker()); - Pending - } - None => Ready(()), + if self.shared.ref_count_rx.load(Relaxed) == 0 { + return; } - } -} - -/// Notifies all watchers of a change -fn notify_all<T>(shared: &Shared<T>) { - let watchers = shared.watchers.lock().unwrap(); - for watcher in watchers.iter() { - // Notify the task - watcher.waker.wake(); + notified.await; + debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed)); } } impl<T> Drop for Sender<T> { fn drop(&mut self) { - if let Some(shared) = self.shared.upgrade() { - shared.version.fetch_or(CLOSED, SeqCst); - notify_all(&*shared); - } + self.shared.version.fetch_or(CLOSED, SeqCst); + self.shared.notify_rx.notify_waiters(); } } @@ -388,44 +405,3 @@ impl<T> ops::Deref for Ref<'_, T> { self.inner.deref() } } - -// ===== impl Shared ===== - -impl<T> Drop for Shared<T> { - fn drop(&mut self) { - self.cancel.wake(); - } -} - -// ===== impl Watcher ===== - -impl Watcher { - fn new_version(version: usize) -> Self { - Watcher(Arc::new(WatchInner { - version: AtomicUsize::new(version), - waker: AtomicWaker::new(), - })) - } -} - -impl std::cmp::PartialEq for Watcher { - fn eq(&self, other: &Watcher) -> bool { - Arc::ptr_eq(&self.0, &other.0) - } -} - -impl std::cmp::Eq for Watcher {} - -impl std::hash::Hash for Watcher { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - (&*self.0 as *const WatchInner).hash(state) - } -} - -impl std::ops::Deref for Watcher { - type Target = WatchInner; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} |