aboutsummaryrefslogtreecommitdiff
path: root/src/sync
diff options
context:
space:
mode:
Diffstat (limited to 'src/sync')
-rw-r--r--src/sync/barrier.rs6
-rw-r--r--src/sync/batch_semaphore.rs70
-rw-r--r--src/sync/broadcast.rs478
-rw-r--r--src/sync/cancellation_token.rs861
-rw-r--r--src/sync/mod.rs62
-rw-r--r--src/sync/mpsc/block.rs8
-rw-r--r--src/sync/mpsc/bounded.rs465
-rw-r--r--src/sync/mpsc/chan.rs289
-rw-r--r--src/sync/mpsc/error.rs20
-rw-r--r--src/sync/mpsc/list.rs6
-rw-r--r--src/sync/mpsc/mod.rs39
-rw-r--r--src/sync/mpsc/unbounded.rs97
-rw-r--r--src/sync/mutex.rs62
-rw-r--r--src/sync/notify.rs152
-rw-r--r--src/sync/oneshot.rs4
-rw-r--r--src/sync/rwlock.rs431
-rw-r--r--src/sync/semaphore.rs15
-rw-r--r--src/sync/semaphore_ll.rs1221
-rw-r--r--src/sync/task/atomic_waker.rs5
-rw-r--r--src/sync/tests/loom_broadcast.rs2
-rw-r--r--src/sync/tests/loom_cancellation_token.rs155
-rw-r--r--src/sync/tests/loom_mpsc.rs71
-rw-r--r--src/sync/tests/loom_notify.rs12
-rw-r--r--src/sync/tests/loom_oneshot.rs6
-rw-r--r--src/sync/tests/loom_semaphore_ll.rs192
-rw-r--r--src/sync/tests/loom_watch.rs36
-rw-r--r--src/sync/tests/mod.rs5
-rw-r--r--src/sync/tests/semaphore_ll.rs470
-rw-r--r--src/sync/watch.rs384
29 files changed, 1684 insertions, 3940 deletions
diff --git a/src/sync/barrier.rs b/src/sync/barrier.rs
index 6286334..fddb3a5 100644
--- a/src/sync/barrier.rs
+++ b/src/sync/barrier.rs
@@ -96,7 +96,7 @@ impl Barrier {
// wake everyone, increment the generation, and return
state
.waker
- .broadcast(state.generation)
+ .send(state.generation)
.expect("there is at least one receiver");
state.arrived = 0;
state.generation += 1;
@@ -110,9 +110,11 @@ impl Barrier {
let mut wait = self.wait.clone();
loop {
+ let _ = wait.changed().await;
+
// note that the first time through the loop, this _will_ yield a generation
// immediately, since we cloned a receiver that has never seen any values.
- if wait.recv().await.expect("sender hasn't been closed") >= generation {
+ if *wait.borrow() >= generation {
break;
}
}
diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs
index 070cd20..0b50e4f 100644
--- a/src/sync/batch_semaphore.rs
+++ b/src/sync/batch_semaphore.rs
@@ -1,3 +1,4 @@
+#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
//! # Implementation Details
//!
//! The semaphore is implemented using an intrusive linked list of waiters. An
@@ -36,7 +37,7 @@ pub(crate) struct Semaphore {
}
struct Waitlist {
- queue: LinkedList<Waiter>,
+ queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
closed: bool,
}
@@ -96,10 +97,13 @@ impl Semaphore {
/// Note that this reserves three bits of flags in the permit counter, but
/// we only actually use one of them. However, the previous semaphore
/// implementation used three bits, so we will continue to reserve them to
- /// avoid a breaking change if additional flags need to be aadded in the
+ /// avoid a breaking change if additional flags need to be added in the
/// future.
pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3;
const CLOSED: usize = 1;
+ // The least-significant bit in the number of permits is reserved to use
+ // as a flag indicating that the semaphore has been closed. Consequently
+ // PERMIT_SHIFT is used to leave that bit for that purpose.
const PERMIT_SHIFT: usize = 1;
/// Creates a new semaphore with the initial number of permits
@@ -120,6 +124,27 @@ impl Semaphore {
}
}
+ /// Creates a new semaphore with the initial number of permits
+ ///
+ /// Maximum number of permits on 32-bit platforms is `1<<29`.
+ ///
+ /// If the specified number of permits exceeds the maximum permit amount
+ /// Then the value will get clamped to the maximum number of permits.
+ #[cfg(all(feature = "parking_lot", not(all(loom, test))))]
+ pub(crate) const fn const_new(mut permits: usize) -> Self {
+ // NOTE: assertions and by extension panics are still being worked on: https://github.com/rust-lang/rust/issues/74925
+ // currently we just clamp the permit count when it exceeds the max
+ permits &= Self::MAX_PERMITS;
+
+ Self {
+ permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
+ waiters: Mutex::const_new(Waitlist {
+ queue: LinkedList::new(),
+ closed: false,
+ }),
+ }
+ }
+
/// Returns the current number of available permits
pub(crate) fn available_permits(&self) -> usize {
self.permits.load(Acquire) >> Self::PERMIT_SHIFT
@@ -134,16 +159,15 @@ impl Semaphore {
}
// Assign permits to the wait queue
- self.add_permits_locked(added, self.waiters.lock().unwrap());
+ self.add_permits_locked(added, self.waiters.lock());
}
/// Closes the semaphore. This prevents the semaphore from issuing new
/// permits and notifies all pending waiters.
// This will be used once the bounded MPSC is updated to use the new
// semaphore implementation.
- #[allow(dead_code)]
pub(crate) fn close(&self) {
- let mut waiters = self.waiters.lock().unwrap();
+ let mut waiters = self.waiters.lock();
// If the semaphore's permits counter has enough permits for an
// unqueued waiter to acquire all the permits it needs immediately,
// it won't touch the wait list. Therefore, we have to set a bit on
@@ -161,6 +185,11 @@ impl Semaphore {
}
}
+ /// Returns true if the semaphore is closed
+ pub(crate) fn is_closed(&self) -> bool {
+ self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED
+ }
+
pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> {
assert!(
num_permits as usize <= Self::MAX_PERMITS,
@@ -170,8 +199,8 @@ impl Semaphore {
let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT;
let mut curr = self.permits.load(Acquire);
loop {
- // Has the semaphore closed?git
- if curr & Self::CLOSED > 0 {
+ // Has the semaphore closed?
+ if curr & Self::CLOSED == Self::CLOSED {
return Err(TryAcquireError::Closed);
}
@@ -203,7 +232,7 @@ impl Semaphore {
let mut lock = Some(waiters);
let mut is_empty = false;
while rem > 0 {
- let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap());
+ let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock());
'inner: for slot in &mut wakers[..] {
// Was the waiter assigned enough permits to wake it?
match waiters.queue.last() {
@@ -296,7 +325,7 @@ impl Semaphore {
// counter. Otherwise, if we subtract the permits and then
// acquire the lock, we might miss additional permits being
// added while waiting for the lock.
- lock = Some(self.waiters.lock().unwrap());
+ lock = Some(self.waiters.lock());
}
match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
@@ -306,7 +335,7 @@ impl Semaphore {
if !queued {
return Ready(Ok(()));
} else if lock.is_none() {
- break self.waiters.lock().unwrap();
+ break self.waiters.lock();
}
}
break lock.expect("lock must be acquired before waiting");
@@ -357,7 +386,7 @@ impl Semaphore {
impl fmt::Debug for Semaphore {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Semaphore")
- .field("permits", &self.permits.load(Relaxed))
+ .field("permits", &self.available_permits())
.finish()
}
}
@@ -456,14 +485,7 @@ impl Drop for Acquire<'_> {
// This is where we ensure safety. The future is being dropped,
// which means we must ensure that the waiter entry is no longer stored
// in the linked list.
- let mut waiters = match self.semaphore.waiters.lock() {
- Ok(lock) => lock,
- // Removing the node from the linked list is necessary to ensure
- // safety. Even if the lock was poisoned, we need to make sure it is
- // removed from the linked list before dropping it --- otherwise,
- // the list will contain a dangling pointer to this node.
- Err(e) => e.into_inner(),
- };
+ let mut waiters = self.semaphore.waiters.lock();
// remove the entry from the list
let node = NonNull::from(&mut self.node);
@@ -506,20 +528,14 @@ impl TryAcquireError {
/// Returns `true` if the error was caused by a closed semaphore.
#[allow(dead_code)] // may be used later!
pub(crate) fn is_closed(&self) -> bool {
- match self {
- TryAcquireError::Closed => true,
- _ => false,
- }
+ matches!(self, TryAcquireError::Closed)
}
/// Returns `true` if the error was caused by calling `try_acquire` on a
/// semaphore with no available permits.
#[allow(dead_code)] // may be used later!
pub(crate) fn is_no_permits(&self) -> bool {
- match self {
- TryAcquireError::NoPermits => true,
- _ => false,
- }
+ matches!(self, TryAcquireError::NoPermits)
}
}
diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs
index 0c8716f..ee9aba0 100644
--- a/src/sync/broadcast.rs
+++ b/src/sync/broadcast.rs
@@ -21,7 +21,7 @@
//! ## Lagging
//!
//! As sent messages must be retained until **all** [`Receiver`] handles receive
-//! a clone, broadcast channels are suspectible to the "slow receiver" problem.
+//! a clone, broadcast channels are susceptible to the "slow receiver" problem.
//! In this case, all but one receiver are able to receive values at the rate
//! they are sent. Because one receiver is stalled, the channel starts to fill
//! up.
@@ -55,8 +55,8 @@
//! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe
//! [`Receiver`]: crate::sync::broadcast::Receiver
//! [`channel`]: crate::sync::broadcast::channel
-//! [`RecvError::Lagged`]: crate::sync::broadcast::RecvError::Lagged
-//! [`RecvError::Closed`]: crate::sync::broadcast::RecvError::Closed
+//! [`RecvError::Lagged`]: crate::sync::broadcast::error::RecvError::Lagged
+//! [`RecvError::Closed`]: crate::sync::broadcast::error::RecvError::Closed
//! [`recv`]: crate::sync::broadcast::Receiver::recv
//!
//! # Examples
@@ -107,6 +107,7 @@
//! assert_eq!(20, rx.recv().await.unwrap());
//! assert_eq!(30, rx.recv().await.unwrap());
//! }
+//! ```
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
@@ -194,58 +195,99 @@ pub struct Receiver<T> {
/// Next position to read from
next: u64,
-
- /// Used to support the deprecated `poll_recv` fn
- waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>,
}
-/// Error returned by [`Sender::send`][Sender::send].
-///
-/// A **send** operation can only fail if there are no active receivers,
-/// implying that the message could never be received. The error contains the
-/// message being sent as a payload so it can be recovered.
-#[derive(Debug)]
-pub struct SendError<T>(pub T);
+pub mod error {
+ //! Broadcast error types
-/// An error returned from the [`recv`] function on a [`Receiver`].
-///
-/// [`recv`]: crate::sync::broadcast::Receiver::recv
-/// [`Receiver`]: crate::sync::broadcast::Receiver
-#[derive(Debug, PartialEq)]
-pub enum RecvError {
- /// There are no more active senders implying no further messages will ever
- /// be sent.
- Closed,
+ use std::fmt;
- /// The receiver lagged too far behind. Attempting to receive again will
- /// return the oldest message still retained by the channel.
+ /// Error returned by from the [`send`] function on a [`Sender`].
///
- /// Includes the number of skipped messages.
- Lagged(u64),
-}
+ /// A **send** operation can only fail if there are no active receivers,
+ /// implying that the message could never be received. The error contains the
+ /// message being sent as a payload so it can be recovered.
+ ///
+ /// [`send`]: crate::sync::broadcast::Sender::send
+ /// [`Sender`]: crate::sync::broadcast::Sender
+ #[derive(Debug)]
+ pub struct SendError<T>(pub T);
-/// An error returned from the [`try_recv`] function on a [`Receiver`].
-///
-/// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv
-/// [`Receiver`]: crate::sync::broadcast::Receiver
-#[derive(Debug, PartialEq)]
-pub enum TryRecvError {
- /// The channel is currently empty. There are still active
- /// [`Sender`][Sender] handles, so data may yet become available.
- Empty,
-
- /// There are no more active senders implying no further messages will ever
- /// be sent.
- Closed,
-
- /// The receiver lagged too far behind and has been forcibly disconnected.
- /// Attempting to receive again will return the oldest message still
- /// retained by the channel.
- ///
- /// Includes the number of skipped messages.
- Lagged(u64),
+ impl<T> fmt::Display for SendError<T> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "channel closed")
+ }
+ }
+
+ impl<T: fmt::Debug> std::error::Error for SendError<T> {}
+
+ /// An error returned from the [`recv`] function on a [`Receiver`].
+ ///
+ /// [`recv`]: crate::sync::broadcast::Receiver::recv
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ #[derive(Debug, PartialEq)]
+ pub enum RecvError {
+ /// There are no more active senders implying no further messages will ever
+ /// be sent.
+ Closed,
+
+ /// The receiver lagged too far behind. Attempting to receive again will
+ /// return the oldest message still retained by the channel.
+ ///
+ /// Includes the number of skipped messages.
+ Lagged(u64),
+ }
+
+ impl fmt::Display for RecvError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ RecvError::Closed => write!(f, "channel closed"),
+ RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
+ }
+ }
+ }
+
+ impl std::error::Error for RecvError {}
+
+ /// An error returned from the [`try_recv`] function on a [`Receiver`].
+ ///
+ /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ #[derive(Debug, PartialEq)]
+ pub enum TryRecvError {
+ /// The channel is currently empty. There are still active
+ /// [`Sender`] handles, so data may yet become available.
+ ///
+ /// [`Sender`]: crate::sync::broadcast::Sender
+ Empty,
+
+ /// There are no more active senders implying no further messages will ever
+ /// be sent.
+ Closed,
+
+ /// The receiver lagged too far behind and has been forcibly disconnected.
+ /// Attempting to receive again will return the oldest message still
+ /// retained by the channel.
+ ///
+ /// Includes the number of skipped messages.
+ Lagged(u64),
+ }
+
+ impl fmt::Display for TryRecvError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ TryRecvError::Empty => write!(f, "channel empty"),
+ TryRecvError::Closed => write!(f, "channel closed"),
+ TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
+ }
+ }
+ }
+
+ impl std::error::Error for TryRecvError {}
}
+use self::error::*;
+
/// Data shared between senders and receivers
struct Shared<T> {
/// slots in the channel
@@ -273,7 +315,7 @@ struct Tail {
closed: bool,
/// Receivers waiting for a value
- waiters: LinkedList<Waiter>,
+ waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
}
/// Slot in the buffer
@@ -373,8 +415,8 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2;
/// [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe
/// [`Receiver`]: crate::sync::broadcast::Receiver
/// [`recv`]: crate::sync::broadcast::Receiver::recv
-/// [`SendError`]: crate::sync::broadcast::SendError
-/// [`RecvError`]: crate::sync::broadcast::RecvError
+/// [`SendError`]: crate::sync::broadcast::error::SendError
+/// [`RecvError`]: crate::sync::broadcast::error::RecvError
///
/// # Examples
///
@@ -400,7 +442,7 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2;
/// tx.send(20).unwrap();
/// }
/// ```
-pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
+pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
assert!(capacity > 0, "capacity is empty");
assert!(capacity <= usize::MAX >> 1, "requested capacity too large");
@@ -433,7 +475,6 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
let rx = Receiver {
shared: shared.clone(),
next: 0,
- waiter: None,
};
let tx = Sender { shared };
@@ -528,23 +569,7 @@ impl<T> Sender<T> {
/// ```
pub fn subscribe(&self) -> Receiver<T> {
let shared = self.shared.clone();
-
- let mut tail = shared.tail.lock().unwrap();
-
- if tail.rx_cnt == MAX_RECEIVERS {
- panic!("max receivers");
- }
-
- tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
- let next = tail.pos;
-
- drop(tail);
-
- Receiver {
- shared,
- next,
- waiter: None,
- }
+ new_receiver(shared)
}
/// Returns the number of active receivers
@@ -584,12 +609,12 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn receiver_count(&self) -> usize {
- let tail = self.shared.tail.lock().unwrap();
+ let tail = self.shared.tail.lock();
tail.rx_cnt
}
fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> {
- let mut tail = self.shared.tail.lock().unwrap();
+ let mut tail = self.shared.tail.lock();
if tail.rx_cnt == 0 {
return Err(SendError(value));
@@ -634,6 +659,22 @@ impl<T> Sender<T> {
}
}
+fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> {
+ let mut tail = shared.tail.lock();
+
+ if tail.rx_cnt == MAX_RECEIVERS {
+ panic!("max receivers");
+ }
+
+ tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
+
+ let next = tail.pos;
+
+ drop(tail);
+
+ Receiver { shared, next }
+}
+
impl Tail {
fn notify_rx(&mut self) {
while let Some(mut waiter) = self.waiters.pop_back() {
@@ -695,7 +736,7 @@ impl<T> Receiver<T> {
// the slot lock.
drop(slot);
- let mut tail = self.shared.tail.lock().unwrap();
+ let mut tail = self.shared.tail.lock();
// Acquire slot lock again
slot = self.shared.buffer[idx].read().unwrap();
@@ -784,106 +825,7 @@ impl<T> Receiver<T> {
}
}
-impl<T> Receiver<T>
-where
- T: Clone,
-{
- /// Attempts to return a pending value on this receiver without awaiting.
- ///
- /// This is useful for a flavor of "optimistic check" before deciding to
- /// await on a receiver.
- ///
- /// Compared with [`recv`], this function has three failure cases instead of one
- /// (one for closed, one for an empty buffer, one for a lagging receiver).
- ///
- /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have
- /// dropped, indicating that no further values can be sent on the channel.
- ///
- /// If the [`Receiver`] handle falls behind, once the channel is full, newly
- /// sent values will overwrite old values. At this point, a call to [`recv`]
- /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s
- /// internal cursor is updated to point to the oldest value still held by
- /// the channel. A subsequent call to [`try_recv`] will return this value
- /// **unless** it has been since overwritten. If there are no values to
- /// receive, `Err(TryRecvError::Empty)` is returned.
- ///
- /// [`recv`]: crate::sync::broadcast::Receiver::recv
- /// [`Receiver`]: crate::sync::broadcast::Receiver
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::sync::broadcast;
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let (tx, mut rx) = broadcast::channel(16);
- ///
- /// assert!(rx.try_recv().is_err());
- ///
- /// tx.send(10).unwrap();
- ///
- /// let value = rx.try_recv().unwrap();
- /// assert_eq!(10, value);
- /// }
- /// ```
- pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
- let guard = self.recv_ref(None)?;
- guard.clone_value().ok_or(TryRecvError::Closed)
- }
-
- #[doc(hidden)]
- #[deprecated(since = "0.2.21", note = "use async fn recv()")]
- pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
- use Poll::{Pending, Ready};
-
- // The borrow checker prohibits calling `self.poll_ref` while passing in
- // a mutable ref to a field (as it should). To work around this,
- // `waiter` is first *removed* from `self` then `poll_recv` is called.
- //
- // However, for safety, we must ensure that `waiter` is **not** dropped.
- // It could be contained in the intrusive linked list. The `Receiver`
- // drop implementation handles cleanup.
- //
- // The guard pattern is used to ensure that, on return, even due to
- // panic, the waiter node is replaced on `self`.
-
- struct Guard<'a, T> {
- waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>,
- receiver: &'a mut Receiver<T>,
- }
-
- impl<'a, T> Drop for Guard<'a, T> {
- fn drop(&mut self) {
- self.receiver.waiter = self.waiter.take();
- }
- }
-
- let waiter = self.waiter.take().or_else(|| {
- Some(Box::pin(UnsafeCell::new(Waiter {
- queued: false,
- waker: None,
- pointers: linked_list::Pointers::new(),
- _p: PhantomPinned,
- })))
- });
-
- let guard = Guard {
- waiter,
- receiver: self,
- };
- let res = guard
- .receiver
- .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker())));
-
- match res {
- Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)),
- Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)),
- Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))),
- Err(TryRecvError::Empty) => Pending,
- }
- }
-
+impl<T: Clone> Receiver<T> {
/// Receives the next value for this receiver.
///
/// Each [`Receiver`] handle will receive a clone of all values sent
@@ -948,54 +890,103 @@ where
/// assert_eq!(20, rx.recv().await.unwrap());
/// assert_eq!(30, rx.recv().await.unwrap());
/// }
+ /// ```
pub async fn recv(&mut self) -> Result<T, RecvError> {
let fut = Recv::<_, T>::new(Borrow(self));
fut.await
}
-}
-#[cfg(feature = "stream")]
-#[doc(hidden)]
-#[deprecated(since = "0.2.21", note = "use `into_stream()`")]
-impl<T> crate::stream::Stream for Receiver<T>
-where
- T: Clone,
-{
- type Item = Result<T, RecvError>;
-
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Result<T, RecvError>>> {
- #[allow(deprecated)]
- self.poll_recv(cx).map(|v| match v {
- Ok(v) => Some(Ok(v)),
- lag @ Err(RecvError::Lagged(_)) => Some(lag),
- Err(RecvError::Closed) => None,
- })
+ /// Attempts to return a pending value on this receiver without awaiting.
+ ///
+ /// This is useful for a flavor of "optimistic check" before deciding to
+ /// await on a receiver.
+ ///
+ /// Compared with [`recv`], this function has three failure cases instead of two
+ /// (one for closed, one for an empty buffer, one for a lagging receiver).
+ ///
+ /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have
+ /// dropped, indicating that no further values can be sent on the channel.
+ ///
+ /// If the [`Receiver`] handle falls behind, once the channel is full, newly
+ /// sent values will overwrite old values. At this point, a call to [`recv`]
+ /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s
+ /// internal cursor is updated to point to the oldest value still held by
+ /// the channel. A subsequent call to [`try_recv`] will return this value
+ /// **unless** it has been since overwritten. If there are no values to
+ /// receive, `Err(TryRecvError::Empty)` is returned.
+ ///
+ /// [`recv`]: crate::sync::broadcast::Receiver::recv
+ /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = broadcast::channel(16);
+ ///
+ /// assert!(rx.try_recv().is_err());
+ ///
+ /// tx.send(10).unwrap();
+ ///
+ /// let value = rx.try_recv().unwrap();
+ /// assert_eq!(10, value);
+ /// }
+ /// ```
+ pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
+ let guard = self.recv_ref(None)?;
+ guard.clone_value().ok_or(TryRecvError::Closed)
+ }
+
+ /// Convert the receiver into a `Stream`.
+ ///
+ /// The conversion allows using `Receiver` with APIs that require stream
+ /// values.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::stream::StreamExt;
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, rx) = broadcast::channel(128);
+ ///
+ /// tokio::spawn(async move {
+ /// for i in 0..10_i32 {
+ /// tx.send(i).unwrap();
+ /// }
+ /// });
+ ///
+ /// // Streams must be pinned to iterate.
+ /// tokio::pin! {
+ /// let stream = rx
+ /// .into_stream()
+ /// .filter(Result::is_ok)
+ /// .map(Result::unwrap)
+ /// .filter(|v| v % 2 == 0)
+ /// .map(|v| v + 1);
+ /// }
+ ///
+ /// while let Some(i) = stream.next().await {
+ /// println!("{}", i);
+ /// }
+ /// }
+ /// ```
+ #[cfg(feature = "stream")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
+ pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> {
+ Recv::new(Borrow(self))
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
- let mut tail = self.shared.tail.lock().unwrap();
-
- if let Some(waiter) = &self.waiter {
- // safety: tail lock is held
- let queued = waiter.with(|ptr| unsafe { (*ptr).queued });
-
- if queued {
- // Remove the node
- //
- // safety: tail lock is held and the wait node is verified to be in
- // the list.
- unsafe {
- waiter.with_mut(|ptr| {
- tail.waiters.remove((&mut *ptr).into());
- });
- }
- }
- }
+ let mut tail = self.shared.tail.lock();
tail.rx_cnt -= 1;
let until = tail.pos;
@@ -1070,48 +1061,6 @@ where
cfg_stream! {
use futures_core::Stream;
- impl<T: Clone> Receiver<T> {
- /// Convert the receiver into a `Stream`.
- ///
- /// The conversion allows using `Receiver` with APIs that require stream
- /// values.
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::stream::StreamExt;
- /// use tokio::sync::broadcast;
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let (tx, rx) = broadcast::channel(128);
- ///
- /// tokio::spawn(async move {
- /// for i in 0..10_i32 {
- /// tx.send(i).unwrap();
- /// }
- /// });
- ///
- /// // Streams must be pinned to iterate.
- /// tokio::pin! {
- /// let stream = rx
- /// .into_stream()
- /// .filter(Result::is_ok)
- /// .map(Result::unwrap)
- /// .filter(|v| v % 2 == 0)
- /// .map(|v| v + 1);
- /// }
- ///
- /// while let Some(i) = stream.next().await {
- /// println!("{}", i);
- /// }
- /// }
- /// ```
- pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> {
- Recv::new(Borrow(self))
- }
- }
-
impl<R, T: Clone> Stream for Recv<R, T>
where
R: AsMut<Receiver<T>>,
@@ -1141,7 +1090,7 @@ where
fn drop(&mut self) {
// Acquire the tail lock. This is required for safety before accessing
// the waiter node.
- let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap();
+ let mut tail = self.receiver.as_mut().shared.tail.lock();
// safety: tail lock is held
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });
@@ -1211,27 +1160,4 @@ impl<'a, T> Drop for RecvGuard<'a, T> {
}
}
-impl fmt::Display for RecvError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- RecvError::Closed => write!(f, "channel closed"),
- RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
- }
- }
-}
-
-impl std::error::Error for RecvError {}
-
-impl fmt::Display for TryRecvError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- TryRecvError::Empty => write!(f, "channel empty"),
- TryRecvError::Closed => write!(f, "channel closed"),
- TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
- }
- }
-}
-
-impl std::error::Error for TryRecvError {}
-
fn is_unpin<T: Unpin>() {}
diff --git a/src/sync/cancellation_token.rs b/src/sync/cancellation_token.rs
deleted file mode 100644
index d60d8e0..0000000
--- a/src/sync/cancellation_token.rs
+++ /dev/null
@@ -1,861 +0,0 @@
-//! An asynchronously awaitable `CancellationToken`.
-//! The token allows to signal a cancellation request to one or more tasks.
-
-use crate::loom::sync::atomic::AtomicUsize;
-use crate::loom::sync::Mutex;
-use crate::util::intrusive_double_linked_list::{LinkedList, ListNode};
-
-use core::future::Future;
-use core::pin::Pin;
-use core::ptr::NonNull;
-use core::sync::atomic::Ordering;
-use core::task::{Context, Poll, Waker};
-
-/// A token which can be used to signal a cancellation request to one or more
-/// tasks.
-///
-/// Tasks can call [`CancellationToken::cancelled()`] in order to
-/// obtain a Future which will be resolved when cancellation is requested.
-///
-/// Cancellation can be requested through the [`CancellationToken::cancel`] method.
-///
-/// # Examples
-///
-/// ```ignore
-/// use tokio::select;
-/// use tokio::scope::CancellationToken;
-///
-/// #[tokio::main]
-/// async fn main() {
-/// let token = CancellationToken::new();
-/// let cloned_token = token.clone();
-///
-/// let join_handle = tokio::spawn(async move {
-/// // Wait for either cancellation or a very long time
-/// select! {
-/// _ = cloned_token.cancelled() => {
-/// // The token was cancelled
-/// 5
-/// }
-/// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => {
-/// 99
-/// }
-/// }
-/// });
-///
-/// tokio::spawn(async move {
-/// tokio::time::delay_for(std::time::Duration::from_millis(10)).await;
-/// token.cancel();
-/// });
-///
-/// assert_eq!(5, join_handle.await.unwrap());
-/// }
-/// ```
-pub struct CancellationToken {
- inner: NonNull<CancellationTokenState>,
-}
-
-// Safety: The CancellationToken is thread-safe and can be moved between threads,
-// since all methods are internally synchronized.
-unsafe impl Send for CancellationToken {}
-unsafe impl Sync for CancellationToken {}
-
-/// A Future that is resolved once the corresponding [`CancellationToken`]
-/// was cancelled
-#[must_use = "futures do nothing unless polled"]
-pub struct WaitForCancellationFuture<'a> {
- /// The CancellationToken that is associated with this WaitForCancellationFuture
- cancellation_token: Option<&'a CancellationToken>,
- /// Node for waiting at the cancellation_token
- wait_node: ListNode<WaitQueueEntry>,
- /// Whether this future was registered at the token yet as a waiter
- is_registered: bool,
-}
-
-// Safety: Futures can be sent between threads as long as the underlying
-// cancellation_token is thread-safe (Sync),
-// which allows to poll/register/unregister from a different thread.
-unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
-
-// ===== impl CancellationToken =====
-
-impl core::fmt::Debug for CancellationToken {
- fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
- f.debug_struct("CancellationToken")
- .field("is_cancelled", &self.is_cancelled())
- .finish()
- }
-}
-
-impl Clone for CancellationToken {
- fn clone(&self) -> Self {
- // Safety: The state inside a `CancellationToken` is always valid, since
- // is reference counted
- let inner = self.state();
-
- // Tokens are cloned by increasing their refcount
- let current_state = inner.snapshot();
- inner.increment_refcount(current_state);
-
- CancellationToken { inner: self.inner }
- }
-}
-
-impl Drop for CancellationToken {
- fn drop(&mut self) {
- let token_state_pointer = self.inner;
-
- // Safety: The state inside a `CancellationToken` is always valid, since
- // is reference counted
- let inner = unsafe { &mut *self.inner.as_ptr() };
-
- let mut current_state = inner.snapshot();
-
- // We need to safe the parent, since the state might be released by the
- // next call
- let parent = inner.parent;
-
- // Drop our own refcount
- current_state = inner.decrement_refcount(current_state);
-
- // If this was the last reference, unregister from the parent
- if current_state.refcount == 0 {
- if let Some(mut parent) = parent {
- // Safety: Since we still retain a reference on the parent, it must be valid.
- let parent = unsafe { parent.as_mut() };
- parent.unregister_child(token_state_pointer, current_state);
- }
- }
- }
-}
-
-impl CancellationToken {
- /// Creates a new CancellationToken in the non-cancelled state.
- pub fn new() -> CancellationToken {
- let state = Box::new(CancellationTokenState::new(
- None,
- StateSnapshot {
- cancel_state: CancellationState::NotCancelled,
- has_parent_ref: false,
- refcount: 1,
- },
- ));
-
- // Safety: We just created the Box. The pointer is guaranteed to be
- // not null
- CancellationToken {
- inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) },
- }
- }
-
- /// Returns a reference to the utilized `CancellationTokenState`.
- fn state(&self) -> &CancellationTokenState {
- // Safety: The state inside a `CancellationToken` is always valid, since
- // is reference counted
- unsafe { &*self.inner.as_ptr() }
- }
-
- /// Creates a `CancellationToken` which will get cancelled whenever the
- /// current token gets cancelled.
- ///
- /// If the current token is already cancelled, the child token will get
- /// returned in cancelled state.
- ///
- /// # Examples
- ///
- /// ```ignore
- /// use tokio::select;
- /// use tokio::scope::CancellationToken;
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let token = CancellationToken::new();
- /// let child_token = token.child_token();
- ///
- /// let join_handle = tokio::spawn(async move {
- /// // Wait for either cancellation or a very long time
- /// select! {
- /// _ = child_token.cancelled() => {
- /// // The token was cancelled
- /// 5
- /// }
- /// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => {
- /// 99
- /// }
- /// }
- /// });
- ///
- /// tokio::spawn(async move {
- /// tokio::time::delay_for(std::time::Duration::from_millis(10)).await;
- /// token.cancel();
- /// });
- ///
- /// assert_eq!(5, join_handle.await.unwrap());
- /// }
- /// ```
- pub fn child_token(&self) -> CancellationToken {
- let inner = self.state();
-
- // Increment the refcount of this token. It will be referenced by the
- // child, independent of whether the child is immediately cancelled or
- // not.
- let _current_state = inner.increment_refcount(inner.snapshot());
-
- let mut unpacked_child_state = StateSnapshot {
- has_parent_ref: true,
- refcount: 1,
- cancel_state: CancellationState::NotCancelled,
- };
- let mut child_token_state = Box::new(CancellationTokenState::new(
- Some(self.inner),
- unpacked_child_state,
- ));
-
- {
- let mut guard = inner.synchronized.lock().unwrap();
- if guard.is_cancelled {
- // This task was already cancelled. In this case we should not
- // insert the child into the list, since it would never get removed
- // from the list.
- (*child_token_state.synchronized.lock().unwrap()).is_cancelled = true;
- unpacked_child_state.cancel_state = CancellationState::Cancelled;
- // Since it's not in the list, the parent doesn't need to retain
- // a reference to it.
- unpacked_child_state.has_parent_ref = false;
- child_token_state
- .state
- .store(unpacked_child_state.pack(), Ordering::SeqCst);
- } else {
- if let Some(mut first_child) = guard.first_child {
- child_token_state.from_parent.next_peer = Some(first_child);
- // Safety: We manipulate other child task inside the Mutex
- // and retain a parent reference on it. The child token can't
- // get invalidated while the Mutex is held.
- unsafe {
- first_child.as_mut().from_parent.prev_peer =
- Some((&mut *child_token_state).into())
- };
- }
- guard.first_child = Some((&mut *child_token_state).into());
- }
- };
-
- let child_token_ptr = Box::into_raw(child_token_state);
- // Safety: We just created the pointer from a `Box`
- CancellationToken {
- inner: unsafe { NonNull::new_unchecked(child_token_ptr) },
- }
- }
-
- /// Cancel the [`CancellationToken`] and all child tokens which had been
- /// derived from it.
- ///
- /// This will wake up all tasks which are waiting for cancellation.
- pub fn cancel(&self) {
- self.state().cancel();
- }
-
- /// Returns `true` if the `CancellationToken` had been cancelled
- pub fn is_cancelled(&self) -> bool {
- self.state().is_cancelled()
- }
-
- /// Returns a `Future` that gets fulfilled when cancellation is requested.
- pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
- WaitForCancellationFuture {
- cancellation_token: Some(self),
- wait_node: ListNode::new(WaitQueueEntry::new()),
- is_registered: false,
- }
- }
-
- unsafe fn register(
- &self,
- wait_node: &mut ListNode<WaitQueueEntry>,
- cx: &mut Context<'_>,
- ) -> Poll<()> {
- self.state().register(wait_node, cx)
- }
-
- fn check_for_cancellation(
- &self,
- wait_node: &mut ListNode<WaitQueueEntry>,
- cx: &mut Context<'_>,
- ) -> Poll<()> {
- self.state().check_for_cancellation(wait_node, cx)
- }
-
- fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) {
- self.state().unregister(wait_node)
- }
-}
-
-// ===== impl WaitForCancellationFuture =====
-
-impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
- fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
- f.debug_struct("WaitForCancellationFuture").finish()
- }
-}
-
-impl<'a> Future for WaitForCancellationFuture<'a> {
- type Output = ();
-
- fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
- // Safety: We do not move anything out of `WaitForCancellationFuture`
- let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) };
-
- let cancellation_token = mut_self
- .cancellation_token
- .expect("polled WaitForCancellationFuture after completion");
-
- let poll_res = if !mut_self.is_registered {
- // Safety: The `ListNode` is pinned through the Future,
- // and we will unregister it in `WaitForCancellationFuture::drop`
- // before the Future is dropped and the memory reference is invalidated.
- unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) }
- } else {
- cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx)
- };
-
- if let Poll::Ready(()) = poll_res {
- // The cancellation_token was signalled
- mut_self.cancellation_token = None;
- // A signalled Token means the Waker won't be enqueued anymore
- mut_self.is_registered = false;
- mut_self.wait_node.task = None;
- } else {
- // This `Future` and its stored `Waker` stay registered at the
- // `CancellationToken`
- mut_self.is_registered = true;
- }
-
- poll_res
- }
-}
-
-impl<'a> Drop for WaitForCancellationFuture<'a> {
- fn drop(&mut self) {
- // If this WaitForCancellationFuture has been polled and it was added to the
- // wait queue at the cancellation_token, it must be removed before dropping.
- // Otherwise the cancellation_token would access invalid memory.
- if let Some(token) = self.cancellation_token {
- if self.is_registered {
- token.unregister(&mut self.wait_node);
- }
- }
- }
-}
-
-/// Tracks how the future had interacted with the [`CancellationToken`]
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
-enum PollState {
- /// The task has never interacted with the [`CancellationToken`].
- New,
- /// The task was added to the wait queue at the [`CancellationToken`].
- Waiting,
- /// The task has been polled to completion.
- Done,
-}
-
-/// Tracks the WaitForCancellationFuture waiting state.
-/// Access to this struct is synchronized through the mutex in the CancellationToken.
-struct WaitQueueEntry {
- /// The task handle of the waiting task
- task: Option<Waker>,
- // Current polling state. This state is only updated inside the Mutex of
- // the CancellationToken.
- state: PollState,
-}
-
-impl WaitQueueEntry {
- /// Creates a new WaitQueueEntry
- fn new() -> WaitQueueEntry {
- WaitQueueEntry {
- task: None,
- state: PollState::New,
- }
- }
-}
-
-struct SynchronizedState {
- waiters: LinkedList<WaitQueueEntry>,
- first_child: Option<NonNull<CancellationTokenState>>,
- is_cancelled: bool,
-}
-
-impl SynchronizedState {
- fn new() -> Self {
- Self {
- waiters: LinkedList::new(),
- first_child: None,
- is_cancelled: false,
- }
- }
-}
-
-/// Information embedded in child tokens which is synchronized through the Mutex
-/// in their parent.
-struct SynchronizedThroughParent {
- next_peer: Option<NonNull<CancellationTokenState>>,
- prev_peer: Option<NonNull<CancellationTokenState>>,
-}
-
-/// Possible states of a `CancellationToken`
-#[derive(Debug, Copy, Clone, PartialEq, Eq)]
-enum CancellationState {
- NotCancelled = 0,
- Cancelling = 1,
- Cancelled = 2,
-}
-
-impl CancellationState {
- fn pack(self) -> usize {
- self as usize
- }
-
- fn unpack(value: usize) -> Self {
- match value {
- 0 => CancellationState::NotCancelled,
- 1 => CancellationState::Cancelling,
- 2 => CancellationState::Cancelled,
- _ => unreachable!("Invalid value"),
- }
- }
-}
-
-#[derive(Debug, Copy, Clone, PartialEq, Eq)]
-struct StateSnapshot {
- /// The amount of references to this particular CancellationToken.
- /// `CancellationToken` structs hold these references to a `CancellationTokenState`.
- /// Also the state is referenced by the state of each child.
- refcount: usize,
- /// Whether the state is still referenced by it's parent and can therefore
- /// not be freed.
- has_parent_ref: bool,
- /// Whether the token is cancelled
- cancel_state: CancellationState,
-}
-
-impl StateSnapshot {
- /// Packs the snapshot into a `usize`
- fn pack(self) -> usize {
- self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack()
- }
-
- /// Unpacks the snapshot from a `usize`
- fn unpack(value: usize) -> Self {
- let refcount = value >> 3;
- let has_parent_ref = value & 4 != 0;
- let cancel_state = CancellationState::unpack(value & 0x03);
-
- StateSnapshot {
- refcount,
- has_parent_ref,
- cancel_state,
- }
- }
-
- /// Whether this `CancellationTokenState` is still referenced by any
- /// `CancellationToken`.
- fn has_refs(&self) -> bool {
- self.refcount != 0 || self.has_parent_ref
- }
-}
-
-/// The maximum permitted amount of references to a CancellationToken. This
-/// is derived from the intent to never use more than 32bit in the `Snapshot`.
-const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3;
-
-/// Internal state of the `CancellationToken` pair above
-struct CancellationTokenState {
- state: AtomicUsize,
- parent: Option<NonNull<CancellationTokenState>>,
- from_parent: SynchronizedThroughParent,
- synchronized: Mutex<SynchronizedState>,
-}
-
-impl CancellationTokenState {
- fn new(
- parent: Option<NonNull<CancellationTokenState>>,
- state: StateSnapshot,
- ) -> CancellationTokenState {
- CancellationTokenState {
- parent,
- from_parent: SynchronizedThroughParent {
- prev_peer: None,
- next_peer: None,
- },
- state: AtomicUsize::new(state.pack()),
- synchronized: Mutex::new(SynchronizedState::new()),
- }
- }
-
- /// Returns a snapshot of the current atomic state of the token
- fn snapshot(&self) -> StateSnapshot {
- StateSnapshot::unpack(self.state.load(Ordering::SeqCst))
- }
-
- fn atomic_update_state<F>(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot
- where
- F: Fn(StateSnapshot) -> StateSnapshot,
- {
- let mut current_packed_state = current_state.pack();
- loop {
- let next_state = func(current_state);
- match self.state.compare_exchange(
- current_packed_state,
- next_state.pack(),
- Ordering::SeqCst,
- Ordering::SeqCst,
- ) {
- Ok(_) => {
- return next_state;
- }
- Err(actual) => {
- current_packed_state = actual;
- current_state = StateSnapshot::unpack(actual);
- }
- }
- }
- }
-
- fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot {
- self.atomic_update_state(current_state, |mut state: StateSnapshot| {
- if state.refcount >= MAX_REFS as usize {
- eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded");
- std::process::abort();
- }
- state.refcount += 1;
- state
- })
- }
-
- fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot {
- let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| {
- state.refcount -= 1;
- state
- });
-
- // Drop the State if it is not referenced anymore
- if !current_state.has_refs() {
- // Safety: `CancellationTokenState` is always stored in refcounted
- // Boxes
- let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) };
- }
-
- current_state
- }
-
- fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot {
- let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| {
- state.has_parent_ref = false;
- state
- });
-
- // Drop the State if it is not referenced anymore
- if !current_state.has_refs() {
- // Safety: `CancellationTokenState` is always stored in refcounted
- // Boxes
- let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) };
- }
-
- current_state
- }
-
- /// Unregisters a child from the parent token.
- /// The child tokens state is not exactly known at this point in time.
- /// If the parent token is cancelled, the child token gets removed from the
- /// parents list, and might therefore already have been freed. If the parent
- /// token is not cancelled, the child token is still valid.
- fn unregister_child(
- &mut self,
- mut child_state: NonNull<CancellationTokenState>,
- current_child_state: StateSnapshot,
- ) {
- let removed_child = {
- // Remove the child toke from the parents linked list
- let mut guard = self.synchronized.lock().unwrap();
- if !guard.is_cancelled {
- // Safety: Since the token was not cancelled, the child must
- // still be in the list and valid.
- let mut child_state = unsafe { child_state.as_mut() };
- debug_assert!(child_state.snapshot().has_parent_ref);
-
- if guard.first_child == Some(child_state.into()) {
- guard.first_child = child_state.from_parent.next_peer;
- }
- // Safety: If peers wouldn't be valid anymore, they would try
- // to remove themselves from the list. This would require locking
- // the Mutex that we currently own.
- unsafe {
- if let Some(mut prev_peer) = child_state.from_parent.prev_peer {
- prev_peer.as_mut().from_parent.next_peer =
- child_state.from_parent.next_peer;
- }
- if let Some(mut next_peer) = child_state.from_parent.next_peer {
- next_peer.as_mut().from_parent.prev_peer =
- child_state.from_parent.prev_peer;
- }
- }
- child_state.from_parent.prev_peer = None;
- child_state.from_parent.next_peer = None;
-
- // The child is no longer referenced by the parent, since we were able
- // to remove its reference from the parents list.
- true
- } else {
- // Do not touch the linked list anymore. If the parent is cancelled
- // it will move all childs outside of the Mutex and manipulate
- // the pointers there. Manipulating the pointers here too could
- // lead to races. Therefore leave them just as as and let the
- // parent deal with it. The parent will make sure to retain a
- // reference to this state as long as it manipulates the list
- // pointers. Therefore the pointers are not dangling.
- false
- }
- };
-
- if removed_child {
- // If the token removed itself from the parents list, it can reset
- // the the parent ref status. If it is isn't able to do so, because the
- // parent removed it from the list, there is no need to do this.
- // The parent ref acts as as another reference count. Therefore
- // removing this reference can free the object.
- // Safety: The token was in the list. This means the parent wasn't
- // cancelled before, and the token must still be alive.
- unsafe { child_state.as_mut().remove_parent_ref(current_child_state) };
- }
-
- // Decrement the refcount on the parent and free it if necessary
- self.decrement_refcount(self.snapshot());
- }
-
- fn cancel(&self) {
- // Move the state of the CancellationToken from `NotCancelled` to `Cancelling`
- let mut current_state = self.snapshot();
-
- let state_after_cancellation = loop {
- if current_state.cancel_state != CancellationState::NotCancelled {
- // Another task already initiated the cancellation
- return;
- }
-
- let mut next_state = current_state;
- next_state.cancel_state = CancellationState::Cancelling;
- match self.state.compare_exchange(
- current_state.pack(),
- next_state.pack(),
- Ordering::SeqCst,
- Ordering::SeqCst,
- ) {
- Ok(_) => break next_state,
- Err(actual) => current_state = StateSnapshot::unpack(actual),
- }
- };
-
- // This task cancelled the token
-
- // Take the task list out of the Token
- // We do not want to cancel child token inside this lock. If one of the
- // child tasks would have additional child tokens, we would recursively
- // take locks.
-
- // Doing this action has an impact if the child token is dropped concurrently:
- // It will try to deregister itself from the parent task, but can not find
- // itself in the task list anymore. Therefore it needs to assume the parent
- // has extracted the list and will process it. It may not modify the list.
- // This is OK from a memory safety perspective, since the parent still
- // retains a reference to the child task until it finished iterating over
- // it.
-
- let mut first_child = {
- let mut guard = self.synchronized.lock().unwrap();
- // Save the cancellation also inside the Mutex
- // This allows child tokens which want to detach themselves to detect
- // that this is no longer required since the parent cleared the list.
- guard.is_cancelled = true;
-
- // Wakeup all waiters
- // This happens inside the lock to make cancellation reliable
- // If we would access waiters outside of the lock, the pointers
- // may no longer be valid.
- // Typically this shouldn't be an issue, since waking a task should
- // only move it from the blocked into the ready state and not have
- // further side effects.
-
- // Use a reverse iterator, so that the oldest waiter gets
- // scheduled first
- guard.waiters.reverse_drain(|waiter| {
- // We are not allowed to move the `Waker` out of the list node.
- // The `Future` relies on the fact that the old `Waker` stays there
- // as long as the `Future` has not completed in order to perform
- // the `will_wake()` check.
- // Therefore `wake_by_ref` is used instead of `wake()`
- if let Some(handle) = &mut waiter.task {
- handle.wake_by_ref();
- }
- // Mark the waiter to have been removed from the list.
- waiter.state = PollState::Done;
- });
-
- guard.first_child.take()
- };
-
- while let Some(mut child) = first_child {
- // Safety: We know this is a valid pointer since it is in our child pointer
- // list. It can't have been freed in between, since we retain a a reference
- // to each child.
- let mut_child = unsafe { child.as_mut() };
-
- // Get the next child and clean up list pointers
- first_child = mut_child.from_parent.next_peer;
- mut_child.from_parent.prev_peer = None;
- mut_child.from_parent.next_peer = None;
-
- // Cancel the child task
- mut_child.cancel();
-
- // Drop the parent reference. This `CancellationToken` is not interested
- // in interacting with the child anymore.
- // This is ONLY allowed once we promised not to touch the state anymore
- // after this interaction.
- mut_child.remove_parent_ref(mut_child.snapshot());
- }
-
- // The cancellation has completed
- // At this point in time tasks which registered a wait node can be sure
- // that this wait node already had been dequeued from the list without
- // needing to inspect the list.
- self.atomic_update_state(state_after_cancellation, |mut state| {
- state.cancel_state = CancellationState::Cancelled;
- state
- });
- }
-
- /// Returns `true` if the `CancellationToken` had been cancelled
- fn is_cancelled(&self) -> bool {
- let current_state = self.snapshot();
- current_state.cancel_state != CancellationState::NotCancelled
- }
-
- /// Registers a waiting task at the `CancellationToken`.
- /// Safety: This method is only safe as long as the waiting waiting task
- /// will properly unregister the wait node before it gets moved.
- unsafe fn register(
- &self,
- wait_node: &mut ListNode<WaitQueueEntry>,
- cx: &mut Context<'_>,
- ) -> Poll<()> {
- debug_assert_eq!(PollState::New, wait_node.state);
- let current_state = self.snapshot();
-
- // Perform an optimistic cancellation check before. This is not strictly
- // necessary since we also check for cancellation in the Mutex, but
- // reduces the necessary work to be performed for tasks which already
- // had been cancelled.
- if current_state.cancel_state != CancellationState::NotCancelled {
- return Poll::Ready(());
- }
-
- // So far the token is not cancelled. However it could be cancelld before
- // we get the chance to store the `Waker`. Therfore we need to check
- // for cancellation again inside the mutex.
- let mut guard = self.synchronized.lock().unwrap();
- if guard.is_cancelled {
- // Cancellation was signalled
- wait_node.state = PollState::Done;
- Poll::Ready(())
- } else {
- // Added the task to the wait queue
- wait_node.task = Some(cx.waker().clone());
- wait_node.state = PollState::Waiting;
- guard.waiters.add_front(wait_node);
- Poll::Pending
- }
- }
-
- fn check_for_cancellation(
- &self,
- wait_node: &mut ListNode<WaitQueueEntry>,
- cx: &mut Context<'_>,
- ) -> Poll<()> {
- debug_assert!(
- wait_node.task.is_some(),
- "Method can only be called after task had been registered"
- );
-
- let current_state = self.snapshot();
-
- if current_state.cancel_state != CancellationState::NotCancelled {
- // If the cancellation had been fully completed we know that our `Waker`
- // is no longer registered at the `CancellationToken`.
- // Otherwise the cancel call may or may not yet have iterated
- // through the waiters list and removed the wait nodes.
- // If it hasn't yet, we need to remove it. Otherwise an attempt to
- // reuse the `wait_node´ might get freed due to the `WaitForCancellationFuture`
- // getting dropped before the cancellation had interacted with it.
- if current_state.cancel_state != CancellationState::Cancelled {
- self.unregister(wait_node);
- }
- Poll::Ready(())
- } else {
- // Check if we need to swap the `Waker`. This will make the check more
- // expensive, since the `Waker` is synchronized through the Mutex.
- // If we don't need to perform a `Waker` update, an atomic check for
- // cancellation is sufficient.
- let need_waker_update = wait_node
- .task
- .as_ref()
- .map(|waker| waker.will_wake(cx.waker()))
- .unwrap_or(true);
-
- if need_waker_update {
- let guard = self.synchronized.lock().unwrap();
- if guard.is_cancelled {
- // Cancellation was signalled. Since this cancellation signal
- // is set inside the Mutex, the old waiter must already have
- // been removed from the waiting list
- debug_assert_eq!(PollState::Done, wait_node.state);
- wait_node.task = None;
- Poll::Ready(())
- } else {
- // The WaitForCancellationFuture is already in the queue.
- // The CancellationToken can't have been cancelled,
- // since this would change the is_cancelled flag inside the mutex.
- // Therefore we just have to update the Waker. A follow-up
- // cancellation will always use the new waker.
- wait_node.task = Some(cx.waker().clone());
- Poll::Pending
- }
- } else {
- // Do nothing. If the token gets cancelled, this task will get
- // woken again and can fetch the cancellation.
- Poll::Pending
- }
- }
- }
-
- fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) {
- debug_assert!(
- wait_node.task.is_some(),
- "waiter can not be active without task"
- );
-
- let mut guard = self.synchronized.lock().unwrap();
- // WaitForCancellationFuture only needs to get removed if it has been added to
- // the wait queue of the CancellationToken.
- // This has happened in the PollState::Waiting case.
- if let PollState::Waiting = wait_node.state {
- // Safety: Due to the state, we know that the node must be part
- // of the waiter list
- if !unsafe { guard.waiters.remove(wait_node) } {
- // Panic if the address isn't found. This can only happen if the contract was
- // violated, e.g. the WaitQueueEntry got moved after the initial poll.
- panic!("Future could not be removed from wait queue");
- }
- wait_node.state = PollState::Done;
- }
- wait_node.task = None;
- }
-}
diff --git a/src/sync/mod.rs b/src/sync/mod.rs
index 3d96106..57ae277 100644
--- a/src/sync/mod.rs
+++ b/src/sync/mod.rs
@@ -20,7 +20,7 @@
//! few flavors of channels provided by Tokio. Each channel flavor supports
//! different message passing patterns. When a channel supports multiple
//! producers, many separate tasks may **send** messages. When a channel
-//! supports muliple consumers, many different separate tasks may **receive**
+//! supports multiple consumers, many different separate tasks may **receive**
//! messages.
//!
//! Tokio provides many different channel flavors as different message passing
@@ -106,7 +106,7 @@
//!
//! #[tokio::main]
//! async fn main() {
-//! let (mut tx, mut rx) = mpsc::channel(100);
+//! let (tx, mut rx) = mpsc::channel(100);
//!
//! tokio::spawn(async move {
//! for i in 0..10 {
@@ -150,7 +150,7 @@
//! for _ in 0..10 {
//! // Each task needs its own `tx` handle. This is done by cloning the
//! // original handle.
-//! let mut tx = tx.clone();
+//! let tx = tx.clone();
//!
//! tokio::spawn(async move {
//! tx.send(&b"data to write"[..]).await.unwrap();
@@ -213,7 +213,7 @@
//!
//! // Spawn tasks that will send the increment command.
//! for _ in 0..10 {
-//! let mut cmd_tx = cmd_tx.clone();
+//! let cmd_tx = cmd_tx.clone();
//!
//! join_handles.push(tokio::spawn(async move {
//! let (resp_tx, resp_rx) = oneshot::channel();
@@ -322,7 +322,7 @@
//! tokio::spawn(async move {
//! loop {
//! // Wait 10 seconds between checks
-//! time::delay_for(Duration::from_secs(10)).await;
+//! time::sleep(Duration::from_secs(10)).await;
//!
//! // Load the configuration file
//! let new_config = Config::load_from_file().await.unwrap();
@@ -330,7 +330,7 @@
//! // If the configuration changed, send the new config value
//! // on the watch channel.
//! if new_config != config {
-//! tx.broadcast(new_config.clone()).unwrap();
+//! tx.send(new_config.clone()).unwrap();
//! config = new_config;
//! }
//! }
@@ -355,17 +355,15 @@
//! let op = my_async_operation();
//! tokio::pin!(op);
//!
-//! // Receive the **initial** configuration value. As this is the
-//! // first time the config is received from the watch, it will
-//! // always complete immediatedly.
-//! let mut conf = rx.recv().await.unwrap();
+//! // Get the initial config value
+//! let mut conf = rx.borrow().clone();
//!
//! let mut op_start = Instant::now();
-//! let mut delay = time::delay_until(op_start + conf.timeout);
+//! let mut sleep = time::sleep_until(op_start + conf.timeout);
//!
//! loop {
//! tokio::select! {
-//! _ = &mut delay => {
+//! _ = &mut sleep => {
//! // The operation elapsed. Restart it
//! op.set(my_async_operation());
//!
@@ -373,14 +371,14 @@
//! op_start = Instant::now();
//!
//! // Restart the timeout
-//! delay = time::delay_until(op_start + conf.timeout);
+//! sleep = time::sleep_until(op_start + conf.timeout);
//! }
-//! new_conf = rx.recv() => {
-//! conf = new_conf.unwrap();
+//! _ = rx.changed() => {
+//! conf = rx.borrow().clone();
//!
//! // The configuration has been updated. Update the
-//! // `delay` using the new `timeout` value.
-//! delay.reset(op_start + conf.timeout);
+//! // `sleep` using the new `timeout` value.
+//! sleep.reset(op_start + conf.timeout);
//! }
//! _ = &mut op => {
//! // The operation completed!
@@ -399,14 +397,14 @@
//! }
//! ```
//!
-//! [`watch` channel]: crate::sync::watch
-//! [`broadcast` channel]: crate::sync::broadcast
+//! [`watch` channel]: mod@crate::sync::watch
+//! [`broadcast` channel]: mod@crate::sync::broadcast
//!
//! # State synchronization
//!
//! The remaining synchronization primitives focus on synchronizing state.
//! These are asynchronous equivalents to versions provided by `std`. They
-//! operate in a similar way as their `std` counterparts parts but will wait
+//! operate in a similar way as their `std` counterparts but will wait
//! asynchronously instead of blocking the thread.
//!
//! * [`Barrier`](Barrier) Ensures multiple tasks will wait for each other to
@@ -434,23 +432,17 @@ cfg_sync! {
pub mod broadcast;
- cfg_unstable! {
- mod cancellation_token;
- pub use cancellation_token::{CancellationToken, WaitForCancellationFuture};
- }
-
pub mod mpsc;
mod mutex;
pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard};
- mod notify;
+ pub(crate) mod notify;
pub use notify::Notify;
pub mod oneshot;
pub(crate) mod batch_semaphore;
- pub(crate) mod semaphore_ll;
mod semaphore;
pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit};
@@ -464,20 +456,30 @@ cfg_sync! {
}
cfg_not_sync! {
+ #[cfg(any(feature = "fs", feature = "signal", all(unix, feature = "process")))]
+ pub(crate) mod batch_semaphore;
+
+ cfg_fs! {
+ mod mutex;
+ pub(crate) use mutex::Mutex;
+ }
+
+ #[cfg(any(feature = "rt", feature = "signal", all(unix, feature = "process")))]
+ pub(crate) mod notify;
+
cfg_atomic_waker_impl! {
mod task;
pub(crate) use task::AtomicWaker;
}
#[cfg(any(
- feature = "rt-core",
+ feature = "rt",
feature = "process",
feature = "signal"))]
pub(crate) mod oneshot;
- cfg_signal! {
+ cfg_signal_internal! {
pub(crate) mod mpsc;
- pub(crate) mod semaphore_ll;
}
}
diff --git a/src/sync/mpsc/block.rs b/src/sync/mpsc/block.rs
index 7bf1619..e062f2b 100644
--- a/src/sync/mpsc/block.rs
+++ b/src/sync/mpsc/block.rs
@@ -1,8 +1,6 @@
-use crate::loom::{
- cell::UnsafeCell,
- sync::atomic::{AtomicPtr, AtomicUsize},
- thread,
-};
+use crate::loom::cell::UnsafeCell;
+use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
+use crate::loom::thread;
use std::mem::MaybeUninit;
use std::ops;
diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs
index afca8c5..06b3717 100644
--- a/src/sync/mpsc/bounded.rs
+++ b/src/sync/mpsc/bounded.rs
@@ -1,6 +1,6 @@
+use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError};
use crate::sync::mpsc::chan;
-use crate::sync::mpsc::error::{ClosedError, SendError, TryRecvError, TrySendError};
-use crate::sync::semaphore_ll as semaphore;
+use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
cfg_time! {
use crate::sync::mpsc::error::SendTimeoutError;
@@ -8,6 +8,7 @@ cfg_time! {
}
use std::fmt;
+#[cfg(any(feature = "signal", feature = "process", feature = "stream"))]
use std::task::{Context, Poll};
/// Send values to the associated `Receiver`.
@@ -17,20 +18,14 @@ pub struct Sender<T> {
chan: chan::Tx<T, Semaphore>,
}
-impl<T> Clone for Sender<T> {
- fn clone(&self) -> Self {
- Sender {
- chan: self.chan.clone(),
- }
- }
-}
-
-impl<T> fmt::Debug for Sender<T> {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- fmt.debug_struct("Sender")
- .field("chan", &self.chan)
- .finish()
- }
+/// Permit to send one value into the channel.
+///
+/// `Permit` values are returned by [`Sender::reserve()`] and are used to
+/// guarantee channel capacity before generating a message to send.
+///
+/// [`Sender::reserve()`]: Sender::reserve
+pub struct Permit<'a, T> {
+ chan: &'a chan::Tx<T, Semaphore>,
}
/// Receive values from the associated `Sender`.
@@ -41,16 +36,12 @@ pub struct Receiver<T> {
chan: chan::Rx<T, Semaphore>,
}
-impl<T> fmt::Debug for Receiver<T> {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- fmt.debug_struct("Receiver")
- .field("chan", &self.chan)
- .finish()
- }
-}
-
-/// Creates a bounded mpsc channel for communicating between asynchronous tasks,
-/// returning the sender/receiver halves.
+/// Creates a bounded mpsc channel for communicating between asynchronous tasks
+/// with backpressure.
+///
+/// The channel will buffer up to the provided number of messages. Once the
+/// buffer is full, attempts to `send` new messages will wait until a message is
+/// received from the channel. The provided buffer capacity must be at least 1.
///
/// All data sent on `Sender` will become available on `Receiver` in the same
/// order as it was sent.
@@ -62,6 +53,10 @@ impl<T> fmt::Debug for Receiver<T> {
/// will return a `SendError`. Similarly, if `Sender` is disconnected while
/// trying to `recv`, the `recv` method will return a `RecvError`.
///
+/// # Panics
+///
+/// Panics if the buffer capacity is 0.
+///
/// # Examples
///
/// ```rust
@@ -69,7 +64,7 @@ impl<T> fmt::Debug for Receiver<T> {
///
/// #[tokio::main]
/// async fn main() {
-/// let (mut tx, mut rx) = mpsc::channel(100);
+/// let (tx, mut rx) = mpsc::channel(100);
///
/// tokio::spawn(async move {
/// for i in 0..10 {
@@ -117,7 +112,7 @@ impl<T> Receiver<T> {
///
/// #[tokio::main]
/// async fn main() {
- /// let (mut tx, mut rx) = mpsc::channel(100);
+ /// let (tx, mut rx) = mpsc::channel(100);
///
/// tokio::spawn(async move {
/// tx.send("hello").await.unwrap();
@@ -135,7 +130,7 @@ impl<T> Receiver<T> {
///
/// #[tokio::main]
/// async fn main() {
- /// let (mut tx, mut rx) = mpsc::channel(100);
+ /// let (tx, mut rx) = mpsc::channel(100);
///
/// tx.send("hello").await.unwrap();
/// tx.send("world").await.unwrap();
@@ -146,15 +141,48 @@ impl<T> Receiver<T> {
/// ```
pub async fn recv(&mut self) -> Option<T> {
use crate::future::poll_fn;
-
- poll_fn(|cx| self.poll_recv(cx)).await
+ poll_fn(|cx| self.chan.recv(cx)).await
}
- #[doc(hidden)] // TODO: document
- pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
+ #[cfg(any(feature = "signal", feature = "process"))]
+ pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
self.chan.recv(cx)
}
+ /// Blocking receive to call outside of asynchronous contexts.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if called within an asynchronous execution
+ /// context.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::thread;
+ /// use tokio::runtime::Runtime;
+ /// use tokio::sync::mpsc;
+ ///
+ /// fn main() {
+ /// let (tx, mut rx) = mpsc::channel::<u8>(10);
+ ///
+ /// let sync_code = thread::spawn(move || {
+ /// assert_eq!(Some(10), rx.blocking_recv());
+ /// });
+ ///
+ /// Runtime::new()
+ /// .unwrap()
+ /// .block_on(async move {
+ /// let _ = tx.send(10).await;
+ /// });
+ /// sync_code.join().unwrap()
+ /// }
+ /// ```
+ #[cfg(feature = "sync")]
+ pub fn blocking_recv(&mut self) -> Option<T> {
+ crate::future::block_on(self.recv())
+ }
+
/// Attempts to return a pending value on this receiver without blocking.
///
/// This method will never block the caller in order to wait for data to
@@ -173,12 +201,53 @@ impl<T> Receiver<T> {
/// Closes the receiving half of a channel, without dropping it.
///
/// This prevents any further messages from being sent on the channel while
- /// still enabling the receiver to drain messages that are buffered.
+ /// still enabling the receiver to drain messages that are buffered. Any
+ /// outstanding [`Permit`] values will still be able to send messages.
+ ///
+ /// In order to guarantee no messages are dropped, after calling `close()`,
+ /// `recv()` must be called until `None` is returned.
+ ///
+ /// [`Permit`]: Permit
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = mpsc::channel(20);
+ ///
+ /// tokio::spawn(async move {
+ /// let mut i = 0;
+ /// while let Ok(permit) = tx.reserve().await {
+ /// permit.send(i);
+ /// i += 1;
+ /// }
+ /// });
+ ///
+ /// rx.close();
+ ///
+ /// while let Some(msg) = rx.recv().await {
+ /// println!("got {}", msg);
+ /// }
+ ///
+ /// // Channel closed and no messages are lost.
+ /// }
+ /// ```
pub fn close(&mut self) {
self.chan.close();
}
}
+impl<T> fmt::Debug for Receiver<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Receiver")
+ .field("chan", &self.chan)
+ .finish()
+ }
+}
+
impl<T> Unpin for Receiver<T> {}
cfg_stream! {
@@ -186,7 +255,7 @@ cfg_stream! {
type Item = T;
fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
- self.poll_recv(cx)
+ self.chan.recv(cx)
}
}
}
@@ -225,7 +294,7 @@ impl<T> Sender<T> {
///
/// #[tokio::main]
/// async fn main() {
- /// let (mut tx, mut rx) = mpsc::channel(1);
+ /// let (tx, mut rx) = mpsc::channel(1);
///
/// tokio::spawn(async move {
/// for i in 0..10 {
@@ -241,18 +310,49 @@ impl<T> Sender<T> {
/// }
/// }
/// ```
- pub async fn send(&mut self, value: T) -> Result<(), SendError<T>> {
- use crate::future::poll_fn;
-
- if poll_fn(|cx| self.poll_ready(cx)).await.is_err() {
- return Err(SendError(value));
+ pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
+ match self.reserve().await {
+ Ok(permit) => {
+ permit.send(value);
+ Ok(())
+ }
+ Err(_) => Err(SendError(value)),
}
+ }
- match self.try_send(value) {
- Ok(()) => Ok(()),
- Err(TrySendError::Full(_)) => unreachable!(),
- Err(TrySendError::Closed(value)) => Err(SendError(value)),
- }
+ /// Completes when the receiver has dropped.
+ ///
+ /// This allows the producers to get notified when interest in the produced
+ /// values is canceled and immediately stop doing work.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx1, rx) = mpsc::channel::<()>(1);
+ /// let tx2 = tx1.clone();
+ /// let tx3 = tx1.clone();
+ /// let tx4 = tx1.clone();
+ /// let tx5 = tx1.clone();
+ /// tokio::spawn(async move {
+ /// drop(rx);
+ /// });
+ ///
+ /// futures::join!(
+ /// tx1.closed(),
+ /// tx2.closed(),
+ /// tx3.closed(),
+ /// tx4.closed(),
+ /// tx5.closed()
+ /// );
+ //// println!("Receiver dropped");
+ /// }
+ /// ```
+ pub async fn closed(&self) {
+ self.chan.closed().await
}
/// Attempts to immediately send a message on this `Sender`
@@ -262,9 +362,6 @@ impl<T> Sender<T> {
/// with [`send`], this function has two failure cases instead of one (one for
/// disconnection, one for a full buffer).
///
- /// This function may be paired with [`poll_ready`] in order to wait for
- /// channel capacity before trying to send a value.
- ///
/// # Errors
///
/// If the channel capacity has been reached, i.e., the channel has `n`
@@ -276,7 +373,6 @@ impl<T> Sender<T> {
/// an error. The error includes the value passed to `send`.
///
/// [`send`]: Sender::send
- /// [`poll_ready`]: Sender::poll_ready
/// [`channel`]: channel
/// [`close`]: Receiver::close
///
@@ -288,8 +384,8 @@ impl<T> Sender<T> {
/// #[tokio::main]
/// async fn main() {
/// // Create a channel with buffer size 1
- /// let (mut tx1, mut rx) = mpsc::channel(1);
- /// let mut tx2 = tx1.clone();
+ /// let (tx1, mut rx) = mpsc::channel(1);
+ /// let tx2 = tx1.clone();
///
/// tokio::spawn(async move {
/// tx1.send(1).await.unwrap();
@@ -317,8 +413,15 @@ impl<T> Sender<T> {
/// }
/// }
/// ```
- pub fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
- self.chan.try_send(message)?;
+ pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
+ match self.chan.semaphore().0.try_acquire(1) {
+ Ok(_) => {}
+ Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)),
+ Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)),
+ }
+
+ // Send the message
+ self.chan.send(message);
Ok(())
}
@@ -346,11 +449,11 @@ impl<T> Sender<T> {
///
/// ```rust
/// use tokio::sync::mpsc;
- /// use tokio::time::{delay_for, Duration};
+ /// use tokio::time::{sleep, Duration};
///
/// #[tokio::main]
/// async fn main() {
- /// let (mut tx, mut rx) = mpsc::channel(1);
+ /// let (tx, mut rx) = mpsc::channel(1);
///
/// tokio::spawn(async move {
/// for i in 0..10 {
@@ -363,117 +466,213 @@ impl<T> Sender<T> {
///
/// while let Some(i) = rx.recv().await {
/// println!("got = {}", i);
- /// delay_for(Duration::from_millis(200)).await;
+ /// sleep(Duration::from_millis(200)).await;
/// }
/// }
/// ```
#[cfg(feature = "time")]
#[cfg_attr(docsrs, doc(cfg(feature = "time")))]
pub async fn send_timeout(
- &mut self,
+ &self,
value: T,
timeout: Duration,
) -> Result<(), SendTimeoutError<T>> {
- use crate::future::poll_fn;
-
- match crate::time::timeout(timeout, poll_fn(|cx| self.poll_ready(cx))).await {
+ let permit = match crate::time::timeout(timeout, self.reserve()).await {
Err(_) => {
return Err(SendTimeoutError::Timeout(value));
}
Ok(Err(_)) => {
return Err(SendTimeoutError::Closed(value));
}
- Ok(_) => {}
- }
+ Ok(Ok(permit)) => permit,
+ };
- match self.try_send(value) {
- Ok(()) => Ok(()),
- Err(TrySendError::Full(_)) => unreachable!(),
- Err(TrySendError::Closed(value)) => Err(SendTimeoutError::Closed(value)),
- }
+ permit.send(value);
+ Ok(())
}
- /// Returns `Poll::Ready(Ok(()))` when the channel is able to accept another item.
+ /// Blocking send to call outside of asynchronous contexts.
///
- /// If the channel is full, then `Poll::Pending` is returned and the task is notified when a
- /// slot becomes available.
+ /// # Panics
///
- /// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a call to `try_send` will succeed unless
- /// the channel has since been closed. To provide this guarantee, the channel reserves one slot
- /// in the channel for the coming send. This reserved slot is not available to other `Sender`
- /// instances, so you need to be careful to not end up with deadlocks by blocking after calling
- /// `poll_ready` but before sending an element.
+ /// This function panics if called within an asynchronous execution
+ /// context.
///
- /// If, after `poll_ready` succeeds, you decide you do not wish to send an item after all, you
- /// can use [`disarm`](Sender::disarm) to release the reserved slot.
+ /// # Examples
///
- /// Until an item is sent or [`disarm`](Sender::disarm) is called, repeated calls to
- /// `poll_ready` will return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))` if channel
- /// is closed.
- pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> {
- self.chan.poll_ready(cx).map_err(|_| ClosedError::new())
+ /// ```
+ /// use std::thread;
+ /// use tokio::runtime::Runtime;
+ /// use tokio::sync::mpsc;
+ ///
+ /// fn main() {
+ /// let (tx, mut rx) = mpsc::channel::<u8>(1);
+ ///
+ /// let sync_code = thread::spawn(move || {
+ /// tx.blocking_send(10).unwrap();
+ /// });
+ ///
+ /// Runtime::new().unwrap().block_on(async move {
+ /// assert_eq!(Some(10), rx.recv().await);
+ /// });
+ /// sync_code.join().unwrap()
+ /// }
+ /// ```
+ #[cfg(feature = "sync")]
+ pub fn blocking_send(&self, value: T) -> Result<(), SendError<T>> {
+ crate::future::block_on(self.send(value))
}
- /// Undo a successful call to `poll_ready`.
+ /// Checks if the channel has been closed. This happens when the
+ /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is
+ /// called.
///
- /// Once a call to `poll_ready` returns `Poll::Ready(Ok(()))`, it holds up one slot in the
- /// channel to make room for the coming send. `disarm` allows you to give up that slot if you
- /// decide you do not wish to send an item after all. After calling `disarm`, you must call
- /// `poll_ready` until it returns `Poll::Ready(Ok(()))` before attempting to send again.
+ /// [`Receiver`]: crate::sync::mpsc::Receiver
+ /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close
///
- /// Returns `false` if no slot is reserved for this sender (usually because `poll_ready` was
- /// not previously called, or did not succeed).
+ /// ```
+ /// let (tx, rx) = tokio::sync::mpsc::channel::<()>(42);
+ /// assert!(!tx.is_closed());
///
- /// # Motivation
+ /// let tx2 = tx.clone();
+ /// assert!(!tx2.is_closed());
///
- /// Since `poll_ready` takes up one of the finite number of slots in a bounded channel, callers
- /// need to send an item shortly after `poll_ready` succeeds. If they do not, idle senders may
- /// take up all the slots of the channel, and prevent active senders from getting any requests
- /// through. Consider this code that forwards from one channel to another:
+ /// drop(rx);
+ /// assert!(tx.is_closed());
+ /// assert!(tx2.is_closed());
+ /// ```
+ pub fn is_closed(&self) -> bool {
+ self.chan.is_closed()
+ }
+
+ /// Wait for channel capacity. Once capacity to send one message is
+ /// available, it is reserved for the caller.
+ ///
+ /// If the channel is full, the function waits for the number of unreceived
+ /// messages to become less than the channel capacity. Capacity to send one
+ /// message is reserved for the caller. A [`Permit`] is returned to track
+ /// the reserved capacity. The [`send`] function on [`Permit`] consumes the
+ /// reserved capacity.
+ ///
+ /// Dropping [`Permit`] without sending a message releases the capacity back
+ /// to the channel.
+ ///
+ /// [`Permit`]: Permit
+ /// [`send`]: Permit::send
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = mpsc::channel(1);
+ ///
+ /// // Reserve capacity
+ /// let permit = tx.reserve().await.unwrap();
///
- /// ```rust,ignore
- /// loop {
- /// ready!(tx.poll_ready(cx))?;
- /// if let Some(item) = ready!(rx.poll_recv(cx)) {
- /// tx.try_send(item)?;
- /// } else {
- /// break;
- /// }
+ /// // Trying to send directly on the `tx` will fail due to no
+ /// // available capacity.
+ /// assert!(tx.try_send(123).is_err());
+ ///
+ /// // Sending on the permit succeeds
+ /// permit.send(456);
+ ///
+ /// // The value sent on the permit is received
+ /// assert_eq!(rx.recv().await.unwrap(), 456);
/// }
/// ```
+ pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
+ match self.chan.semaphore().0.acquire(1).await {
+ Ok(_) => {}
+ Err(_) => return Err(SendError(())),
+ }
+
+ Ok(Permit { chan: &self.chan })
+ }
+}
+
+impl<T> Clone for Sender<T> {
+ fn clone(&self) -> Self {
+ Sender {
+ chan: self.chan.clone(),
+ }
+ }
+}
+
+impl<T> fmt::Debug for Sender<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Sender")
+ .field("chan", &self.chan)
+ .finish()
+ }
+}
+
+// ===== impl Permit =====
+
+impl<T> Permit<'_, T> {
+ /// Sends a value using the reserved capacity.
+ ///
+ /// Capacity for the message has already been reserved. The message is sent
+ /// to the receiver and the permit is consumed. The operation will succeed
+ /// even if the receiver half has been closed. See [`Receiver::close`] for
+ /// more details on performing a clean shutdown.
+ ///
+ /// [`Receiver::close`]: Receiver::close
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = mpsc::channel(1);
+ ///
+ /// // Reserve capacity
+ /// let permit = tx.reserve().await.unwrap();
///
- /// If many such forwarders exist, and they all forward into a single (cloned) `Sender`, then
- /// any number of forwarders may be waiting for `rx.poll_recv` at the same time. While they do,
- /// they are effectively each reducing the channel's capacity by 1. If enough of these
- /// forwarders are idle, forwarders whose `rx` _do_ have elements will be unable to find a spot
- /// for them through `poll_ready`, and the system will deadlock.
- ///
- /// `disarm` solves this problem by allowing you to give up the reserved slot if you find that
- /// you have to block. We can then fix the code above by writing:
- ///
- /// ```rust,ignore
- /// loop {
- /// ready!(tx.poll_ready(cx))?;
- /// let item = rx.poll_recv(cx);
- /// if let Poll::Ready(Ok(_)) = item {
- /// // we're going to send the item below, so don't disarm
- /// } else {
- /// // give up our send slot, we won't need it for a while
- /// tx.disarm();
- /// }
- /// if let Some(item) = ready!(item) {
- /// tx.try_send(item)?;
- /// } else {
- /// break;
- /// }
+ /// // Trying to send directly on the `tx` will fail due to no
+ /// // available capacity.
+ /// assert!(tx.try_send(123).is_err());
+ ///
+ /// // Send a message on the permit
+ /// permit.send(456);
+ ///
+ /// // The value sent on the permit is received
+ /// assert_eq!(rx.recv().await.unwrap(), 456);
/// }
/// ```
- pub fn disarm(&mut self) -> bool {
- if self.chan.is_ready() {
- self.chan.disarm();
- true
- } else {
- false
+ pub fn send(self, value: T) {
+ use std::mem;
+
+ self.chan.send(value);
+
+ // Avoid the drop logic
+ mem::forget(self);
+ }
+}
+
+impl<T> Drop for Permit<'_, T> {
+ fn drop(&mut self) {
+ use chan::Semaphore;
+
+ let semaphore = self.chan.semaphore();
+
+ // Add the permit back to the semaphore
+ semaphore.add_permit();
+
+ if semaphore.is_closed() && semaphore.is_idle() {
+ self.chan.wake_rx();
}
}
}
+
+impl<T> fmt::Debug for Permit<'_, T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Permit")
+ .field("chan", &self.chan)
+ .finish()
+ }
+}
diff --git a/src/sync/mpsc/chan.rs b/src/sync/mpsc/chan.rs
index 148ee3a..c78fb50 100644
--- a/src/sync/mpsc/chan.rs
+++ b/src/sync/mpsc/chan.rs
@@ -2,8 +2,9 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::future::AtomicWaker;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Arc;
-use crate::sync::mpsc::error::{ClosedError, TryRecvError};
-use crate::sync::mpsc::{error, list};
+use crate::sync::mpsc::error::TryRecvError;
+use crate::sync::mpsc::list;
+use crate::sync::notify::Notify;
use std::fmt;
use std::process;
@@ -12,21 +13,13 @@ use std::task::Poll::{Pending, Ready};
use std::task::{Context, Poll};
/// Channel sender
-pub(crate) struct Tx<T, S: Semaphore> {
+pub(crate) struct Tx<T, S> {
inner: Arc<Chan<T, S>>,
- permit: S::Permit,
}
-impl<T, S: Semaphore> fmt::Debug for Tx<T, S>
-where
- S::Permit: fmt::Debug,
- S: fmt::Debug,
-{
+impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- fmt.debug_struct("Tx")
- .field("inner", &self.inner)
- .field("permit", &self.permit)
- .finish()
+ fmt.debug_struct("Tx").field("inner", &self.inner).finish()
}
}
@@ -35,70 +28,26 @@ pub(crate) struct Rx<T, S: Semaphore> {
inner: Arc<Chan<T, S>>,
}
-impl<T, S: Semaphore> fmt::Debug for Rx<T, S>
-where
- S: fmt::Debug,
-{
+impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Rx").field("inner", &self.inner).finish()
}
}
-#[derive(Debug, Eq, PartialEq)]
-pub(crate) enum TrySendError {
- Closed,
- Full,
-}
-
-impl<T> From<(T, TrySendError)> for error::SendError<T> {
- fn from(src: (T, TrySendError)) -> error::SendError<T> {
- match src.1 {
- TrySendError::Closed => error::SendError(src.0),
- TrySendError::Full => unreachable!(),
- }
- }
-}
-
-impl<T> From<(T, TrySendError)> for error::TrySendError<T> {
- fn from(src: (T, TrySendError)) -> error::TrySendError<T> {
- match src.1 {
- TrySendError::Closed => error::TrySendError::Closed(src.0),
- TrySendError::Full => error::TrySendError::Full(src.0),
- }
- }
-}
-
pub(crate) trait Semaphore {
- type Permit;
-
- fn new_permit() -> Self::Permit;
-
- /// The permit is dropped without a value being sent. In this case, the
- /// permit must be returned to the semaphore.
- fn drop_permit(&self, permit: &mut Self::Permit);
-
fn is_idle(&self) -> bool;
fn add_permit(&self);
- fn poll_acquire(
- &self,
- cx: &mut Context<'_>,
- permit: &mut Self::Permit,
- ) -> Poll<Result<(), ClosedError>>;
-
- fn try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>;
-
- /// A value was sent into the channel and the permit held by `tx` is
- /// dropped. In this case, the permit should not immeditely be returned to
- /// the semaphore. Instead, the permit is returnred to the semaphore once
- /// the sent value is read by the rx handle.
- fn forget(&self, permit: &mut Self::Permit);
-
fn close(&self);
+
+ fn is_closed(&self) -> bool;
}
struct Chan<T, S> {
+ /// Notifies all tasks listening for the receiver being dropped
+ notify_rx_closed: Notify,
+
/// Handle to the push half of the lock-free list.
tx: list::Tx<T>,
@@ -153,13 +102,11 @@ impl<T> fmt::Debug for RxFields<T> {
unsafe impl<T: Send, S: Send> Send for Chan<T, S> {}
unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {}
-pub(crate) fn channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>)
-where
- S: Semaphore,
-{
+pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) {
let (tx, rx) = list::channel();
let chan = Arc::new(Chan {
+ notify_rx_closed: Notify::new(),
tx,
semaphore,
rx_waker: AtomicWaker::new(),
@@ -175,48 +122,60 @@ where
// ===== impl Tx =====
-impl<T, S> Tx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S> Tx<T, S> {
fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> {
- Tx {
- inner: chan,
- permit: S::new_permit(),
- }
- }
-
- pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> {
- self.inner.semaphore.poll_acquire(cx, &mut self.permit)
+ Tx { inner: chan }
}
- pub(crate) fn disarm(&mut self) {
- // TODO: should this error if not acquired?
- self.inner.semaphore.drop_permit(&mut self.permit)
+ pub(super) fn semaphore(&self) -> &S {
+ &self.inner.semaphore
}
/// Send a message and notify the receiver.
- pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> {
- self.inner.try_send(value, &mut self.permit)
+ pub(crate) fn send(&self, value: T) {
+ self.inner.send(value);
}
-}
-impl<T> Tx<T, (crate::sync::semaphore_ll::Semaphore, usize)> {
- pub(crate) fn is_ready(&self) -> bool {
- self.permit.is_acquired()
+ /// Wake the receive half
+ pub(crate) fn wake_rx(&self) {
+ self.inner.rx_waker.wake();
}
}
-impl<T> Tx<T, AtomicUsize> {
- pub(crate) fn send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)> {
- self.inner.try_send(value, &mut ())
+impl<T, S: Semaphore> Tx<T, S> {
+ pub(crate) fn is_closed(&self) -> bool {
+ self.inner.semaphore.is_closed()
+ }
+
+ pub(crate) async fn closed(&self) {
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::task::Poll;
+
+ // In order to avoid a race condition, we first request a notification,
+ // **then** check the current value's version. If a new version exists,
+ // the notification request is dropped. Requesting the notification
+ // requires polling the future once.
+ let notified = self.inner.notify_rx_closed.notified();
+ pin!(notified);
+
+ // Polling the future once is guaranteed to return `Pending` as `watch`
+ // only notifies using `notify_waiters`.
+ crate::future::poll_fn(|cx| {
+ let res = Pin::new(&mut notified).poll(cx);
+ assert!(!res.is_ready());
+ Poll::Ready(())
+ })
+ .await;
+
+ if self.inner.semaphore.is_closed() {
+ return;
+ }
+ notified.await;
}
}
-impl<T, S> Clone for Tx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S> Clone for Tx<T, S> {
fn clone(&self) -> Tx<T, S> {
// Using a Relaxed ordering here is sufficient as the caller holds a
// strong ref to `self`, preventing a concurrent decrement to zero.
@@ -224,18 +183,12 @@ where
Tx {
inner: self.inner.clone(),
- permit: S::new_permit(),
}
}
}
-impl<T, S> Drop for Tx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S> Drop for Tx<T, S> {
fn drop(&mut self) {
- self.inner.semaphore.drop_permit(&mut self.permit);
-
if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 {
return;
}
@@ -244,16 +197,13 @@ where
self.inner.tx.close();
// Notify the receiver
- self.inner.rx_waker.wake();
+ self.wake_rx();
}
}
// ===== impl Rx =====
-impl<T, S> Rx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S: Semaphore> Rx<T, S> {
fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> {
Rx { inner: chan }
}
@@ -270,6 +220,7 @@ where
});
self.inner.semaphore.close();
+ self.inner.notify_rx_closed.notify_waiters();
}
/// Receive the next value
@@ -341,10 +292,7 @@ where
}
}
-impl<T, S> Drop for Rx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S: Semaphore> Drop for Rx<T, S> {
fn drop(&mut self) {
use super::block::Read::Value;
@@ -362,25 +310,13 @@ where
// ===== impl Chan =====
-impl<T, S> Chan<T, S>
-where
- S: Semaphore,
-{
- fn try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)> {
- if let Err(e) = self.semaphore.try_acquire(permit) {
- return Err((value, e));
- }
-
+impl<T, S> Chan<T, S> {
+ fn send(&self, value: T) {
// Push the value
self.tx.push(value);
// Notify the rx task
self.rx_waker.wake();
-
- // Release the permit
- self.semaphore.forget(permit);
-
- Ok(())
}
}
@@ -399,72 +335,24 @@ impl<T, S> Drop for Chan<T, S> {
}
}
-use crate::sync::semaphore_ll::TryAcquireError;
-
-impl From<TryAcquireError> for TrySendError {
- fn from(src: TryAcquireError) -> TrySendError {
- if src.is_closed() {
- TrySendError::Closed
- } else if src.is_no_permits() {
- TrySendError::Full
- } else {
- unreachable!();
- }
- }
-}
-
// ===== impl Semaphore for (::Semaphore, capacity) =====
-use crate::sync::semaphore_ll::Permit;
-
-impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) {
- type Permit = Permit;
-
- fn new_permit() -> Permit {
- Permit::new()
- }
-
- fn drop_permit(&self, permit: &mut Permit) {
- permit.release(1, &self.0);
- }
-
+impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) {
fn add_permit(&self) {
- self.0.add_permits(1)
+ self.0.release(1)
}
fn is_idle(&self) -> bool {
self.0.available_permits() == self.1
}
- fn poll_acquire(
- &self,
- cx: &mut Context<'_>,
- permit: &mut Permit,
- ) -> Poll<Result<(), ClosedError>> {
- // Keep track of task budget
- let coop = ready!(crate::coop::poll_proceed(cx));
-
- permit
- .poll_acquire(cx, 1, &self.0)
- .map_err(|_| ClosedError::new())
- .map(move |r| {
- coop.made_progress();
- r
- })
- }
-
- fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> {
- permit.try_acquire(1, &self.0)?;
- Ok(())
- }
-
- fn forget(&self, permit: &mut Self::Permit) {
- permit.forget(1);
- }
-
fn close(&self) {
self.0.close();
}
+
+ fn is_closed(&self) -> bool {
+ self.0.is_closed()
+ }
}
// ===== impl Semaphore for AtomicUsize =====
@@ -473,12 +361,6 @@ use std::sync::atomic::Ordering::{Acquire, Release};
use std::usize;
impl Semaphore for AtomicUsize {
- type Permit = ();
-
- fn new_permit() {}
-
- fn drop_permit(&self, _permit: &mut ()) {}
-
fn add_permit(&self) {
let prev = self.fetch_sub(2, Release);
@@ -492,40 +374,11 @@ impl Semaphore for AtomicUsize {
self.load(Acquire) >> 1 == 0
}
- fn poll_acquire(
- &self,
- _cx: &mut Context<'_>,
- permit: &mut (),
- ) -> Poll<Result<(), ClosedError>> {
- Ready(self.try_acquire(permit).map_err(|_| ClosedError::new()))
- }
-
- fn try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError> {
- let mut curr = self.load(Acquire);
-
- loop {
- if curr & 1 == 1 {
- return Err(TrySendError::Closed);
- }
-
- if curr == usize::MAX ^ 1 {
- // Overflowed the ref count. There is no safe way to recover, so
- // abort the process. In practice, this should never happen.
- process::abort()
- }
-
- match self.compare_exchange(curr, curr + 2, AcqRel, Acquire) {
- Ok(_) => return Ok(()),
- Err(actual) => {
- curr = actual;
- }
- }
- }
- }
-
- fn forget(&self, _permit: &mut ()) {}
-
fn close(&self) {
self.fetch_or(1, Release);
}
+
+ fn is_closed(&self) -> bool {
+ self.load(Acquire) & 1 == 1
+ }
}
diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs
index 72c42aa..7705452 100644
--- a/src/sync/mpsc/error.rs
+++ b/src/sync/mpsc/error.rs
@@ -94,26 +94,6 @@ impl fmt::Display for TryRecvError {
impl Error for TryRecvError {}
-// ===== ClosedError =====
-
-/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready).
-#[derive(Debug)]
-pub struct ClosedError(());
-
-impl ClosedError {
- pub(crate) fn new() -> ClosedError {
- ClosedError(())
- }
-}
-
-impl fmt::Display for ClosedError {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(fmt, "channel closed")
- }
-}
-
-impl Error for ClosedError {}
-
cfg_time! {
// ===== SendTimeoutError =====
diff --git a/src/sync/mpsc/list.rs b/src/sync/mpsc/list.rs
index 53f82a2..2f4c532 100644
--- a/src/sync/mpsc/list.rs
+++ b/src/sync/mpsc/list.rs
@@ -1,9 +1,7 @@
//! A concurrent, lock-free, FIFO list.
-use crate::loom::{
- sync::atomic::{AtomicPtr, AtomicUsize},
- thread,
-};
+use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
+use crate::loom::thread;
use crate::sync::mpsc::block::{self, Block};
use std::fmt;
diff --git a/src/sync/mpsc/mod.rs b/src/sync/mpsc/mod.rs
index c489c9f..a2bcf83 100644
--- a/src/sync/mpsc/mod.rs
+++ b/src/sync/mpsc/mod.rs
@@ -1,23 +1,29 @@
#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
-//! A multi-producer, single-consumer queue for sending values across
+//! A multi-producer, single-consumer queue for sending values between
//! asynchronous tasks.
//!
-//! Similar to `std`, channel creation provides [`Receiver`] and [`Sender`]
-//! handles. [`Receiver`] implements `Stream` and allows a task to read values
-//! out of the channel. If there is no message to read, the current task will be
-//! notified when a new value is sent. If the channel is at capacity, the send
-//! is rejected and the task will be notified when additional capacity is
-//! available. In other words, the channel provides backpressure.
-//!
//! This module provides two variants of the channel: bounded and unbounded. The
//! bounded variant has a limit on the number of messages that the channel can
//! store, and if this limit is reached, trying to send another message will
//! wait until a message is received from the channel. An unbounded channel has
-//! an infinite capacity, so the `send` method never does any kind of sleeping.
+//! an infinite capacity, so the `send` method will always complete immediately.
//! This makes the [`UnboundedSender`] usable from both synchronous and
//! asynchronous code.
//!
+//! Similar to the `mpsc` channels provided by `std`, the channel constructor
+//! functions provide separate send and receive handles, [`Sender`] and
+//! [`Receiver`] for the bounded channel, [`UnboundedSender`] and
+//! [`UnboundedReceiver`] for the unbounded channel. Both [`Receiver`] and
+//! [`UnboundedReceiver`] implement [`Stream`] and allow a task to read
+//! values out of the channel. If there is no message to read, the current task
+//! will be notified when a new value is sent. [`Sender`] and
+//! [`UnboundedSender`] allow sending values into the channel. If the bounded
+//! channel is at capacity, the send is rejected and the task will be notified
+//! when additional capacity is available. In other words, the channel provides
+//! backpressure.
+//!
+//!
//! # Disconnection
//!
//! When all [`Sender`] handles have been dropped, it is no longer
@@ -43,11 +49,10 @@
//! are two situations to consider:
//!
//! **Bounded channel**: If you need a bounded channel, you should use a bounded
-//! Tokio `mpsc` channel for both directions of communication. To call the async
-//! [`send`][bounded-send] or [`recv`][bounded-recv] methods in sync code, you
-//! will need to use [`Handle::block_on`], which allow you to execute an async
-//! method in synchronous code. This is necessary because a bounded channel may
-//! need to wait for additional capacity to become available.
+//! Tokio `mpsc` channel for both directions of communication. Instead of calling
+//! the async [`send`][bounded-send] or [`recv`][bounded-recv] methods, in
+//! synchronous code you will need to use the [`blocking_send`][blocking-send] or
+//! [`blocking_recv`][blocking-recv] methods.
//!
//! **Unbounded channel**: You should use the kind of channel that matches where
//! the receiver is. So for sending a message _from async to sync_, you should
@@ -57,9 +62,13 @@
//!
//! [`Sender`]: crate::sync::mpsc::Sender
//! [`Receiver`]: crate::sync::mpsc::Receiver
+//! [`Stream`]: crate::stream::Stream
//! [bounded-send]: crate::sync::mpsc::Sender::send()
//! [bounded-recv]: crate::sync::mpsc::Receiver::recv()
+//! [blocking-send]: crate::sync::mpsc::Sender::blocking_send()
+//! [blocking-recv]: crate::sync::mpsc::Receiver::blocking_recv()
//! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender
+//! [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver
//! [`Handle::block_on`]: crate::runtime::Handle::block_on()
//! [std-unbounded]: std::sync::mpsc::channel
//! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html
@@ -67,7 +76,7 @@
pub(super) mod block;
mod bounded;
-pub use self::bounded::{channel, Receiver, Sender};
+pub use self::bounded::{channel, Permit, Receiver, Sender};
mod chan;
diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs
index 1b2288a..fe882d5 100644
--- a/src/sync/mpsc/unbounded.rs
+++ b/src/sync/mpsc/unbounded.rs
@@ -47,7 +47,7 @@ impl<T> fmt::Debug for UnboundedReceiver<T> {
}
/// Creates an unbounded mpsc channel for communicating between asynchronous
-/// tasks.
+/// tasks without backpressure.
///
/// A `send` on this channel will always succeed as long as the receive half has
/// not been closed. If the receiver falls behind, messages will be arbitrarily
@@ -73,8 +73,7 @@ impl<T> UnboundedReceiver<T> {
UnboundedReceiver { chan }
}
- #[doc(hidden)] // TODO: doc
- pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
+ fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
self.chan.recv(cx)
}
@@ -174,7 +173,97 @@ impl<T> UnboundedSender<T> {
/// [`close`]: UnboundedReceiver::close
/// [`UnboundedReceiver`]: UnboundedReceiver
pub fn send(&self, message: T) -> Result<(), SendError<T>> {
- self.chan.send_unbounded(message)?;
+ if !self.inc_num_messages() {
+ return Err(SendError(message));
+ }
+
+ self.chan.send(message);
Ok(())
}
+
+ fn inc_num_messages(&self) -> bool {
+ use std::process;
+ use std::sync::atomic::Ordering::{AcqRel, Acquire};
+
+ let mut curr = self.chan.semaphore().load(Acquire);
+
+ loop {
+ if curr & 1 == 1 {
+ return false;
+ }
+
+ if curr == usize::MAX ^ 1 {
+ // Overflowed the ref count. There is no safe way to recover, so
+ // abort the process. In practice, this should never happen.
+ process::abort()
+ }
+
+ match self
+ .chan
+ .semaphore()
+ .compare_exchange(curr, curr + 2, AcqRel, Acquire)
+ {
+ Ok(_) => return true,
+ Err(actual) => {
+ curr = actual;
+ }
+ }
+ }
+ }
+
+ /// Completes when the receiver has dropped.
+ ///
+ /// This allows the producers to get notified when interest in the produced
+ /// values is canceled and immediately stop doing work.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx1, rx) = mpsc::unbounded_channel::<()>();
+ /// let tx2 = tx1.clone();
+ /// let tx3 = tx1.clone();
+ /// let tx4 = tx1.clone();
+ /// let tx5 = tx1.clone();
+ /// tokio::spawn(async move {
+ /// drop(rx);
+ /// });
+ ///
+ /// futures::join!(
+ /// tx1.closed(),
+ /// tx2.closed(),
+ /// tx3.closed(),
+ /// tx4.closed(),
+ /// tx5.closed()
+ /// );
+ //// println!("Receiver dropped");
+ /// }
+ /// ```
+ pub async fn closed(&self) {
+ self.chan.closed().await
+ }
+ /// Checks if the channel has been closed. This happens when the
+ /// [`UnboundedReceiver`] is dropped, or when the
+ /// [`UnboundedReceiver::close`] method is called.
+ ///
+ /// [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver
+ /// [`UnboundedReceiver::close`]: crate::sync::mpsc::UnboundedReceiver::close
+ ///
+ /// ```
+ /// let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>();
+ /// assert!(!tx.is_closed());
+ ///
+ /// let tx2 = tx.clone();
+ /// assert!(!tx2.is_closed());
+ ///
+ /// drop(rx);
+ /// assert!(tx.is_closed());
+ /// assert!(tx2.is_closed());
+ /// ```
+ pub fn is_closed(&self) -> bool {
+ self.chan.is_closed()
+ }
}
diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs
index 642058b..21e44ca 100644
--- a/src/sync/mutex.rs
+++ b/src/sync/mutex.rs
@@ -1,3 +1,5 @@
+#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
+
use crate::sync::batch_semaphore as semaphore;
use std::cell::UnsafeCell;
@@ -115,7 +117,6 @@ use std::sync::Arc;
/// [`std::sync::Mutex`]: struct@std::sync::Mutex
/// [`Send`]: trait@std::marker::Send
/// [`lock`]: method@Mutex::lock
-#[derive(Debug)]
pub struct Mutex<T: ?Sized> {
s: semaphore::Semaphore,
c: UnsafeCell<T>,
@@ -220,6 +221,27 @@ impl<T: ?Sized> Mutex<T> {
}
}
+ /// Creates a new lock in an unlocked state ready for use.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Mutex;
+ ///
+ /// static LOCK: Mutex<i32> = Mutex::const_new(5);
+ /// ```
+ #[cfg(all(feature = "parking_lot", not(all(loom, test)),))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
+ pub const fn const_new(t: T) -> Self
+ where
+ T: Sized,
+ {
+ Self {
+ c: UnsafeCell::new(t),
+ s: semaphore::Semaphore::const_new(1),
+ }
+ }
+
/// Locks this mutex, causing the current task
/// to yield until the lock has been acquired.
/// When the lock has been acquired, function returns a [`MutexGuard`].
@@ -305,6 +327,30 @@ impl<T: ?Sized> Mutex<T> {
}
}
+ /// Returns a mutable reference to the underlying data.
+ ///
+ /// Since this call borrows the `Mutex` mutably, no actual locking needs to
+ /// take place -- the mutable borrow statically guarantees no locks exist.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Mutex;
+ ///
+ /// fn main() {
+ /// let mut mutex = Mutex::new(1);
+ ///
+ /// let n = mutex.get_mut();
+ /// *n = 2;
+ /// }
+ /// ```
+ pub fn get_mut(&mut self) -> &mut T {
+ unsafe {
+ // Safety: This is https://github.com/rust-lang/rust/pull/76936
+ &mut *self.c.get()
+ }
+ }
+
/// Attempts to acquire the lock, and returns [`TryLockError`] if the lock
/// is currently held somewhere else.
///
@@ -373,6 +419,20 @@ where
}
}
+impl<T> std::fmt::Debug for Mutex<T>
+where
+ T: std::fmt::Debug,
+{
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let mut d = f.debug_struct("Mutex");
+ match self.try_lock() {
+ Ok(inner) => d.field("data", &*inner),
+ Err(_) => d.field("data", &format_args!("<locked>")),
+ };
+ d.finish()
+ }
+}
+
// === impl MutexGuard ===
impl<T: ?Sized> Drop for MutexGuard<'_, T> {
diff --git a/src/sync/notify.rs b/src/sync/notify.rs
index 5cb41e8..922f109 100644
--- a/src/sync/notify.rs
+++ b/src/sync/notify.rs
@@ -1,3 +1,10 @@
+// Allow `unreachable_pub` warnings when sync is not enabled
+// due to the usage of `Notify` within the `rt` feature set.
+// When this module is compiled with `sync` enabled we will warn on
+// this lint. When `rt` is enabled we use `pub(crate)` which
+// triggers this warning but it is safe to ignore in this case.
+#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
+
use crate::loom::sync::atomic::AtomicU8;
use crate::loom::sync::Mutex;
use crate::util::linked_list::{self, LinkedList};
@@ -10,6 +17,8 @@ use std::ptr::NonNull;
use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};
+type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
+
/// Notify a single task to wake up.
///
/// `Notify` provides a basic mechanism to notify a single task of an event.
@@ -17,20 +26,20 @@ use std::task::{Context, Poll, Waker};
/// another task to perform an operation.
///
/// `Notify` can be thought of as a [`Semaphore`] starting with 0 permits.
-/// [`notified().await`] waits for a permit to become available, and [`notify()`]
+/// [`notified().await`] waits for a permit to become available, and [`notify_one()`]
/// sets a permit **if there currently are no available permits**.
///
/// The synchronization details of `Notify` are similar to
/// [`thread::park`][park] and [`Thread::unpark`][unpark] from std. A [`Notify`]
/// value contains a single permit. [`notified().await`] waits for the permit to
-/// be made available, consumes the permit, and resumes. [`notify()`] sets the
+/// be made available, consumes the permit, and resumes. [`notify_one()`] sets the
/// permit, waking a pending task if there is one.
///
-/// If `notify()` is called **before** `notfied().await`, then the next call to
+/// If `notify_one()` is called **before** `notified().await`, then the next call to
/// `notified().await` will complete immediately, consuming the permit. Any
/// subsequent calls to `notified().await` will wait for a new permit.
///
-/// If `notify()` is called **multiple** times before `notified().await`, only a
+/// If `notify_one()` is called **multiple** times before `notified().await`, only a
/// **single** permit is stored. The next call to `notified().await` will
/// complete immediately, but the one after will wait for a new permit.
///
@@ -53,7 +62,7 @@ use std::task::{Context, Poll, Waker};
/// });
///
/// println!("sending notification");
-/// notify.notify();
+/// notify.notify_one();
/// }
/// ```
///
@@ -76,7 +85,7 @@ use std::task::{Context, Poll, Waker};
/// .push_back(value);
///
/// // Notify the consumer a value is available
-/// self.notify.notify();
+/// self.notify.notify_one();
/// }
///
/// pub async fn recv(&self) -> T {
@@ -96,12 +105,20 @@ use std::task::{Context, Poll, Waker};
/// [park]: std::thread::park
/// [unpark]: std::thread::Thread::unpark
/// [`notified().await`]: Notify::notified()
-/// [`notify()`]: Notify::notify()
+/// [`notify_one()`]: Notify::notify_one()
/// [`Semaphore`]: crate::sync::Semaphore
#[derive(Debug)]
pub struct Notify {
state: AtomicU8,
- waiters: Mutex<LinkedList<Waiter>>,
+ waiters: Mutex<WaitList>,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum NotificationType {
+ // Notification triggered by calling `notify_waiters`
+ AllWaiters,
+ // Notification triggered by calling `notify_one`
+ OneWaiter,
}
#[derive(Debug)]
@@ -113,7 +130,7 @@ struct Waiter {
waker: Option<Waker>,
/// `true` if the notification has been assigned to this waiter.
- notified: bool,
+ notified: Option<NotificationType>,
/// Should not be `Unpin`.
_p: PhantomPinned,
@@ -121,7 +138,7 @@ struct Waiter {
/// Future returned from `notified()`
#[derive(Debug)]
-struct Notified<'a> {
+pub struct Notified<'a> {
/// The `Notify` being received on.
notify: &'a Notify,
@@ -168,14 +185,38 @@ impl Notify {
}
}
+ /// Create a new `Notify`, initialized without a permit.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Notify;
+ ///
+ /// static NOTIFY: Notify = Notify::const_new();
+ /// ```
+ #[cfg(all(feature = "parking_lot", not(all(loom, test))))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
+ pub const fn const_new() -> Notify {
+ Notify {
+ state: AtomicU8::new(0),
+ waiters: Mutex::const_new(LinkedList::new()),
+ }
+ }
+
/// Wait for a notification.
///
+ /// Equivalent to:
+ ///
+ /// ```ignore
+ /// async fn notified(&self);
+ /// ```
+ ///
/// Each `Notify` value holds a single permit. If a permit is available from
- /// an earlier call to [`notify()`], then `notified().await` will complete
+ /// an earlier call to [`notify_one()`], then `notified().await` will complete
/// immediately, consuming that permit. Otherwise, `notified().await` waits
- /// for a permit to be made available by the next call to `notify()`.
+ /// for a permit to be made available by the next call to `notify_one()`.
///
- /// [`notify()`]: Notify::notify
+ /// [`notify_one()`]: Notify::notify_one
///
/// # Examples
///
@@ -194,21 +235,20 @@ impl Notify {
/// });
///
/// println!("sending notification");
- /// notify.notify();
+ /// notify.notify_one();
/// }
/// ```
- pub async fn notified(&self) {
+ pub fn notified(&self) -> Notified<'_> {
Notified {
notify: self,
state: State::Init,
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
- notified: false,
+ notified: None,
_p: PhantomPinned,
}),
}
- .await
}
/// Notifies a waiting task
@@ -216,10 +256,10 @@ impl Notify {
/// If a task is currently waiting, that task is notified. Otherwise, a
/// permit is stored in this `Notify` value and the **next** call to
/// [`notified().await`] will complete immediately consuming the permit made
- /// available by this call to `notify()`.
+ /// available by this call to `notify_one()`.
///
/// At most one permit may be stored by `Notify`. Many sequential calls to
- /// `notify` will result in a single permit being stored. The next call to
+ /// `notify_one` will result in a single permit being stored. The next call to
/// `notified().await` will complete immediately, but the one after that
/// will wait.
///
@@ -242,10 +282,10 @@ impl Notify {
/// });
///
/// println!("sending notification");
- /// notify.notify();
+ /// notify.notify_one();
/// }
/// ```
- pub fn notify(&self) {
+ pub fn notify_one(&self) {
// Load the current state
let mut curr = self.state.load(SeqCst);
@@ -266,7 +306,7 @@ impl Notify {
}
// There are waiters, the lock must be acquired to notify.
- let mut waiters = self.waiters.lock().unwrap();
+ let mut waiters = self.waiters.lock();
// The state must be reloaded while the lock is held. The state may only
// transition out of WAITING while the lock is held.
@@ -277,6 +317,45 @@ impl Notify {
waker.wake();
}
}
+
+ /// Notifies all waiting tasks
+ pub(crate) fn notify_waiters(&self) {
+ // There are waiters, the lock must be acquired to notify.
+ let mut waiters = self.waiters.lock();
+
+ // The state must be reloaded while the lock is held. The state may only
+ // transition out of WAITING while the lock is held.
+ let curr = self.state.load(SeqCst);
+
+ if let EMPTY | NOTIFIED = curr {
+ // There are no waiting tasks. In this case, no synchronization is
+ // established between `notify` and `notified().await`.
+ return;
+ }
+
+ // At this point, it is guaranteed that the state will not
+ // concurrently change, as holding the lock is required to
+ // transition **out** of `WAITING`.
+ //
+ // Get pending waiters
+ while let Some(mut waiter) = waiters.pop_back() {
+ // Safety: `waiters` lock is still held.
+ let waiter = unsafe { waiter.as_mut() };
+
+ assert!(waiter.notified.is_none());
+
+ waiter.notified = Some(NotificationType::AllWaiters);
+
+ if let Some(waker) = waiter.waker.take() {
+ waker.wake();
+ }
+ }
+
+ // All waiters have been notified, the state must be transitioned to
+ // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
+ // held, a `store` is sufficient.
+ self.state.store(EMPTY, SeqCst);
+ }
}
impl Default for Notify {
@@ -285,7 +364,7 @@ impl Default for Notify {
}
}
-fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) -> Option<Waker> {
+fn notify_locked(waiters: &mut WaitList, state: &AtomicU8, curr: u8) -> Option<Waker> {
loop {
match curr {
EMPTY | NOTIFIED => {
@@ -311,9 +390,9 @@ fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) -
// Safety: `waiters` lock is still held.
let waiter = unsafe { waiter.as_mut() };
- assert!(!waiter.notified);
+ assert!(waiter.notified.is_none());
- waiter.notified = true;
+ waiter.notified = Some(NotificationType::OneWaiter);
let waker = waiter.waker.take();
if waiters.is_empty() {
@@ -373,7 +452,7 @@ impl Future for Notified<'_> {
// Acquire the lock and attempt to transition to the waiting
// state.
- let mut waiters = notify.waiters.lock().unwrap();
+ let mut waiters = notify.waiters.lock();
// Reload the state with the lock held
let mut curr = notify.state.load(SeqCst);
@@ -428,6 +507,8 @@ impl Future for Notified<'_> {
waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) });
*state = Waiting;
+
+ return Poll::Pending;
}
Waiting => {
// Currently in the "Waiting" state, implying the caller has
@@ -435,16 +516,16 @@ impl Future for Notified<'_> {
// `notify.waiters`). In order to access the waker fields,
// we must hold the lock.
- let waiters = notify.waiters.lock().unwrap();
+ let waiters = notify.waiters.lock();
// Safety: called while locked
let w = unsafe { &mut *waiter.get() };
- if w.notified {
+ if w.notified.is_some() {
// Our waker has been notified. Reset the fields and
// remove it from the list.
w.waker = None;
- w.notified = false;
+ w.notified = None;
*state = Done;
} else {
@@ -483,12 +564,12 @@ impl Drop for Notified<'_> {
// longer stored in the linked list.
if let Waiting = *state {
let mut notify_state = WAITING;
- let mut waiters = notify.waiters.lock().unwrap();
+ let mut waiters = notify.waiters.lock();
// `Notify.state` may be in any of the three states (Empty, Waiting,
// Notified). It doesn't actually matter what the atomic is set to
// at this point. We hold the lock and will ensure the atomic is in
- // the correct state once th elock is dropped.
+ // the correct state once the lock is dropped.
//
// Because the atomic state is not checked, at first glance, it may
// seem like this routine does not handle the case where the
@@ -516,14 +597,13 @@ impl Drop for Notified<'_> {
notify.state.store(EMPTY, SeqCst);
}
- // See if the node was notified but not received. In this case, the
- // notification must be sent to another waiter.
+ // See if the node was notified but not received. In this case, if
+ // the notification was triggered via `notify_one`, it must be sent
+ // to the next waiter.
//
// Safety: with the entry removed from the linked list, there can be
// no concurrent access to the entry
- let notified = unsafe { (*waiter.get()).notified };
-
- if notified {
+ if let Some(NotificationType::OneWaiter) = unsafe { (*waiter.get()).notified } {
if let Some(waker) = notify_locked(&mut waiters, &notify.state, notify_state) {
drop(waiters);
waker.wake();
diff --git a/src/sync/oneshot.rs b/src/sync/oneshot.rs
index 17767e7..951ab71 100644
--- a/src/sync/oneshot.rs
+++ b/src/sync/oneshot.rs
@@ -124,7 +124,6 @@ struct State(usize);
/// }
/// ```
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
- #[allow(deprecated)]
let inner = Arc::new(Inner {
state: AtomicUsize::new(State::new().as_usize()),
value: UnsafeCell::new(None),
@@ -197,8 +196,7 @@ impl<T> Sender<T> {
Ok(())
}
- #[doc(hidden)] // TODO: remove
- pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
+ fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
// Keep track of task budget
let coop = ready!(crate::coop::poll_proceed(cx));
diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs
index 3d2a2f7..a84c4c1 100644
--- a/src/sync/rwlock.rs
+++ b/src/sync/rwlock.rs
@@ -1,5 +1,8 @@
-use crate::sync::batch_semaphore::{AcquireError, Semaphore};
+use crate::sync::batch_semaphore::Semaphore;
use std::cell::UnsafeCell;
+use std::fmt;
+use std::marker;
+use std::mem;
use std::ops;
#[cfg(not(loom))]
@@ -8,7 +11,7 @@ const MAX_READS: usize = 32;
#[cfg(loom)]
const MAX_READS: usize = 10;
-/// An asynchronous reader-writer lock
+/// An asynchronous reader-writer lock.
///
/// This type of lock allows a number of readers or at most one writer at any
/// point in time. The write portion of this lock typically allows modification
@@ -83,10 +86,140 @@ pub struct RwLock<T: ?Sized> {
/// [`RwLock`].
///
/// [`read`]: method@RwLock::read
-#[derive(Debug)]
+/// [`RwLock`]: struct@RwLock
pub struct RwLockReadGuard<'a, T: ?Sized> {
- permit: ReleasingPermit<'a, T>,
- lock: &'a RwLock<T>,
+ s: &'a Semaphore,
+ data: *const T,
+ marker: marker::PhantomData<&'a T>,
+}
+
+impl<'a, T> RwLockReadGuard<'a, T> {
+ /// Make a new `RwLockReadGuard` for a component of the locked data.
+ ///
+ /// This operation cannot fail as the `RwLockReadGuard` passed in already
+ /// locked the data.
+ ///
+ /// This is an associated function that needs to be
+ /// used as `RwLockReadGuard::map(...)`. A method would interfere with
+ /// methods of the same name on the contents of the locked data.
+ ///
+ /// This is an asynchronous version of [`RwLockReadGuard::map`] from the
+ /// [`parking_lot` crate].
+ ///
+ /// [`RwLockReadGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.map
+ /// [`parking_lot` crate]: https://crates.io/crates/parking_lot
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{RwLock, RwLockReadGuard};
+ ///
+ /// #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+ /// struct Foo(u32);
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = RwLock::new(Foo(1));
+ ///
+ /// let guard = lock.read().await;
+ /// let guard = RwLockReadGuard::map(guard, |f| &f.0);
+ ///
+ /// assert_eq!(1, *guard);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn map<F, U: ?Sized>(this: Self, f: F) -> RwLockReadGuard<'a, U>
+ where
+ F: FnOnce(&T) -> &U,
+ {
+ let data = f(&*this) as *const U;
+ let s = this.s;
+ // NB: Forget to avoid drop impl from being called.
+ mem::forget(this);
+ RwLockReadGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ }
+ }
+
+ /// Attempts to make a new [`RwLockReadGuard`] for a component of the
+ /// locked data. The original guard is returned if the closure returns
+ /// `None`.
+ ///
+ /// This operation cannot fail as the `RwLockReadGuard` passed in already
+ /// locked the data.
+ ///
+ /// This is an associated function that needs to be used as
+ /// `RwLockReadGuard::try_map(..)`. A method would interfere with methods of the
+ /// same name on the contents of the locked data.
+ ///
+ /// This is an asynchronous version of [`RwLockReadGuard::try_map`] from the
+ /// [`parking_lot` crate].
+ ///
+ /// [`RwLockReadGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.try_map
+ /// [`parking_lot` crate]: https://crates.io/crates/parking_lot
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{RwLock, RwLockReadGuard};
+ ///
+ /// #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+ /// struct Foo(u32);
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = RwLock::new(Foo(1));
+ ///
+ /// let guard = lock.read().await;
+ /// let guard = RwLockReadGuard::try_map(guard, |f| Some(&f.0)).expect("should not fail");
+ ///
+ /// assert_eq!(1, *guard);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn try_map<F, U: ?Sized>(this: Self, f: F) -> Result<RwLockReadGuard<'a, U>, Self>
+ where
+ F: FnOnce(&T) -> Option<&U>,
+ {
+ let data = match f(&*this) {
+ Some(data) => data as *const U,
+ None => return Err(this),
+ };
+ let s = this.s;
+ // NB: Forget to avoid drop impl from being called.
+ mem::forget(this);
+ Ok(RwLockReadGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ })
+ }
+}
+
+impl<'a, T: ?Sized> fmt::Debug for RwLockReadGuard<'a, T>
+where
+ T: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Debug::fmt(&**self, f)
+ }
+}
+
+impl<'a, T: ?Sized> fmt::Display for RwLockReadGuard<'a, T>
+where
+ T: fmt::Display,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Display::fmt(&**self, f)
+ }
+}
+
+impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> {
+ fn drop(&mut self) {
+ self.s.release(1);
+ }
}
/// RAII structure used to release the exclusive write access of a lock when
@@ -97,32 +230,195 @@ pub struct RwLockReadGuard<'a, T: ?Sized> {
///
/// [`write`]: method@RwLock::write
/// [`RwLock`]: struct@RwLock
-#[derive(Debug)]
pub struct RwLockWriteGuard<'a, T: ?Sized> {
- permit: ReleasingPermit<'a, T>,
- lock: &'a RwLock<T>,
+ s: &'a Semaphore,
+ data: *mut T,
+ marker: marker::PhantomData<&'a mut T>,
}
-// Wrapper arround Permit that releases on Drop
-#[derive(Debug)]
-struct ReleasingPermit<'a, T: ?Sized> {
- num_permits: u16,
- lock: &'a RwLock<T>,
+impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> {
+ /// Make a new `RwLockWriteGuard` for a component of the locked data.
+ ///
+ /// This operation cannot fail as the `RwLockWriteGuard` passed in already
+ /// locked the data.
+ ///
+ /// This is an associated function that needs to be used as
+ /// `RwLockWriteGuard::map(..)`. A method would interfere with methods of
+ /// the same name on the contents of the locked data.
+ ///
+ /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the
+ /// [`parking_lot` crate].
+ ///
+ /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map
+ /// [`parking_lot` crate]: https://crates.io/crates/parking_lot
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{RwLock, RwLockWriteGuard};
+ ///
+ /// #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+ /// struct Foo(u32);
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = RwLock::new(Foo(1));
+ ///
+ /// {
+ /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0);
+ /// *mapped = 2;
+ /// }
+ ///
+ /// assert_eq!(Foo(2), *lock.read().await);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockWriteGuard<'a, U>
+ where
+ F: FnOnce(&mut T) -> &mut U,
+ {
+ let data = f(&mut *this) as *mut U;
+ let s = this.s;
+ // NB: Forget to avoid drop impl from being called.
+ mem::forget(this);
+ RwLockWriteGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ }
+ }
+
+ /// Attempts to make a new [`RwLockWriteGuard`] for a component of
+ /// the locked data. The original guard is returned if the closure returns
+ /// `None`.
+ ///
+ /// This operation cannot fail as the `RwLockWriteGuard` passed in already
+ /// locked the data.
+ ///
+ /// This is an associated function that needs to be
+ /// used as `RwLockWriteGuard::try_map(...)`. A method would interfere with
+ /// methods of the same name on the contents of the locked data.
+ ///
+ /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from
+ /// the [`parking_lot` crate].
+ ///
+ /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map
+ /// [`parking_lot` crate]: https://crates.io/crates/parking_lot
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{RwLock, RwLockWriteGuard};
+ ///
+ /// #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+ /// struct Foo(u32);
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = RwLock::new(Foo(1));
+ ///
+ /// {
+ /// let guard = lock.write().await;
+ /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail");
+ /// *guard = 2;
+ /// }
+ ///
+ /// assert_eq!(Foo(2), *lock.read().await);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn try_map<F, U: ?Sized>(mut this: Self, f: F) -> Result<RwLockWriteGuard<'a, U>, Self>
+ where
+ F: FnOnce(&mut T) -> Option<&mut U>,
+ {
+ let data = match f(&mut *this) {
+ Some(data) => data as *mut U,
+ None => return Err(this),
+ };
+ let s = this.s;
+ // NB: Forget to avoid drop impl from being called.
+ mem::forget(this);
+ Ok(RwLockWriteGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ })
+ }
+
+ /// Atomically downgrades a write lock into a read lock without allowing
+ /// any writers to take exclusive access of the lock in the meantime.
+ ///
+ /// **Note:** This won't *necessarily* allow any additional readers to acquire
+ /// locks, since [`RwLock`] is fair and it is possible that a writer is next
+ /// in line.
+ ///
+ /// Returns an RAII guard which will drop the read access of this rwlock
+ /// when dropped.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::sync::RwLock;
+ /// # use std::sync::Arc;
+ /// #
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = Arc::new(RwLock::new(1));
+ ///
+ /// let n = lock.write().await;
+ ///
+ /// let cloned_lock = lock.clone();
+ /// let handle = tokio::spawn(async move {
+ /// *cloned_lock.write().await = 2;
+ /// });
+ ///
+ /// let n = n.downgrade();
+ /// assert_eq!(*n, 1, "downgrade is atomic");
+ ///
+ /// assert_eq!(*lock.read().await, 1, "additional readers can obtain locks");
+ ///
+ /// drop(n);
+ /// handle.await.unwrap();
+ /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock");
+ /// # }
+ /// ```
+ ///
+ /// [`RwLock`]: struct@RwLock
+ pub fn downgrade(self) -> RwLockReadGuard<'a, T> {
+ let RwLockWriteGuard { s, data, .. } = self;
+
+ // Release all but one of the permits held by the write guard
+ s.release(MAX_READS - 1);
+
+ RwLockReadGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ }
+ }
}
-impl<'a, T: ?Sized> ReleasingPermit<'a, T> {
- async fn acquire(
- lock: &'a RwLock<T>,
- num_permits: u16,
- ) -> Result<ReleasingPermit<'a, T>, AcquireError> {
- lock.s.acquire(num_permits.into()).await?;
- Ok(Self { num_permits, lock })
+impl<'a, T: ?Sized> fmt::Debug for RwLockWriteGuard<'a, T>
+where
+ T: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Debug::fmt(&**self, f)
}
}
-impl<T: ?Sized> Drop for ReleasingPermit<'_, T> {
+impl<'a, T: ?Sized> fmt::Display for RwLockWriteGuard<'a, T>
+where
+ T: fmt::Display,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Display::fmt(&**self, f)
+ }
+}
+
+impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> {
fn drop(&mut self) {
- self.lock.s.release(self.num_permits as usize);
+ self.s.release(MAX_READS);
}
}
@@ -139,9 +435,11 @@ fn bounds() {
check_sync::<RwLock<u32>>();
check_unpin::<RwLock<u32>>();
+ check_send::<RwLockReadGuard<'_, u32>>();
check_sync::<RwLockReadGuard<'_, u32>>();
check_unpin::<RwLockReadGuard<'_, u32>>();
+ check_send::<RwLockWriteGuard<'_, u32>>();
check_sync::<RwLockWriteGuard<'_, u32>>();
check_unpin::<RwLockWriteGuard<'_, u32>>();
@@ -155,8 +453,17 @@ fn bounds() {
// RwLock<T>.
unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {}
unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {}
+// NB: These impls need to be explicit since we're storing a raw pointer.
+// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over
+// `T` is `Send`.
+unsafe impl<T> Send for RwLockReadGuard<'_, T> where T: ?Sized + Sync {}
unsafe impl<T> Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {}
unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {}
+// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over
+// `T` is `Send` - but since this is also provides mutable access, we need to
+// make sure that `T` is `Send` since its value can be sent across thread
+// boundaries.
+unsafe impl<T> Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {}
impl<T: ?Sized> RwLock<T> {
/// Creates a new instance of an `RwLock<T>` which is unlocked.
@@ -178,6 +485,27 @@ impl<T: ?Sized> RwLock<T> {
}
}
+ /// Creates a new instance of an `RwLock<T>` which is unlocked.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::RwLock;
+ ///
+ /// static LOCK: RwLock<i32> = RwLock::const_new(5);
+ /// ```
+ #[cfg(all(feature = "parking_lot", not(all(loom, test))))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
+ pub const fn const_new(value: T) -> RwLock<T>
+ where
+ T: Sized,
+ {
+ RwLock {
+ c: UnsafeCell::new(value),
+ s: Semaphore::const_new(MAX_READS),
+ }
+ }
+
/// Locks this rwlock with shared read access, causing the current task
/// to yield until the lock has been acquired.
///
@@ -210,12 +538,16 @@ impl<T: ?Sized> RwLock<T> {
///}
/// ```
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
- let permit = ReleasingPermit::acquire(self, 1).await.unwrap_or_else(|_| {
+ self.s.acquire(1).await.unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
// handle to it through the Arc, which means that this can never happen.
unreachable!()
});
- RwLockReadGuard { lock: self, permit }
+ RwLockReadGuard {
+ s: &self.s,
+ data: self.c.get(),
+ marker: marker::PhantomData,
+ }
}
/// Locks this rwlock with exclusive write access, causing the current task
@@ -241,15 +573,40 @@ impl<T: ?Sized> RwLock<T> {
///}
/// ```
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
- let permit = ReleasingPermit::acquire(self, MAX_READS as u16)
- .await
- .unwrap_or_else(|_| {
- // The semaphore was closed. but, we never explicitly close it, and we have a
- // handle to it through the Arc, which means that this can never happen.
- unreachable!()
- });
-
- RwLockWriteGuard { lock: self, permit }
+ self.s.acquire(MAX_READS as u32).await.unwrap_or_else(|_| {
+ // The semaphore was closed. but, we never explicitly close it, and we have a
+ // handle to it through the Arc, which means that this can never happen.
+ unreachable!()
+ });
+ RwLockWriteGuard {
+ s: &self.s,
+ data: self.c.get(),
+ marker: marker::PhantomData,
+ }
+ }
+
+ /// Returns a mutable reference to the underlying data.
+ ///
+ /// Since this call borrows the `RwLock` mutably, no actual locking needs to
+ /// take place -- the mutable borrow statically guarantees no locks exist.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::RwLock;
+ ///
+ /// fn main() {
+ /// let mut lock = RwLock::new(1);
+ ///
+ /// let n = lock.get_mut();
+ /// *n = 2;
+ /// }
+ /// ```
+ pub fn get_mut(&mut self) -> &mut T {
+ unsafe {
+ // Safety: This is https://github.com/rust-lang/rust/pull/76936
+ &mut *self.c.get()
+ }
}
/// Consumes the lock, returning the underlying data.
@@ -265,7 +622,7 @@ impl<T: ?Sized> ops::Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
- unsafe { &*self.lock.c.get() }
+ unsafe { &*self.data }
}
}
@@ -273,13 +630,13 @@ impl<T: ?Sized> ops::Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
- unsafe { &*self.lock.c.get() }
+ unsafe { &*self.data }
}
}
impl<T: ?Sized> ops::DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
- unsafe { &mut *self.lock.c.get() }
+ unsafe { &mut *self.data }
}
}
@@ -289,7 +646,7 @@ impl<T> From<T> for RwLock<T> {
}
}
-impl<T> Default for RwLock<T>
+impl<T: ?Sized> Default for RwLock<T>
where
T: Default,
{
diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs
index 2489d34..43dd976 100644
--- a/src/sync/semaphore.rs
+++ b/src/sync/semaphore.rs
@@ -1,7 +1,7 @@
use super::batch_semaphore as ll; // low level implementation
use std::sync::Arc;
-/// Counting semaphore performing asynchronous permit aquisition.
+/// Counting semaphore performing asynchronous permit acquisition.
///
/// A semaphore maintains a set of permits. Permits are used to synchronize
/// access to a shared resource. A semaphore differs from a mutex in that it
@@ -74,6 +74,15 @@ impl Semaphore {
}
}
+ /// Creates a new semaphore with the initial number of permits.
+ #[cfg(all(feature = "parking_lot", not(all(loom, test))))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
+ pub const fn const_new(permits: usize) -> Self {
+ Self {
+ ll_sem: ll::Semaphore::const_new(permits),
+ }
+ }
+
/// Returns the current number of available permits.
pub fn available_permits(&self) -> usize {
self.ll_sem.available_permits()
@@ -114,7 +123,7 @@ impl Semaphore {
pub async fn acquire_owned(self: Arc<Self>) -> OwnedSemaphorePermit {
self.ll_sem.acquire(1).await.unwrap();
OwnedSemaphorePermit {
- sem: self.clone(),
+ sem: self,
permits: 1,
}
}
@@ -127,7 +136,7 @@ impl Semaphore {
pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> {
match self.ll_sem.try_acquire(1) {
Ok(_) => Ok(OwnedSemaphorePermit {
- sem: self.clone(),
+ sem: self,
permits: 1,
}),
Err(_) => Err(TryAcquireError(())),
diff --git a/src/sync/semaphore_ll.rs b/src/sync/semaphore_ll.rs
deleted file mode 100644
index 25d25ac..0000000
--- a/src/sync/semaphore_ll.rs
+++ /dev/null
@@ -1,1221 +0,0 @@
-#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
-
-//! Thread-safe, asynchronous counting semaphore.
-//!
-//! A `Semaphore` instance holds a set of permits. Permits are used to
-//! synchronize access to a shared resource.
-//!
-//! Before accessing the shared resource, callers acquire a permit from the
-//! semaphore. Once the permit is acquired, the caller then enters the critical
-//! section. If no permits are available, then acquiring the semaphore returns
-//! `Pending`. The task is woken once a permit becomes available.
-
-use crate::loom::cell::UnsafeCell;
-use crate::loom::future::AtomicWaker;
-use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
-use crate::loom::thread;
-
-use std::cmp;
-use std::fmt;
-use std::ptr::{self, NonNull};
-use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
-use std::task::Poll::{Pending, Ready};
-use std::task::{Context, Poll};
-use std::usize;
-
-/// Futures-aware semaphore.
-pub(crate) struct Semaphore {
- /// Tracks both the waiter queue tail pointer and the number of remaining
- /// permits.
- state: AtomicUsize,
-
- /// waiter queue head pointer.
- head: UnsafeCell<NonNull<Waiter>>,
-
- /// Coordinates access to the queue head.
- rx_lock: AtomicUsize,
-
- /// Stub waiter node used as part of the MPSC channel algorithm.
- stub: Box<Waiter>,
-}
-
-/// A semaphore permit
-///
-/// Tracks the lifecycle of a semaphore permit.
-///
-/// An instance of `Permit` is intended to be used with a **single** instance of
-/// `Semaphore`. Using a single instance of `Permit` with multiple semaphore
-/// instances will result in unexpected behavior.
-///
-/// `Permit` does **not** release the permit back to the semaphore on drop. It
-/// is the user's responsibility to ensure that `Permit::release` is called
-/// before dropping the permit.
-#[derive(Debug)]
-pub(crate) struct Permit {
- waiter: Option<Box<Waiter>>,
- state: PermitState,
-}
-
-/// Error returned by `Permit::poll_acquire`.
-#[derive(Debug)]
-pub(crate) struct AcquireError(());
-
-/// Error returned by `Permit::try_acquire`.
-#[derive(Debug)]
-pub(crate) enum TryAcquireError {
- Closed,
- NoPermits,
-}
-
-/// Node used to notify the semaphore waiter when permit is available.
-#[derive(Debug)]
-struct Waiter {
- /// Stores waiter state.
- ///
- /// See `WaiterState` for more details.
- state: AtomicUsize,
-
- /// Task to wake when a permit is made available.
- waker: AtomicWaker,
-
- /// Next pointer in the queue of waiting senders.
- next: AtomicPtr<Waiter>,
-}
-
-/// Semaphore state
-///
-/// The 2 low bits track the modes.
-///
-/// - Closed
-/// - Full
-///
-/// When not full, the rest of the `usize` tracks the total number of messages
-/// in the channel. When full, the rest of the `usize` is a pointer to the tail
-/// of the "waiting senders" queue.
-#[derive(Copy, Clone)]
-struct SemState(usize);
-
-/// Permit state
-#[derive(Debug, Copy, Clone)]
-enum PermitState {
- /// Currently waiting for permits to be made available and assigned to the
- /// waiter.
- Waiting(u16),
-
- /// The number of acquired permits
- Acquired(u16),
-}
-
-/// State for an individual waker node
-#[derive(Debug, Copy, Clone)]
-struct WaiterState(usize);
-
-/// Waiter node is in the semaphore queue
-const QUEUED: usize = 0b001;
-
-/// Semaphore has been closed, no more permits will be issued.
-const CLOSED: usize = 0b10;
-
-/// The permit that owns the `Waiter` dropped.
-const DROPPED: usize = 0b100;
-
-/// Represents "one requested permit" in the waiter state
-const PERMIT_ONE: usize = 0b1000;
-
-/// Masks the waiter state to only contain bits tracking number of requested
-/// permits.
-const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1);
-
-/// How much to shift a permit count to pack it into the waker state
-const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros();
-
-/// Flag differentiating between available permits and waiter pointers.
-///
-/// If we assume pointers are properly aligned, then the least significant bit
-/// will always be zero. So, we use that bit to track if the value represents a
-/// number.
-const NUM_FLAG: usize = 0b01;
-
-/// Signal the semaphore is closed
-const CLOSED_FLAG: usize = 0b10;
-
-/// Maximum number of permits a semaphore can manage
-const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
-
-/// When representing "numbers", the state has to be shifted this much (to get
-/// rid of the flag bit).
-const NUM_SHIFT: usize = 2;
-
-// ===== impl Semaphore =====
-
-impl Semaphore {
- /// Creates a new semaphore with the initial number of permits
- ///
- /// # Panics
- ///
- /// Panics if `permits` is zero.
- pub(crate) fn new(permits: usize) -> Semaphore {
- let stub = Box::new(Waiter::new());
- let ptr = NonNull::from(&*stub);
-
- // Allocations are aligned
- debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
-
- let state = SemState::new(permits, &stub);
-
- Semaphore {
- state: AtomicUsize::new(state.to_usize()),
- head: UnsafeCell::new(ptr),
- rx_lock: AtomicUsize::new(0),
- stub,
- }
- }
-
- /// Returns the current number of available permits
- pub(crate) fn available_permits(&self) -> usize {
- let curr = SemState(self.state.load(Acquire));
- curr.available_permits()
- }
-
- /// Tries to acquire the requested number of permits, registering the waiter
- /// if not enough permits are available.
- fn poll_acquire(
- &self,
- cx: &mut Context<'_>,
- num_permits: u16,
- permit: &mut Permit,
- ) -> Poll<Result<(), AcquireError>> {
- self.poll_acquire2(num_permits, || {
- let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new()));
-
- waiter.waker.register_by_ref(cx.waker());
-
- Some(NonNull::from(&**waiter))
- })
- }
-
- fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
- match self.poll_acquire2(num_permits, || None) {
- Poll::Ready(res) => res.map_err(to_try_acquire),
- Poll::Pending => Err(TryAcquireError::NoPermits),
- }
- }
-
- /// Polls for a permit
- ///
- /// Tries to acquire available permits first. If unable to acquire a
- /// sufficient number of permits, the caller's waiter is pushed onto the
- /// semaphore's wait queue.
- fn poll_acquire2<F>(
- &self,
- num_permits: u16,
- mut get_waiter: F,
- ) -> Poll<Result<(), AcquireError>>
- where
- F: FnMut() -> Option<NonNull<Waiter>>,
- {
- let num_permits = num_permits as usize;
-
- // Load the current state
- let mut curr = SemState(self.state.load(Acquire));
-
- // Saves a ref to the waiter node
- let mut maybe_waiter: Option<NonNull<Waiter>> = None;
-
- /// Used in branches where we attempt to push the waiter into the wait
- /// queue but fail due to permits becoming available or the wait queue
- /// transitioning to "closed". In this case, the waiter must be
- /// transitioned back to the "idle" state.
- macro_rules! revert_to_idle {
- () => {
- if let Some(waiter) = maybe_waiter {
- unsafe { waiter.as_ref() }.revert_to_idle();
- }
- };
- }
-
- loop {
- let mut next = curr;
-
- if curr.is_closed() {
- revert_to_idle!();
- return Ready(Err(AcquireError::closed()));
- }
-
- let acquired = next.acquire_permits(num_permits, &self.stub);
-
- if !acquired {
- // There are not enough available permits to satisfy the
- // request. The permit transitions to a waiting state.
- debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits);
-
- if let Some(waiter) = maybe_waiter.as_ref() {
- // Safety: the caller owns the waiter.
- let w = unsafe { waiter.as_ref() };
- w.set_permits_to_acquire(num_permits - curr.available_permits());
- } else {
- // Get the waiter for the permit.
- if let Some(waiter) = get_waiter() {
- // Safety: the caller owns the waiter.
- let w = unsafe { waiter.as_ref() };
-
- // If there are any currently available permits, the
- // waiter acquires those immediately and waits for the
- // remaining permits to become available.
- if !w.to_queued(num_permits - curr.available_permits()) {
- // The node is alrady queued, there is no further work
- // to do.
- return Pending;
- }
-
- maybe_waiter = Some(waiter);
- } else {
- // No waiter, this indicates the caller does not wish to
- // "wait", so there is nothing left to do.
- return Pending;
- }
- }
-
- next.set_waiter(maybe_waiter.unwrap());
- }
-
- debug_assert_ne!(curr.0, 0);
- debug_assert_ne!(next.0, 0);
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => {
- if acquired {
- // Successfully acquire permits **without** queuing the
- // waiter node. The waiter node is not currently in the
- // queue.
- revert_to_idle!();
- return Ready(Ok(()));
- } else {
- // The node is pushed into the queue, the final step is
- // to set the node's "next" pointer to return the wait
- // queue into a consistent state.
-
- let prev_waiter =
- curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub));
-
- let waiter = maybe_waiter.unwrap();
-
- // Link the nodes.
- //
- // Safety: the mpsc algorithm guarantees the old tail of
- // the queue is not removed from the queue during the
- // push process.
- unsafe {
- prev_waiter.as_ref().store_next(waiter);
- }
-
- return Pending;
- }
- }
- Err(actual) => {
- curr = SemState(actual);
- }
- }
- }
- }
-
- /// Closes the semaphore. This prevents the semaphore from issuing new
- /// permits and notifies all pending waiters.
- pub(crate) fn close(&self) {
- // Acquire the `rx_lock`, setting the "closed" flag on the lock.
- let prev = self.rx_lock.fetch_or(1, AcqRel);
-
- if prev != 0 {
- // Another thread has the lock and will be responsible for notifying
- // pending waiters.
- return;
- }
-
- self.add_permits_locked(0, true);
- }
- /// Adds `n` new permits to the semaphore.
- ///
- /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded.
- pub(crate) fn add_permits(&self, n: usize) {
- if n == 0 {
- return;
- }
-
- // TODO: Handle overflow. A panic is not sufficient, the process must
- // abort.
- let prev = self.rx_lock.fetch_add(n << 1, AcqRel);
-
- if prev != 0 {
- // Another thread has the lock and will be responsible for notifying
- // pending waiters.
- return;
- }
-
- self.add_permits_locked(n, false);
- }
-
- fn add_permits_locked(&self, mut rem: usize, mut closed: bool) {
- while rem > 0 || closed {
- if closed {
- SemState::fetch_set_closed(&self.state, AcqRel);
- }
-
- // Release the permits and notify
- self.add_permits_locked2(rem, closed);
-
- let n = rem << 1;
-
- let actual = if closed {
- let actual = self.rx_lock.fetch_sub(n | 1, AcqRel);
- closed = false;
- actual
- } else {
- let actual = self.rx_lock.fetch_sub(n, AcqRel);
- closed = actual & 1 == 1;
- actual
- };
-
- rem = (actual >> 1) - rem;
- }
- }
-
- /// Releases a specific amount of permits to the semaphore
- ///
- /// This function is called by `add_permits` after the add lock has been
- /// acquired.
- fn add_permits_locked2(&self, mut n: usize, closed: bool) {
- // If closing the semaphore, we want to drain the entire queue. The
- // number of permits being assigned doesn't matter.
- if closed {
- n = usize::MAX;
- }
-
- 'outer: while n > 0 {
- unsafe {
- let mut head = self.head.with(|head| *head);
- let mut next_ptr = head.as_ref().next.load(Acquire);
-
- let stub = self.stub();
-
- if head == stub {
- // The stub node indicates an empty queue. Any remaining
- // permits get assigned back to the semaphore.
- let next = match NonNull::new(next_ptr) {
- Some(next) => next,
- None => {
- // This loop is not part of the standard intrusive mpsc
- // channel algorithm. This is where we atomically pop
- // the last task and add `n` to the remaining capacity.
- //
- // This modification to the pop algorithm works because,
- // at this point, we have not done any work (only done
- // reading). We have a *pretty* good idea that there is
- // no concurrent pusher.
- //
- // The capacity is then atomically added by doing an
- // AcqRel CAS on `state`. The `state` cell is the
- // linchpin of the algorithm.
- //
- // By successfully CASing `head` w/ AcqRel, we ensure
- // that, if any thread was racing and entered a push, we
- // see that and abort pop, retrying as it is
- // "inconsistent".
- let mut curr = SemState::load(&self.state, Acquire);
-
- loop {
- if curr.has_waiter(&self.stub) {
- // A waiter is being added concurrently.
- // This is the MPSC queue's "inconsistent"
- // state and we must loop and try again.
- thread::yield_now();
- continue 'outer;
- }
-
- // If closing, nothing more to do.
- if closed {
- debug_assert!(curr.is_closed(), "state = {:?}", curr);
- return;
- }
-
- let mut next = curr;
- next.release_permits(n, &self.stub);
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => return,
- Err(actual) => {
- curr = SemState(actual);
- }
- }
- }
- }
- };
-
- self.head.with_mut(|head| *head = next);
- head = next;
- next_ptr = next.as_ref().next.load(Acquire);
- }
-
- // `head` points to a waiter assign permits to the waiter. If
- // all requested permits are satisfied, then we can continue,
- // otherwise the node stays in the wait queue.
- if !head.as_ref().assign_permits(&mut n, closed) {
- assert_eq!(n, 0);
- return;
- }
-
- if let Some(next) = NonNull::new(next_ptr) {
- self.head.with_mut(|head| *head = next);
-
- self.remove_queued(head, closed);
- continue 'outer;
- }
-
- let state = SemState::load(&self.state, Acquire);
-
- // This must always be a pointer as the wait list is not empty.
- let tail = state.waiter().unwrap();
-
- if tail != head {
- // Inconsistent
- thread::yield_now();
- continue 'outer;
- }
-
- self.push_stub(closed);
-
- next_ptr = head.as_ref().next.load(Acquire);
-
- if let Some(next) = NonNull::new(next_ptr) {
- self.head.with_mut(|head| *head = next);
-
- self.remove_queued(head, closed);
- continue 'outer;
- }
-
- // Inconsistent state, loop
- thread::yield_now();
- }
- }
- }
-
- /// The wait node has had all of its permits assigned and has been removed
- /// from the wait queue.
- ///
- /// Attempt to remove the QUEUED bit from the node. If additional permits
- /// are concurrently requested, the node must be pushed back into the wait
- /// queued.
- fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) {
- let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire));
-
- loop {
- if curr.is_dropped() {
- // The Permit dropped, it is on us to release the memory
- let _ = unsafe { Box::from_raw(waiter.as_ptr()) };
- return;
- }
-
- // The node is removed from the queue. We attempt to unset the
- // queued bit, but concurrently the waiter has requested more
- // permits. When the waiter requested more permits, it saw the
- // queued bit set so took no further action. This requires us to
- // push the node back into the queue.
- if curr.permits_to_acquire() > 0 {
- // More permits are requested. The waiter must be re-queued
- unsafe {
- self.push_waiter(waiter, closed);
- }
- return;
- }
-
- let mut next = curr;
- next.unset_queued();
-
- let w = unsafe { waiter.as_ref() };
-
- match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => return,
- Err(actual) => {
- curr = WaiterState(actual);
- }
- }
- }
- }
-
- unsafe fn push_stub(&self, closed: bool) {
- self.push_waiter(self.stub(), closed);
- }
-
- unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) {
- // Set the next pointer. This does not require an atomic operation as
- // this node is not accessible. The write will be flushed with the next
- // operation
- waiter.as_ref().next.store(ptr::null_mut(), Relaxed);
-
- // Update the tail to point to the new node. We need to see the previous
- // node in order to update the next pointer as well as release `task`
- // to any other threads calling `push`.
- let next = SemState::new_ptr(waiter, closed);
- let prev = SemState(self.state.swap(next.0, AcqRel));
-
- debug_assert_eq!(closed, prev.is_closed());
-
- // This function is only called when there are pending tasks. Because of
- // this, the state must *always* be in pointer mode.
- let prev = prev.waiter().unwrap();
-
- // No cycles plz
- debug_assert_ne!(prev, waiter);
-
- // Release `task` to the consume end.
- prev.as_ref().next.store(waiter.as_ptr(), Release);
- }
-
- fn stub(&self) -> NonNull<Waiter> {
- unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
- }
-}
-
-impl Drop for Semaphore {
- fn drop(&mut self) {
- self.close();
- }
-}
-
-impl fmt::Debug for Semaphore {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- fmt.debug_struct("Semaphore")
- .field("state", &SemState::load(&self.state, Relaxed))
- .field("head", &self.head.with(|ptr| ptr))
- .field("rx_lock", &self.rx_lock.load(Relaxed))
- .field("stub", &self.stub)
- .finish()
- }
-}
-
-unsafe impl Send for Semaphore {}
-unsafe impl Sync for Semaphore {}
-
-// ===== impl Permit =====
-
-impl Permit {
- /// Creates a new `Permit`.
- ///
- /// The permit begins in the "unacquired" state.
- pub(crate) fn new() -> Permit {
- use PermitState::Acquired;
-
- Permit {
- waiter: None,
- state: Acquired(0),
- }
- }
-
- /// Returns `true` if the permit has been acquired
- #[allow(dead_code)] // may be used later
- pub(crate) fn is_acquired(&self) -> bool {
- match self.state {
- PermitState::Acquired(num) if num > 0 => true,
- _ => false,
- }
- }
-
- /// Tries to acquire the permit. If no permits are available, the current task
- /// is notified once a new permit becomes available.
- pub(crate) fn poll_acquire(
- &mut self,
- cx: &mut Context<'_>,
- num_permits: u16,
- semaphore: &Semaphore,
- ) -> Poll<Result<(), AcquireError>> {
- use std::cmp::Ordering::*;
- use PermitState::*;
-
- match self.state {
- Waiting(requested) => {
- // There must be a waiter
- let waiter = self.waiter.as_ref().unwrap();
-
- match requested.cmp(&num_permits) {
- Less => {
- let delta = num_permits - requested;
-
- // Request additional permits. If the waiter has been
- // dequeued, it must be re-queued.
- if !waiter.try_inc_permits_to_acquire(delta as usize) {
- let waiter = NonNull::from(&**waiter);
-
- // Ignore the result. The check for
- // `permits_to_acquire()` will converge the state as
- // needed
- let _ = semaphore.poll_acquire2(delta, || Some(waiter))?;
- }
-
- self.state = Waiting(num_permits);
- }
- Greater => {
- let delta = requested - num_permits;
- let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
-
- semaphore.add_permits(to_release);
- self.state = Waiting(num_permits);
- }
- Equal => {}
- }
-
- if waiter.permits_to_acquire()? == 0 {
- self.state = Acquired(requested);
- return Ready(Ok(()));
- }
-
- waiter.waker.register_by_ref(cx.waker());
-
- if waiter.permits_to_acquire()? == 0 {
- self.state = Acquired(requested);
- return Ready(Ok(()));
- }
-
- Pending
- }
- Acquired(acquired) => {
- if acquired >= num_permits {
- Ready(Ok(()))
- } else {
- match semaphore.poll_acquire(cx, num_permits - acquired, self)? {
- Ready(()) => {
- self.state = Acquired(num_permits);
- Ready(Ok(()))
- }
- Pending => {
- self.state = Waiting(num_permits);
- Pending
- }
- }
- }
- }
- }
- }
-
- /// Tries to acquire the permit.
- pub(crate) fn try_acquire(
- &mut self,
- num_permits: u16,
- semaphore: &Semaphore,
- ) -> Result<(), TryAcquireError> {
- use PermitState::*;
-
- match self.state {
- Waiting(requested) => {
- // There must be a waiter
- let waiter = self.waiter.as_ref().unwrap();
-
- if requested > num_permits {
- let delta = requested - num_permits;
- let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
-
- semaphore.add_permits(to_release);
- self.state = Waiting(num_permits);
- }
-
- let res = waiter.permits_to_acquire().map_err(to_try_acquire)?;
-
- if res == 0 {
- if requested < num_permits {
- // Try to acquire the additional permits
- semaphore.try_acquire(num_permits - requested)?;
- }
-
- self.state = Acquired(num_permits);
- Ok(())
- } else {
- Err(TryAcquireError::NoPermits)
- }
- }
- Acquired(acquired) => {
- if acquired < num_permits {
- semaphore.try_acquire(num_permits - acquired)?;
- self.state = Acquired(num_permits);
- }
-
- Ok(())
- }
- }
- }
-
- /// Releases a permit back to the semaphore
- pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) {
- let n = self.forget(n);
- semaphore.add_permits(n as usize);
- }
-
- /// Forgets the permit **without** releasing it back to the semaphore.
- ///
- /// After calling `forget`, `poll_acquire` is able to acquire new permit
- /// from the semaphore.
- ///
- /// Repeatedly calling `forget` without associated calls to `add_permit`
- /// will result in the semaphore losing all permits.
- ///
- /// Will forget **at most** the number of acquired permits. This number is
- /// returned.
- pub(crate) fn forget(&mut self, n: u16) -> u16 {
- use PermitState::*;
-
- match self.state {
- Waiting(requested) => {
- let n = cmp::min(n, requested);
-
- // Decrement
- let acquired = self
- .waiter
- .as_ref()
- .unwrap()
- .try_dec_permits_to_acquire(n as usize) as u16;
-
- if n == requested {
- self.state = Acquired(0);
- } else if acquired == requested - n {
- self.state = Waiting(acquired);
- } else {
- self.state = Waiting(requested - n);
- }
-
- acquired
- }
- Acquired(acquired) => {
- let n = cmp::min(n, acquired);
- self.state = Acquired(acquired - n);
- n
- }
- }
- }
-}
-
-impl Default for Permit {
- fn default() -> Self {
- Self::new()
- }
-}
-
-impl Drop for Permit {
- fn drop(&mut self) {
- if let Some(waiter) = self.waiter.take() {
- // Set the dropped flag
- let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel));
-
- if state.is_queued() {
- // The waiter is stored in the queue. The semaphore will drop it
- std::mem::forget(waiter);
- }
- }
- }
-}
-
-// ===== impl AcquireError ====
-
-impl AcquireError {
- fn closed() -> AcquireError {
- AcquireError(())
- }
-}
-
-fn to_try_acquire(_: AcquireError) -> TryAcquireError {
- TryAcquireError::Closed
-}
-
-impl fmt::Display for AcquireError {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(fmt, "semaphore closed")
- }
-}
-
-impl std::error::Error for AcquireError {}
-
-// ===== impl TryAcquireError =====
-
-impl TryAcquireError {
- /// Returns `true` if the error was caused by a closed semaphore.
- pub(crate) fn is_closed(&self) -> bool {
- match self {
- TryAcquireError::Closed => true,
- _ => false,
- }
- }
-
- /// Returns `true` if the error was caused by calling `try_acquire` on a
- /// semaphore with no available permits.
- pub(crate) fn is_no_permits(&self) -> bool {
- match self {
- TryAcquireError::NoPermits => true,
- _ => false,
- }
- }
-}
-
-impl fmt::Display for TryAcquireError {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- TryAcquireError::Closed => write!(fmt, "semaphore closed"),
- TryAcquireError::NoPermits => write!(fmt, "no permits available"),
- }
- }
-}
-
-impl std::error::Error for TryAcquireError {}
-
-// ===== impl Waiter =====
-
-impl Waiter {
- fn new() -> Waiter {
- Waiter {
- state: AtomicUsize::new(0),
- waker: AtomicWaker::new(),
- next: AtomicPtr::new(ptr::null_mut()),
- }
- }
-
- fn permits_to_acquire(&self) -> Result<usize, AcquireError> {
- let state = WaiterState(self.state.load(Acquire));
-
- if state.is_closed() {
- Err(AcquireError(()))
- } else {
- Ok(state.permits_to_acquire())
- }
- }
-
- /// Only increments the number of permits *if* the waiter is currently
- /// queued.
- ///
- /// # Returns
- ///
- /// `true` if the number of permits to acquire has been incremented. `false`
- /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`.
- fn try_inc_permits_to_acquire(&self, n: usize) -> bool {
- let mut curr = WaiterState(self.state.load(Acquire));
-
- loop {
- if !curr.is_queued() {
- assert_eq!(0, curr.permits_to_acquire());
- return false;
- }
-
- let mut next = curr;
- next.set_permits_to_acquire(n + curr.permits_to_acquire());
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => return true,
- Err(actual) => curr = WaiterState(actual),
- }
- }
- }
-
- /// Try to decrement the number of permits to acquire. This returns the
- /// actual number of permits that were decremented. The delta betweeen `n`
- /// and the return has been assigned to the permit and the caller must
- /// assign these back to the semaphore.
- fn try_dec_permits_to_acquire(&self, n: usize) -> usize {
- let mut curr = WaiterState(self.state.load(Acquire));
-
- loop {
- if !curr.is_queued() {
- assert_eq!(0, curr.permits_to_acquire());
- }
-
- let delta = cmp::min(n, curr.permits_to_acquire());
- let rem = curr.permits_to_acquire() - delta;
-
- let mut next = curr;
- next.set_permits_to_acquire(rem);
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => return n - delta,
- Err(actual) => curr = WaiterState(actual),
- }
- }
- }
-
- /// Store the number of remaining permits needed to satisfy the waiter and
- /// transition to the "QUEUED" state.
- ///
- /// # Returns
- ///
- /// `true` if the `QUEUED` bit was set as part of the transition.
- fn to_queued(&self, num_permits: usize) -> bool {
- let mut curr = WaiterState(self.state.load(Acquire));
-
- // The waiter should **not** be waiting for any permits.
- debug_assert_eq!(curr.permits_to_acquire(), 0);
-
- loop {
- let mut next = curr;
- next.set_permits_to_acquire(num_permits);
- next.set_queued();
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => {
- if curr.is_queued() {
- return false;
- } else {
- // Make sure the next pointer is null
- self.next.store(ptr::null_mut(), Relaxed);
- return true;
- }
- }
- Err(actual) => curr = WaiterState(actual),
- }
- }
- }
-
- /// Set the number of permits to acquire.
- ///
- /// This function is only called when the waiter is being inserted into the
- /// wait queue. Because of this, there are no concurrent threads that can
- /// modify the state and using `store` is safe.
- fn set_permits_to_acquire(&self, num_permits: usize) {
- debug_assert!(WaiterState(self.state.load(Acquire)).is_queued());
-
- let mut state = WaiterState(QUEUED);
- state.set_permits_to_acquire(num_permits);
-
- self.state.store(state.0, Release);
- }
-
- /// Assign permits to the waiter.
- ///
- /// Returns `true` if the waiter should be removed from the queue
- fn assign_permits(&self, n: &mut usize, closed: bool) -> bool {
- let mut curr = WaiterState(self.state.load(Acquire));
-
- loop {
- let mut next = curr;
-
- // Number of permits to assign to this waiter
- let assign = cmp::min(curr.permits_to_acquire(), *n);
-
- // Assign the permits
- next.set_permits_to_acquire(curr.permits_to_acquire() - assign);
-
- if closed {
- next.set_closed();
- }
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => {
- // Update `n`
- *n -= assign;
-
- if next.permits_to_acquire() == 0 {
- if curr.permits_to_acquire() > 0 {
- self.waker.wake();
- }
-
- return true;
- } else {
- return false;
- }
- }
- Err(actual) => curr = WaiterState(actual),
- }
- }
- }
-
- fn revert_to_idle(&self) {
- // An idle node is not waiting on any permits
- self.state.store(0, Relaxed);
- }
-
- fn store_next(&self, next: NonNull<Waiter>) {
- self.next.store(next.as_ptr(), Release);
- }
-}
-
-// ===== impl SemState =====
-
-impl SemState {
- /// Returns a new default `State` value.
- fn new(permits: usize, stub: &Waiter) -> SemState {
- assert!(permits <= MAX_PERMITS);
-
- if permits > 0 {
- SemState((permits << NUM_SHIFT) | NUM_FLAG)
- } else {
- SemState(stub as *const _ as usize)
- }
- }
-
- /// Returns a `State` tracking `ptr` as the tail of the queue.
- fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState {
- let mut val = tail.as_ptr() as usize;
-
- if closed {
- val |= CLOSED_FLAG;
- }
-
- SemState(val)
- }
-
- /// Returns the amount of remaining capacity
- fn available_permits(self) -> usize {
- if !self.has_available_permits() {
- return 0;
- }
-
- self.0 >> NUM_SHIFT
- }
-
- /// Returns `true` if the state has permits that can be claimed by a waiter.
- fn has_available_permits(self) -> bool {
- self.0 & NUM_FLAG == NUM_FLAG
- }
-
- fn has_waiter(self, stub: &Waiter) -> bool {
- !self.has_available_permits() && !self.is_stub(stub)
- }
-
- /// Tries to atomically acquire specified number of permits.
- ///
- /// # Return
- ///
- /// Returns `true` if the specified number of permits were acquired, `false`
- /// otherwise. Returning false does not mean that there are no more
- /// available permits.
- fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool {
- debug_assert!(num > 0);
-
- if self.available_permits() < num {
- return false;
- }
-
- debug_assert!(self.waiter().is_none());
-
- self.0 -= num << NUM_SHIFT;
-
- if self.0 == NUM_FLAG {
- // Set the state to the stub pointer.
- self.0 = stub as *const _ as usize;
- }
-
- true
- }
-
- /// Releases permits
- ///
- /// Returns `true` if the permits were accepted.
- fn release_permits(&mut self, permits: usize, stub: &Waiter) {
- debug_assert!(permits > 0);
-
- if self.is_stub(stub) {
- self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG);
- return;
- }
-
- debug_assert!(self.has_available_permits());
-
- self.0 += permits << NUM_SHIFT;
- }
-
- fn is_waiter(self) -> bool {
- self.0 & NUM_FLAG == 0
- }
-
- /// Returns the waiter, if one is set.
- fn waiter(self) -> Option<NonNull<Waiter>> {
- if self.is_waiter() {
- let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored");
-
- Some(waiter)
- } else {
- None
- }
- }
-
- /// Assumes `self` represents a pointer
- fn as_ptr(self) -> *mut Waiter {
- (self.0 & !CLOSED_FLAG) as *mut Waiter
- }
-
- /// Sets to a pointer to a waiter.
- ///
- /// This can only be done from the full state.
- fn set_waiter(&mut self, waiter: NonNull<Waiter>) {
- let waiter = waiter.as_ptr() as usize;
- debug_assert!(!self.is_closed());
-
- self.0 = waiter;
- }
-
- fn is_stub(self, stub: &Waiter) -> bool {
- self.as_ptr() as usize == stub as *const _ as usize
- }
-
- /// Loads the state from an AtomicUsize.
- fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState {
- let value = cell.load(ordering);
- SemState(value)
- }
-
- fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState {
- let value = cell.fetch_or(CLOSED_FLAG, ordering);
- SemState(value)
- }
-
- fn is_closed(self) -> bool {
- self.0 & CLOSED_FLAG == CLOSED_FLAG
- }
-
- /// Converts the state into a `usize` representation.
- fn to_usize(self) -> usize {
- self.0
- }
-}
-
-impl fmt::Debug for SemState {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- let mut fmt = fmt.debug_struct("SemState");
-
- if self.is_waiter() {
- fmt.field("state", &"<waiter>");
- } else {
- fmt.field("permits", &self.available_permits());
- }
-
- fmt.finish()
- }
-}
-
-// ===== impl WaiterState =====
-
-impl WaiterState {
- fn permits_to_acquire(self) -> usize {
- self.0 >> PERMIT_SHIFT
- }
-
- fn set_permits_to_acquire(&mut self, val: usize) {
- self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK)
- }
-
- fn is_queued(self) -> bool {
- self.0 & QUEUED == QUEUED
- }
-
- fn set_queued(&mut self) {
- self.0 |= QUEUED;
- }
-
- fn is_closed(self) -> bool {
- self.0 & CLOSED == CLOSED
- }
-
- fn set_closed(&mut self) {
- self.0 |= CLOSED;
- }
-
- fn unset_queued(&mut self) {
- assert!(self.is_queued());
- self.0 -= QUEUED;
- }
-
- fn is_dropped(self) -> bool {
- self.0 & DROPPED == DROPPED
- }
-}
diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs
index 73b1745..ae4cac7 100644
--- a/src/sync/task/atomic_waker.rs
+++ b/src/sync/task/atomic_waker.rs
@@ -141,13 +141,12 @@ impl AtomicWaker {
}
}
+ /*
/// Registers the current waker to be notified on calls to `wake`.
- ///
- /// This is the same as calling `register_task` with `task::current()`.
- #[cfg(feature = "io-driver")]
pub(crate) fn register(&self, waker: Waker) {
self.do_register(waker);
}
+ */
/// Registers the provided waker to be notified on calls to `wake`.
///
diff --git a/src/sync/tests/loom_broadcast.rs b/src/sync/tests/loom_broadcast.rs
index da12fb9..4b1f034 100644
--- a/src/sync/tests/loom_broadcast.rs
+++ b/src/sync/tests/loom_broadcast.rs
@@ -1,5 +1,5 @@
use crate::sync::broadcast;
-use crate::sync::broadcast::RecvError::{Closed, Lagged};
+use crate::sync::broadcast::error::RecvError::{Closed, Lagged};
use loom::future::block_on;
use loom::sync::Arc;
diff --git a/src/sync/tests/loom_cancellation_token.rs b/src/sync/tests/loom_cancellation_token.rs
deleted file mode 100644
index e9c9f3d..0000000
--- a/src/sync/tests/loom_cancellation_token.rs
+++ /dev/null
@@ -1,155 +0,0 @@
-use crate::sync::CancellationToken;
-
-use loom::{future::block_on, thread};
-use tokio_test::assert_ok;
-
-#[test]
-fn cancel_token() {
- loom::model(|| {
- let token = CancellationToken::new();
- let token1 = token.clone();
-
- let th1 = thread::spawn(move || {
- block_on(async {
- token1.cancelled().await;
- });
- });
-
- let th2 = thread::spawn(move || {
- token.cancel();
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- });
-}
-
-#[test]
-fn cancel_with_child() {
- loom::model(|| {
- let token = CancellationToken::new();
- let token1 = token.clone();
- let token2 = token.clone();
- let child_token = token.child_token();
-
- let th1 = thread::spawn(move || {
- block_on(async {
- token1.cancelled().await;
- });
- });
-
- let th2 = thread::spawn(move || {
- token2.cancel();
- });
-
- let th3 = thread::spawn(move || {
- block_on(async {
- child_token.cancelled().await;
- });
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
-
-#[test]
-fn drop_token_no_child() {
- loom::model(|| {
- let token = CancellationToken::new();
- let token1 = token.clone();
- let token2 = token.clone();
-
- let th1 = thread::spawn(move || {
- drop(token1);
- });
-
- let th2 = thread::spawn(move || {
- drop(token2);
- });
-
- let th3 = thread::spawn(move || {
- drop(token);
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
-
-#[test]
-fn drop_token_with_childs() {
- loom::model(|| {
- let token1 = CancellationToken::new();
- let child_token1 = token1.child_token();
- let child_token2 = token1.child_token();
-
- let th1 = thread::spawn(move || {
- drop(token1);
- });
-
- let th2 = thread::spawn(move || {
- drop(child_token1);
- });
-
- let th3 = thread::spawn(move || {
- drop(child_token2);
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
-
-#[test]
-fn drop_and_cancel_token() {
- loom::model(|| {
- let token1 = CancellationToken::new();
- let token2 = token1.clone();
- let child_token = token1.child_token();
-
- let th1 = thread::spawn(move || {
- drop(token1);
- });
-
- let th2 = thread::spawn(move || {
- token2.cancel();
- });
-
- let th3 = thread::spawn(move || {
- drop(child_token);
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
-
-#[test]
-fn cancel_parent_and_child() {
- loom::model(|| {
- let token1 = CancellationToken::new();
- let token2 = token1.clone();
- let child_token = token1.child_token();
-
- let th1 = thread::spawn(move || {
- drop(token1);
- });
-
- let th2 = thread::spawn(move || {
- token2.cancel();
- });
-
- let th3 = thread::spawn(move || {
- child_token.cancel();
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
diff --git a/src/sync/tests/loom_mpsc.rs b/src/sync/tests/loom_mpsc.rs
index 6a1a6ab..c12313b 100644
--- a/src/sync/tests/loom_mpsc.rs
+++ b/src/sync/tests/loom_mpsc.rs
@@ -2,22 +2,24 @@ use crate::sync::mpsc;
use futures::future::poll_fn;
use loom::future::block_on;
+use loom::sync::Arc;
use loom::thread;
+use tokio_test::assert_ok;
#[test]
fn closing_tx() {
loom::model(|| {
- let (mut tx, mut rx) = mpsc::channel(16);
+ let (tx, mut rx) = mpsc::channel(16);
thread::spawn(move || {
tx.try_send(()).unwrap();
drop(tx);
});
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_some());
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_none());
});
}
@@ -32,15 +34,70 @@ fn closing_unbounded_tx() {
drop(tx);
});
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_some());
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_none());
});
}
#[test]
+fn closing_bounded_rx() {
+ loom::model(|| {
+ let (tx1, rx) = mpsc::channel::<()>(16);
+ let tx2 = tx1.clone();
+ thread::spawn(move || {
+ drop(rx);
+ });
+
+ block_on(tx1.closed());
+ block_on(tx2.closed());
+ });
+}
+
+#[test]
+fn closing_and_sending() {
+ loom::model(|| {
+ let (tx1, mut rx) = mpsc::channel::<()>(16);
+ let tx1 = Arc::new(tx1);
+ let tx2 = tx1.clone();
+
+ let th1 = thread::spawn(move || {
+ tx1.try_send(()).unwrap();
+ });
+
+ let th2 = thread::spawn(move || {
+ block_on(tx2.closed());
+ });
+
+ let th3 = thread::spawn(move || {
+ let v = block_on(rx.recv());
+ assert!(v.is_some());
+ drop(rx);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn closing_unbounded_rx() {
+ loom::model(|| {
+ let (tx1, rx) = mpsc::unbounded_channel::<()>();
+ let tx2 = tx1.clone();
+ thread::spawn(move || {
+ drop(rx);
+ });
+
+ block_on(tx1.closed());
+ block_on(tx2.closed());
+ });
+}
+
+#[test]
fn dropping_tx() {
loom::model(|| {
let (tx, mut rx) = mpsc::channel::<()>(16);
@@ -53,7 +110,7 @@ fn dropping_tx() {
}
drop(tx);
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_none());
});
}
@@ -71,7 +128,7 @@ fn dropping_unbounded_tx() {
}
drop(tx);
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_none());
});
}
diff --git a/src/sync/tests/loom_notify.rs b/src/sync/tests/loom_notify.rs
index 60981d4..79a5bf8 100644
--- a/src/sync/tests/loom_notify.rs
+++ b/src/sync/tests/loom_notify.rs
@@ -16,7 +16,7 @@ fn notify_one() {
});
});
- tx.notify();
+ tx.notify_one();
th.join().unwrap();
});
}
@@ -34,12 +34,12 @@ fn notify_multi() {
ths.push(thread::spawn(move || {
block_on(async {
notify.notified().await;
- notify.notify();
+ notify.notify_one();
})
}));
}
- notify.notify();
+ notify.notify_one();
for th in ths.drain(..) {
th.join().unwrap();
@@ -67,7 +67,7 @@ fn notify_drop() {
block_on(poll_fn(|cx| {
if recv.as_mut().poll(cx).is_ready() {
- rx1.notify();
+ rx1.notify_one();
}
Poll::Ready(())
}));
@@ -77,12 +77,12 @@ fn notify_drop() {
block_on(async {
rx2.notified().await;
// Trigger second notification
- rx2.notify();
+ rx2.notify_one();
rx2.notified().await;
});
});
- notify.notify();
+ notify.notify_one();
th1.join().unwrap();
th2.join().unwrap();
diff --git a/src/sync/tests/loom_oneshot.rs b/src/sync/tests/loom_oneshot.rs
index dfa7459..9729cfb 100644
--- a/src/sync/tests/loom_oneshot.rs
+++ b/src/sync/tests/loom_oneshot.rs
@@ -75,8 +75,10 @@ impl Future for OnClose<'_> {
type Output = bool;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
- let res = self.get_mut().tx.poll_closed(cx);
- Ready(res.is_ready())
+ let fut = self.get_mut().tx.closed();
+ crate::pin!(fut);
+
+ Ready(fut.poll(cx).is_ready())
}
}
diff --git a/src/sync/tests/loom_semaphore_ll.rs b/src/sync/tests/loom_semaphore_ll.rs
deleted file mode 100644
index b5e5efb..0000000
--- a/src/sync/tests/loom_semaphore_ll.rs
+++ /dev/null
@@ -1,192 +0,0 @@
-use crate::sync::semaphore_ll::*;
-
-use futures::future::poll_fn;
-use loom::future::block_on;
-use loom::thread;
-use std::future::Future;
-use std::pin::Pin;
-use std::sync::atomic::AtomicUsize;
-use std::sync::atomic::Ordering::SeqCst;
-use std::sync::Arc;
-use std::task::Poll::Ready;
-use std::task::{Context, Poll};
-
-#[test]
-fn basic_usage() {
- const NUM: usize = 2;
-
- struct Actor {
- waiter: Permit,
- shared: Arc<Shared>,
- }
-
- struct Shared {
- semaphore: Semaphore,
- active: AtomicUsize,
- }
-
- impl Future for Actor {
- type Output = ();
-
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
- let me = &mut *self;
-
- ready!(me.waiter.poll_acquire(cx, 1, &me.shared.semaphore)).unwrap();
-
- let actual = me.shared.active.fetch_add(1, SeqCst);
- assert!(actual <= NUM - 1);
-
- let actual = me.shared.active.fetch_sub(1, SeqCst);
- assert!(actual <= NUM);
-
- me.waiter.release(1, &me.shared.semaphore);
-
- Ready(())
- }
- }
-
- loom::model(|| {
- let shared = Arc::new(Shared {
- semaphore: Semaphore::new(NUM),
- active: AtomicUsize::new(0),
- });
-
- for _ in 0..NUM {
- let shared = shared.clone();
-
- thread::spawn(move || {
- block_on(Actor {
- waiter: Permit::new(),
- shared,
- });
- });
- }
-
- block_on(Actor {
- waiter: Permit::new(),
- shared,
- });
- });
-}
-
-#[test]
-fn release() {
- loom::model(|| {
- let semaphore = Arc::new(Semaphore::new(1));
-
- {
- let semaphore = semaphore.clone();
- thread::spawn(move || {
- let mut permit = Permit::new();
-
- block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap();
-
- permit.release(1, &semaphore);
- });
- }
-
- let mut permit = Permit::new();
-
- block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap();
-
- permit.release(1, &semaphore);
- });
-}
-
-#[test]
-fn basic_closing() {
- const NUM: usize = 2;
-
- loom::model(|| {
- let semaphore = Arc::new(Semaphore::new(1));
-
- for _ in 0..NUM {
- let semaphore = semaphore.clone();
-
- thread::spawn(move || {
- let mut permit = Permit::new();
-
- for _ in 0..2 {
- block_on(poll_fn(|cx| {
- permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ())
- }))?;
-
- permit.release(1, &semaphore);
- }
-
- Ok::<(), ()>(())
- });
- }
-
- semaphore.close();
- });
-}
-
-#[test]
-fn concurrent_close() {
- const NUM: usize = 3;
-
- loom::model(|| {
- let semaphore = Arc::new(Semaphore::new(1));
-
- for _ in 0..NUM {
- let semaphore = semaphore.clone();
-
- thread::spawn(move || {
- let mut permit = Permit::new();
-
- block_on(poll_fn(|cx| {
- permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ())
- }))?;
-
- permit.release(1, &semaphore);
-
- semaphore.close();
-
- Ok::<(), ()>(())
- });
- }
- });
-}
-
-#[test]
-fn batch() {
- let mut b = loom::model::Builder::new();
- b.preemption_bound = Some(1);
-
- b.check(|| {
- let semaphore = Arc::new(Semaphore::new(10));
- let active = Arc::new(AtomicUsize::new(0));
- let mut ths = vec![];
-
- for _ in 0..2 {
- let semaphore = semaphore.clone();
- let active = active.clone();
-
- ths.push(thread::spawn(move || {
- let mut permit = Permit::new();
-
- for n in &[4, 10, 8] {
- block_on(poll_fn(|cx| permit.poll_acquire(cx, *n, &semaphore))).unwrap();
-
- active.fetch_add(*n as usize, SeqCst);
-
- let num_active = active.load(SeqCst);
- assert!(num_active <= 10);
-
- thread::yield_now();
-
- active.fetch_sub(*n as usize, SeqCst);
-
- permit.release(*n, &semaphore);
- }
- }));
- }
-
- for th in ths.into_iter() {
- th.join().unwrap();
- }
-
- assert_eq!(10, semaphore.available_permits());
- });
-}
diff --git a/src/sync/tests/loom_watch.rs b/src/sync/tests/loom_watch.rs
new file mode 100644
index 0000000..c575b5b
--- /dev/null
+++ b/src/sync/tests/loom_watch.rs
@@ -0,0 +1,36 @@
+use crate::sync::watch;
+
+use loom::future::block_on;
+use loom::thread;
+
+#[test]
+fn smoke() {
+ loom::model(|| {
+ let (tx, mut rx1) = watch::channel(1);
+ let mut rx2 = rx1.clone();
+ let mut rx3 = rx1.clone();
+ let mut rx4 = rx1.clone();
+ let mut rx5 = rx1.clone();
+
+ let th = thread::spawn(move || {
+ tx.send(2).unwrap();
+ });
+
+ block_on(rx1.changed()).unwrap();
+ assert_eq!(*rx1.borrow(), 2);
+
+ block_on(rx2.changed()).unwrap();
+ assert_eq!(*rx2.borrow(), 2);
+
+ block_on(rx3.changed()).unwrap();
+ assert_eq!(*rx3.borrow(), 2);
+
+ block_on(rx4.changed()).unwrap();
+ assert_eq!(*rx4.borrow(), 2);
+
+ block_on(rx5.changed()).unwrap();
+ assert_eq!(*rx5.borrow(), 2);
+
+ th.join().unwrap();
+ })
+}
diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs
index 6ba8c1f..a78be6f 100644
--- a/src/sync/tests/mod.rs
+++ b/src/sync/tests/mod.rs
@@ -1,18 +1,15 @@
cfg_not_loom! {
mod atomic_waker;
- mod semaphore_ll;
mod semaphore_batch;
}
cfg_loom! {
mod loom_atomic_waker;
mod loom_broadcast;
- #[cfg(tokio_unstable)]
- mod loom_cancellation_token;
mod loom_list;
mod loom_mpsc;
mod loom_notify;
mod loom_oneshot;
mod loom_semaphore_batch;
- mod loom_semaphore_ll;
+ mod loom_watch;
}
diff --git a/src/sync/tests/semaphore_ll.rs b/src/sync/tests/semaphore_ll.rs
deleted file mode 100644
index bfb0757..0000000
--- a/src/sync/tests/semaphore_ll.rs
+++ /dev/null
@@ -1,470 +0,0 @@
-use crate::sync::semaphore_ll::{Permit, Semaphore};
-use tokio_test::*;
-
-#[test]
-fn poll_acquire_one_available() {
- let s = Semaphore::new(100);
- assert_eq!(s.available_permits(), 100);
-
- // Polling for a permit succeeds immediately
- let mut permit = task::spawn(Permit::new());
- assert!(!permit.is_acquired());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 99);
- assert!(permit.is_acquired());
-
- // Polling again on the same waiter does not claim a new permit
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 99);
- assert!(permit.is_acquired());
-}
-
-#[test]
-fn poll_acquire_many_available() {
- let s = Semaphore::new(100);
- assert_eq!(s.available_permits(), 100);
-
- // Polling for a permit succeeds immediately
- let mut permit = task::spawn(Permit::new());
- assert!(!permit.is_acquired());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s)));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-
- // Polling again on the same waiter does not claim a new permit
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s)));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-
- // Polling for a larger number of permits acquires more
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 8, &s)));
- assert_eq!(s.available_permits(), 92);
- assert!(permit.is_acquired());
-}
-
-#[test]
-fn try_acquire_one_available() {
- let s = Semaphore::new(100);
- assert_eq!(s.available_permits(), 100);
-
- // Polling for a permit succeeds immediately
- let mut permit = Permit::new();
- assert!(!permit.is_acquired());
-
- assert_ok!(permit.try_acquire(1, &s));
- assert_eq!(s.available_permits(), 99);
- assert!(permit.is_acquired());
-
- // Polling again on the same waiter does not claim a new permit
- assert_ok!(permit.try_acquire(1, &s));
- assert_eq!(s.available_permits(), 99);
- assert!(permit.is_acquired());
-}
-
-#[test]
-fn try_acquire_many_available() {
- let s = Semaphore::new(100);
- assert_eq!(s.available_permits(), 100);
-
- // Polling for a permit succeeds immediately
- let mut permit = Permit::new();
- assert!(!permit.is_acquired());
-
- assert_ok!(permit.try_acquire(5, &s));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-
- // Polling again on the same waiter does not claim a new permit
- assert_ok!(permit.try_acquire(5, &s));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-}
-
-#[test]
-fn poll_acquire_one_unavailable() {
- let s = Semaphore::new(1);
-
- let mut permit_1 = task::spawn(Permit::new());
- let mut permit_2 = task::spawn(Permit::new());
-
- // Acquire the first permit
- assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 0);
-
- permit_2.enter(|cx, mut p| {
- // Try to acquire the second permit
- assert_pending!(p.poll_acquire(cx, 1, &s));
- });
-
- permit_1.release(1, &s);
-
- assert_eq!(s.available_permits(), 0);
- assert!(permit_2.is_woken());
- assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- permit_2.release(1, &s);
- assert_eq!(s.available_permits(), 1);
-}
-
-#[test]
-fn forget_acquired() {
- let s = Semaphore::new(1);
-
- // Polling for a permit succeeds immediately
- let mut permit = task::spawn(Permit::new());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_eq!(s.available_permits(), 0);
-
- permit.forget(1);
- assert_eq!(s.available_permits(), 0);
-}
-
-#[test]
-fn forget_waiting() {
- let s = Semaphore::new(0);
-
- // Polling for a permit succeeds immediately
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_eq!(s.available_permits(), 0);
-
- permit.forget(1);
-
- s.add_permits(1);
-
- assert!(!permit.is_woken());
- assert_eq!(s.available_permits(), 1);
-}
-
-#[test]
-fn poll_acquire_many_unavailable() {
- let s = Semaphore::new(5);
-
- let mut permit_1 = task::spawn(Permit::new());
- let mut permit_2 = task::spawn(Permit::new());
- let mut permit_3 = task::spawn(Permit::new());
-
- // Acquire the first permit
- assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 4);
-
- permit_2.enter(|cx, mut p| {
- // Try to acquire the second permit
- assert_pending!(p.poll_acquire(cx, 5, &s));
- });
-
- assert_eq!(s.available_permits(), 0);
-
- permit_3.enter(|cx, mut p| {
- // Try to acquire the third permit
- assert_pending!(p.poll_acquire(cx, 3, &s));
- });
-
- permit_1.release(1, &s);
-
- assert_eq!(s.available_permits(), 0);
- assert!(permit_2.is_woken());
- assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 5, &s)));
-
- assert!(!permit_3.is_woken());
- assert_eq!(s.available_permits(), 0);
-
- permit_2.release(1, &s);
- assert!(!permit_3.is_woken());
- assert_eq!(s.available_permits(), 0);
-
- permit_2.release(2, &s);
- assert!(permit_3.is_woken());
-
- assert_ready_ok!(permit_3.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-}
-
-#[test]
-fn try_acquire_one_unavailable() {
- let s = Semaphore::new(1);
-
- let mut permit_1 = Permit::new();
- let mut permit_2 = Permit::new();
-
- // Acquire the first permit
- assert_ok!(permit_1.try_acquire(1, &s));
- assert_eq!(s.available_permits(), 0);
-
- assert_err!(permit_2.try_acquire(1, &s));
-
- permit_1.release(1, &s);
-
- assert_eq!(s.available_permits(), 1);
- assert_ok!(permit_2.try_acquire(1, &s));
-
- permit_2.release(1, &s);
- assert_eq!(s.available_permits(), 1);
-}
-
-#[test]
-fn try_acquire_many_unavailable() {
- let s = Semaphore::new(5);
-
- let mut permit_1 = Permit::new();
- let mut permit_2 = Permit::new();
-
- // Acquire the first permit
- assert_ok!(permit_1.try_acquire(1, &s));
- assert_eq!(s.available_permits(), 4);
-
- assert_err!(permit_2.try_acquire(5, &s));
-
- permit_1.release(1, &s);
- assert_eq!(s.available_permits(), 5);
-
- assert_ok!(permit_2.try_acquire(5, &s));
-
- permit_2.release(1, &s);
- assert_eq!(s.available_permits(), 1);
-
- permit_2.release(1, &s);
- assert_eq!(s.available_permits(), 2);
-}
-
-#[test]
-fn poll_acquire_one_zero_permits() {
- let s = Semaphore::new(0);
- assert_eq!(s.available_permits(), 0);
-
- let mut permit = task::spawn(Permit::new());
-
- // Try to acquire the permit
- permit.enter(|cx, mut p| {
- assert_pending!(p.poll_acquire(cx, 1, &s));
- });
-
- s.add_permits(1);
-
- assert!(permit.is_woken());
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-}
-
-#[test]
-#[should_panic]
-fn validates_max_permits() {
- use std::usize;
- Semaphore::new((usize::MAX >> 2) + 1);
-}
-
-#[test]
-fn close_semaphore_prevents_acquire() {
- let s = Semaphore::new(5);
- s.close();
-
- assert_eq!(5, s.available_permits());
-
- let mut permit_1 = task::spawn(Permit::new());
- let mut permit_2 = task::spawn(Permit::new());
-
- assert_ready_err!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(5, s.available_permits());
-
- assert_ready_err!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- assert_eq!(5, s.available_permits());
-}
-
-#[test]
-fn close_semaphore_notifies_permit1() {
- let s = Semaphore::new(0);
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- s.close();
-
- assert!(permit.is_woken());
- assert_ready_err!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-}
-
-#[test]
-fn close_semaphore_notifies_permit2() {
- let s = Semaphore::new(2);
-
- let mut permit1 = task::spawn(Permit::new());
- let mut permit2 = task::spawn(Permit::new());
- let mut permit3 = task::spawn(Permit::new());
- let mut permit4 = task::spawn(Permit::new());
-
- // Acquire a couple of permits
- assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_pending!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_pending!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- s.close();
-
- assert!(permit3.is_woken());
- assert!(permit4.is_woken());
-
- assert_ready_err!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_ready_err!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_eq!(0, s.available_permits());
-
- permit1.release(1, &s);
-
- assert_eq!(1, s.available_permits());
-
- assert_ready_err!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- permit2.release(1, &s);
-
- assert_eq!(2, s.available_permits());
-}
-
-#[test]
-fn poll_acquire_additional_permits_while_waiting_before_assigned() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-
- s.add_permits(1);
- assert!(!permit.is_woken());
-
- s.add_permits(1);
- assert!(permit.is_woken());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-}
-
-#[test]
-fn try_acquire_additional_permits_while_waiting_before_assigned() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
-
- assert_err!(permit.enter(|_, mut p| p.try_acquire(3, &s)));
-
- s.add_permits(1);
- assert!(permit.is_woken());
-
- assert_ok!(permit.enter(|_, mut p| p.try_acquire(2, &s)));
-}
-
-#[test]
-fn poll_acquire_additional_permits_while_waiting_after_assigned_success() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
-
- s.add_permits(2);
-
- assert!(permit.is_woken());
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-}
-
-#[test]
-fn poll_acquire_additional_permits_while_waiting_after_assigned_requeue() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
-
- s.add_permits(2);
-
- assert!(permit.is_woken());
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s)));
-
- s.add_permits(1);
-
- assert!(permit.is_woken());
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s)));
-}
-
-#[test]
-fn poll_acquire_fewer_permits_while_waiting() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- assert_eq!(s.available_permits(), 0);
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 0);
-}
-
-#[test]
-fn poll_acquire_fewer_permits_after_assigned() {
- let s = Semaphore::new(1);
-
- let mut permit1 = task::spawn(Permit::new());
- let mut permit2 = task::spawn(Permit::new());
-
- assert_pending!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 5, &s)));
- assert_eq!(s.available_permits(), 0);
-
- assert_pending!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- s.add_permits(4);
- assert!(permit1.is_woken());
- assert!(!permit2.is_woken());
-
- assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-
- assert!(permit2.is_woken());
- assert_eq!(s.available_permits(), 1);
-
- assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-}
-
-#[test]
-fn forget_partial_1() {
- let s = Semaphore::new(0);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- s.add_permits(1);
-
- assert_eq!(0, s.available_permits());
-
- permit.release(1, &s);
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_eq!(s.available_permits(), 0);
-}
-
-#[test]
-fn forget_partial_2() {
- let s = Semaphore::new(0);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- s.add_permits(1);
-
- assert_eq!(0, s.available_permits());
-
- permit.release(1, &s);
-
- s.add_permits(1);
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- assert_eq!(s.available_permits(), 0);
-}
diff --git a/src/sync/watch.rs b/src/sync/watch.rs
index 13033d9..ec73832 100644
--- a/src/sync/watch.rs
+++ b/src/sync/watch.rs
@@ -6,13 +6,11 @@
//!
//! # Usage
//!
-//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are
-//! the producer and sender halves of the channel. The channel is
-//! created with an initial value. [`Receiver::recv`] will always
-//! be ready upon creation and will yield either this initial value or
-//! the latest value that has been sent by `Sender`.
-//!
-//! Calls to [`Receiver::recv`] will always yield the latest value.
+//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer
+//! and sender halves of the channel. The channel is created with an initial
+//! value. The **latest** value stored in the channel is accessed with
+//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new
+//! value to sent by the [`Sender`] half.
//!
//! # Examples
//!
@@ -23,21 +21,21 @@
//! let (tx, mut rx) = watch::channel("hello");
//!
//! tokio::spawn(async move {
-//! while let Some(value) = rx.recv().await {
-//! println!("received = {:?}", value);
+//! while rx.changed().await.is_ok() {
+//! println!("received = {:?}", *rx.borrow());
//! }
//! });
//!
-//! tx.broadcast("world")?;
+//! tx.send("world")?;
//! # Ok(())
//! # }
//! ```
//!
//! # Closing
//!
-//! [`Sender::closed`] allows the producer to detect when all [`Receiver`]
-//! handles have been dropped. This indicates that there is no further interest
-//! in the values being produced and work can be stopped.
+//! [`Sender::is_closed`] and [`Sender::closed`] allow the producer to detect
+//! when all [`Receiver`] handles have been dropped. This indicates that there
+//! is no further interest in the values being produced and work can be stopped.
//!
//! # Thread safety
//!
@@ -47,20 +45,18 @@
//!
//! [`Sender`]: crate::sync::watch::Sender
//! [`Receiver`]: crate::sync::watch::Receiver
-//! [`Receiver::recv`]: crate::sync::watch::Receiver::recv
+//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed
+//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow
//! [`channel`]: crate::sync::watch::channel
+//! [`Sender::is_closed`]: crate::sync::watch::Sender::is_closed
//! [`Sender::closed`]: crate::sync::watch::Sender::closed
-use crate::future::poll_fn;
-use crate::sync::task::AtomicWaker;
+use crate::sync::Notify;
-use fnv::FnvHashSet;
use std::ops;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
-use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak};
-use std::task::Poll::{Pending, Ready};
-use std::task::{Context, Poll};
+use std::sync::{Arc, RwLock, RwLockReadGuard};
/// Receives values from the associated [`Sender`](struct@Sender).
///
@@ -70,8 +66,8 @@ pub struct Receiver<T> {
/// Pointer to the shared state
shared: Arc<Shared<T>>,
- /// Pointer to the watcher's internal state
- inner: Watcher,
+ /// Last observed version
+ version: usize,
}
/// Sends values to the associated [`Receiver`](struct@Receiver).
@@ -79,7 +75,7 @@ pub struct Receiver<T> {
/// Instances are created by the [`channel`](fn@channel) function.
#[derive(Debug)]
pub struct Sender<T> {
- shared: Weak<Shared<T>>,
+ shared: Arc<Shared<T>>,
}
/// Returns a reference to the inner value
@@ -92,6 +88,27 @@ pub struct Ref<'a, T> {
inner: RwLockReadGuard<'a, T>,
}
+#[derive(Debug)]
+struct Shared<T> {
+ /// The most recent value
+ value: RwLock<T>,
+
+ /// The current version
+ ///
+ /// The lowest bit represents a "closed" state. The rest of the bits
+ /// represent the current version.
+ version: AtomicUsize,
+
+ /// Tracks the number of `Receiver` instances
+ ref_count_rx: AtomicUsize,
+
+ /// Notifies waiting receivers that the value changed.
+ notify_rx: Notify,
+
+ /// Notifies any task listening for `Receiver` dropped events
+ notify_tx: Notify,
+}
+
pub mod error {
//! Watch error types
@@ -112,37 +129,20 @@ pub mod error {
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
-}
-
-#[derive(Debug)]
-struct Shared<T> {
- /// The most recent value
- value: RwLock<T>,
- /// The current version
- ///
- /// The lowest bit represents a "closed" state. The rest of the bits
- /// represent the current version.
- version: AtomicUsize,
-
- /// All watchers
- watchers: Mutex<Watchers>,
-
- /// Task to notify when all watchers drop
- cancel: AtomicWaker,
-}
+ /// Error produced when receiving a change notification.
+ #[derive(Debug)]
+ pub struct RecvError(pub(super) ());
-type Watchers = FnvHashSet<Watcher>;
+ // ===== impl RecvError =====
-/// The watcher's ID is based on the Arc's pointer.
-#[derive(Clone, Debug)]
-struct Watcher(Arc<WatchInner>);
+ impl fmt::Display for RecvError {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(fmt, "channel closed")
+ }
+ }
-#[derive(Debug)]
-struct WatchInner {
- /// Last observed version
- version: AtomicUsize,
- waker: AtomicWaker,
+ impl std::error::Error for RecvError {}
}
const CLOSED: usize = 1;
@@ -162,41 +162,32 @@ const CLOSED: usize = 1;
/// let (tx, mut rx) = watch::channel("hello");
///
/// tokio::spawn(async move {
-/// while let Some(value) = rx.recv().await {
-/// println!("received = {:?}", value);
+/// while rx.changed().await.is_ok() {
+/// println!("received = {:?}", *rx.borrow());
/// }
/// });
///
-/// tx.broadcast("world")?;
+/// tx.send("world")?;
/// # Ok(())
/// # }
/// ```
///
/// [`Sender`]: struct@Sender
/// [`Receiver`]: struct@Receiver
-pub fn channel<T: Clone>(init: T) -> (Sender<T>, Receiver<T>) {
- const VERSION_0: usize = 0;
- const VERSION_1: usize = 2;
-
- // We don't start knowing VERSION_1
- let inner = Watcher::new_version(VERSION_0);
-
- // Insert the watcher
- let mut watchers = FnvHashSet::with_capacity_and_hasher(0, Default::default());
- watchers.insert(inner.clone());
-
+pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared {
value: RwLock::new(init),
- version: AtomicUsize::new(VERSION_1),
- watchers: Mutex::new(watchers),
- cancel: AtomicWaker::new(),
+ version: AtomicUsize::new(0),
+ ref_count_rx: AtomicUsize::new(1),
+ notify_rx: Notify::new(),
+ notify_tx: Notify::new(),
});
let tx = Sender {
- shared: Arc::downgrade(&shared),
+ shared: shared.clone(),
};
- let rx = Receiver { shared, inner };
+ let rx = Receiver { shared, version: 0 };
(tx, rx)
}
@@ -221,39 +212,13 @@ impl<T> Receiver<T> {
Ref { inner }
}
- // TODO: document
- #[doc(hidden)]
- pub fn poll_recv_ref<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll<Option<Ref<'a, T>>> {
- // Make sure the task is up to date
- self.inner.waker.register_by_ref(cx.waker());
-
- let state = self.shared.version.load(SeqCst);
- let version = state & !CLOSED;
-
- if self.inner.version.swap(version, Relaxed) != version {
- let inner = self.shared.value.read().unwrap();
-
- return Ready(Some(Ref { inner }));
- }
-
- if CLOSED == state & CLOSED {
- // The `Store` handle has been dropped.
- return Ready(None);
- }
-
- Pending
- }
-}
-
-impl<T: Clone> Receiver<T> {
- /// Attempts to clone the latest value sent via the channel.
+ /// Wait for a change notification
///
- /// If this is the first time the function is called on a `Receiver`
- /// instance, then the function completes immediately with the **current**
- /// value held by the channel. On the next call, the function waits until
- /// a new value is sent in the channel.
+ /// Returns when a new value has been sent by the [`Sender`] since the last
+ /// time `changed()` was called. When the `Sender` half is dropped, `Err` is
+ /// returned.
///
- /// `None` is returned if the `Sender` half is dropped.
+ /// [`Sender`]: struct@Sender
///
/// # Examples
///
@@ -264,118 +229,170 @@ impl<T: Clone> Receiver<T> {
/// async fn main() {
/// let (tx, mut rx) = watch::channel("hello");
///
- /// let v = rx.recv().await.unwrap();
- /// assert_eq!(v, "hello");
- ///
/// tokio::spawn(async move {
- /// tx.broadcast("goodbye").unwrap();
+ /// tx.send("goodbye").unwrap();
/// });
///
- /// // Waits for the new task to spawn and send the value.
- /// let v = rx.recv().await.unwrap();
- /// assert_eq!(v, "goodbye");
+ /// assert!(rx.changed().await.is_ok());
+ /// assert_eq!(*rx.borrow(), "goodbye");
///
- /// let v = rx.recv().await;
- /// assert!(v.is_none());
+ /// // The `tx` handle has been dropped
+ /// assert!(rx.changed().await.is_err());
/// }
/// ```
- pub async fn recv(&mut self) -> Option<T> {
- poll_fn(|cx| {
- let v_ref = ready!(self.poll_recv_ref(cx));
- Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone()))
+ pub async fn changed(&mut self) -> Result<(), error::RecvError> {
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::task::Poll;
+
+ // In order to avoid a race condition, we first request a notification,
+ // **then** check the current value's version. If a new version exists,
+ // the notification request is dropped. Requesting the notification
+ // requires polling the future once.
+ let notified = self.shared.notify_rx.notified();
+ pin!(notified);
+
+ // Polling the future once is guaranteed to return `Pending` as `watch`
+ // only notifies using `notify_waiters`.
+ crate::future::poll_fn(|cx| {
+ let res = Pin::new(&mut notified).poll(cx);
+ assert!(!res.is_ready());
+ Poll::Ready(())
})
- .await
+ .await;
+
+ if let Some(ret) = maybe_changed(&self.shared, &mut self.version) {
+ return ret;
+ }
+
+ notified.await;
+
+ maybe_changed(&self.shared, &mut self.version)
+ .expect("[bug] failed to observe change after notificaton.")
}
}
-#[cfg(feature = "stream")]
-impl<T: Clone> crate::stream::Stream for Receiver<T> {
- type Item = T;
-
- fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
- let v_ref = ready!(self.poll_recv_ref(cx));
+fn maybe_changed<T>(
+ shared: &Shared<T>,
+ version: &mut usize,
+) -> Option<Result<(), error::RecvError>> {
+ // Load the version from the state
+ let state = shared.version.load(SeqCst);
+ let new_version = state & !CLOSED;
+
+ if *version != new_version {
+ // Observe the new version and return
+ *version = new_version;
+ return Some(Ok(()));
+ }
- Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone()))
+ if CLOSED == state & CLOSED {
+ // All receivers have dropped.
+ return Some(Err(error::RecvError(())));
}
+
+ None
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
- let ver = self.inner.version.load(Relaxed);
- let inner = Watcher::new_version(ver);
+ let version = self.version;
let shared = self.shared.clone();
- shared.watchers.lock().unwrap().insert(inner.clone());
+ // No synchronization necessary as this is only used as a counter and
+ // not memory access.
+ shared.ref_count_rx.fetch_add(1, Relaxed);
- Receiver { shared, inner }
+ Receiver { version, shared }
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
- self.shared.watchers.lock().unwrap().remove(&self.inner);
+ // No synchronization necessary as this is only used as a counter and
+ // not memory access.
+ if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) {
+ // This is the last `Receiver` handle, tasks waiting on `Sender::closed()`
+ self.shared.notify_tx.notify_waiters();
+ }
}
}
impl<T> Sender<T> {
- /// Broadcasts a new value via the channel, notifying all receivers.
- pub fn broadcast(&self, value: T) -> Result<(), error::SendError<T>> {
- let shared = match self.shared.upgrade() {
- Some(shared) => shared,
- // All `Watch` handles have been canceled
- None => return Err(error::SendError { inner: value }),
- };
-
- // Replace the value
- {
- let mut lock = shared.value.write().unwrap();
- *lock = value;
+ /// Sends a new value via the channel, notifying all receivers.
+ pub fn send(&self, value: T) -> Result<(), error::SendError<T>> {
+ // This is pretty much only useful as a hint anyway, so synchronization isn't critical.
+ if 0 == self.shared.ref_count_rx.load(Relaxed) {
+ return Err(error::SendError { inner: value });
}
+ *self.shared.value.write().unwrap() = value;
+
// Update the version. 2 is used so that the CLOSED bit is not set.
- shared.version.fetch_add(2, SeqCst);
+ self.shared.version.fetch_add(2, SeqCst);
// Notify all watchers
- notify_all(&*shared);
+ self.shared.notify_rx.notify_waiters();
Ok(())
}
+ /// Checks if the channel has been closed. This happens when all receivers
+ /// have dropped.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// let (tx, rx) = tokio::sync::watch::channel(());
+ /// assert!(!tx.is_closed());
+ ///
+ /// drop(rx);
+ /// assert!(tx.is_closed());
+ /// ```
+ pub fn is_closed(&self) -> bool {
+ self.shared.ref_count_rx.load(Relaxed) == 0
+ }
+
/// Completes when all receivers have dropped.
///
/// This allows the producer to get notified when interest in the produced
/// values is canceled and immediately stop doing work.
- pub async fn closed(&mut self) {
- poll_fn(|cx| self.poll_close(cx)).await
- }
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::watch;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, rx) = watch::channel("hello");
+ ///
+ /// tokio::spawn(async move {
+ /// // use `rx`
+ /// drop(rx);
+ /// });
+ ///
+ /// // Waits for `rx` to drop
+ /// tx.closed().await;
+ /// println!("the `rx` handles dropped")
+ /// }
+ /// ```
+ pub async fn closed(&self) {
+ let notified = self.shared.notify_tx.notified();
- fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> {
- match self.shared.upgrade() {
- Some(shared) => {
- shared.cancel.register_by_ref(cx.waker());
- Pending
- }
- None => Ready(()),
+ if self.shared.ref_count_rx.load(Relaxed) == 0 {
+ return;
}
- }
-}
-
-/// Notifies all watchers of a change
-fn notify_all<T>(shared: &Shared<T>) {
- let watchers = shared.watchers.lock().unwrap();
- for watcher in watchers.iter() {
- // Notify the task
- watcher.waker.wake();
+ notified.await;
+ debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed));
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
- if let Some(shared) = self.shared.upgrade() {
- shared.version.fetch_or(CLOSED, SeqCst);
- notify_all(&*shared);
- }
+ self.shared.version.fetch_or(CLOSED, SeqCst);
+ self.shared.notify_rx.notify_waiters();
}
}
@@ -388,44 +405,3 @@ impl<T> ops::Deref for Ref<'_, T> {
self.inner.deref()
}
}
-
-// ===== impl Shared =====
-
-impl<T> Drop for Shared<T> {
- fn drop(&mut self) {
- self.cancel.wake();
- }
-}
-
-// ===== impl Watcher =====
-
-impl Watcher {
- fn new_version(version: usize) -> Self {
- Watcher(Arc::new(WatchInner {
- version: AtomicUsize::new(version),
- waker: AtomicWaker::new(),
- }))
- }
-}
-
-impl std::cmp::PartialEq for Watcher {
- fn eq(&self, other: &Watcher) -> bool {
- Arc::ptr_eq(&self.0, &other.0)
- }
-}
-
-impl std::cmp::Eq for Watcher {}
-
-impl std::hash::Hash for Watcher {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- (&*self.0 as *const WatchInner).hash(state)
- }
-}
-
-impl std::ops::Deref for Watcher {
- type Target = WatchInner;
-
- fn deref(&self) -> &Self::Target {
- &self.0
- }
-}