aboutsummaryrefslogtreecommitdiff
path: root/src/sync
diff options
context:
space:
mode:
Diffstat (limited to 'src/sync')
-rw-r--r--src/sync/batch_semaphore.rs8
-rw-r--r--src/sync/broadcast.rs32
-rw-r--r--src/sync/mpsc/block.rs4
-rw-r--r--src/sync/mpsc/bounded.rs20
-rw-r--r--src/sync/mpsc/chan.rs6
-rw-r--r--src/sync/mpsc/error.rs2
-rw-r--r--src/sync/mpsc/list.rs4
-rw-r--r--src/sync/mpsc/unbounded.rs2
-rw-r--r--src/sync/mutex.rs38
-rw-r--r--src/sync/notify.rs25
-rw-r--r--src/sync/once_cell.rs8
-rw-r--r--src/sync/oneshot.rs286
-rw-r--r--src/sync/rwlock/owned_read_guard.rs2
-rw-r--r--src/sync/rwlock/owned_write_guard.rs2
-rw-r--r--src/sync/rwlock/owned_write_guard_mapped.rs2
-rw-r--r--src/sync/rwlock/read_guard.rs2
-rw-r--r--src/sync/rwlock/write_guard.rs2
-rw-r--r--src/sync/rwlock/write_guard_mapped.rs2
-rw-r--r--src/sync/task/atomic_waker.rs80
-rw-r--r--src/sync/tests/atomic_waker.rs39
-rw-r--r--src/sync/tests/loom_atomic_waker.rs55
-rw-r--r--src/sync/tests/loom_oneshot.rs29
-rw-r--r--src/sync/tests/mod.rs1
-rw-r--r--src/sync/tests/notify.rs44
-rw-r--r--src/sync/watch.rs48
25 files changed, 651 insertions, 92 deletions
diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs
index 9b43404..b5c39d2 100644
--- a/src/sync/batch_semaphore.rs
+++ b/src/sync/batch_semaphore.rs
@@ -1,5 +1,5 @@
#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
-//! # Implementation Details
+//! # Implementation Details.
//!
//! The semaphore is implemented using an intrusive linked list of waiters. An
//! atomic counter tracks the number of available permits. If the semaphore does
@@ -138,7 +138,7 @@ impl Semaphore {
}
}
- /// Creates a new semaphore with the initial number of permits
+ /// Creates a new semaphore with the initial number of permits.
///
/// Maximum number of permits on 32-bit platforms is `1<<29`.
///
@@ -159,7 +159,7 @@ impl Semaphore {
}
}
- /// Returns the current number of available permits
+ /// Returns the current number of available permits.
pub(crate) fn available_permits(&self) -> usize {
self.permits.load(Acquire) >> Self::PERMIT_SHIFT
}
@@ -197,7 +197,7 @@ impl Semaphore {
}
}
- /// Returns true if the semaphore is closed
+ /// Returns true if the semaphore is closed.
pub(crate) fn is_closed(&self) -> bool {
self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED
}
diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs
index a2ca445..0d9cd3b 100644
--- a/src/sync/broadcast.rs
+++ b/src/sync/broadcast.rs
@@ -293,37 +293,37 @@ pub mod error {
use self::error::*;
-/// Data shared between senders and receivers
+/// Data shared between senders and receivers.
struct Shared<T> {
- /// slots in the channel
+ /// slots in the channel.
buffer: Box<[RwLock<Slot<T>>]>,
- /// Mask a position -> index
+ /// Mask a position -> index.
mask: usize,
/// Tail of the queue. Includes the rx wait list.
tail: Mutex<Tail>,
- /// Number of outstanding Sender handles
+ /// Number of outstanding Sender handles.
num_tx: AtomicUsize,
}
-/// Next position to write a value
+/// Next position to write a value.
struct Tail {
- /// Next position to write to
+ /// Next position to write to.
pos: u64,
- /// Number of active receivers
+ /// Number of active receivers.
rx_cnt: usize,
- /// True if the channel is closed
+ /// True if the channel is closed.
closed: bool,
- /// Receivers waiting for a value
+ /// Receivers waiting for a value.
waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
}
-/// Slot in the buffer
+/// Slot in the buffer.
struct Slot<T> {
/// Remaining number of receivers that are expected to see this value.
///
@@ -333,7 +333,7 @@ struct Slot<T> {
/// acquired.
rem: AtomicUsize,
- /// Uniquely identifies the `send` stored in the slot
+ /// Uniquely identifies the `send` stored in the slot.
pos: u64,
/// True signals the channel is closed.
@@ -346,9 +346,9 @@ struct Slot<T> {
val: UnsafeCell<Option<T>>,
}
-/// An entry in the wait queue
+/// An entry in the wait queue.
struct Waiter {
- /// True if queued
+ /// True if queued.
queued: bool,
/// Task waiting on the broadcast channel.
@@ -365,12 +365,12 @@ struct RecvGuard<'a, T> {
slot: RwLockReadGuard<'a, Slot<T>>,
}
-/// Receive a value future
+/// Receive a value future.
struct Recv<'a, T> {
- /// Receiver being waited on
+ /// Receiver being waited on.
receiver: &'a mut Receiver<T>,
- /// Entry in the waiter `LinkedList`
+ /// Entry in the waiter `LinkedList`.
waiter: UnsafeCell<Waiter>,
}
diff --git a/src/sync/mpsc/block.rs b/src/sync/mpsc/block.rs
index 6e7b700..58f4a9f 100644
--- a/src/sync/mpsc/block.rs
+++ b/src/sync/mpsc/block.rs
@@ -40,7 +40,7 @@ struct Values<T>([UnsafeCell<MaybeUninit<T>>; BLOCK_CAP]);
use super::BLOCK_CAP;
-/// Masks an index to get the block identifier
+/// Masks an index to get the block identifier.
const BLOCK_MASK: usize = !(BLOCK_CAP - 1);
/// Masks an index to get the value offset in a block.
@@ -89,7 +89,7 @@ impl<T> Block<T> {
}
}
- /// Returns `true` if the block matches the given index
+ /// Returns `true` if the block matches the given index.
pub(crate) fn is_at_index(&self, index: usize) -> bool {
debug_assert!(offset(index) == 0);
self.start_index == index
diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs
index bcad84d..5a2bfa6 100644
--- a/src/sync/mpsc/bounded.rs
+++ b/src/sync/mpsc/bounded.rs
@@ -10,7 +10,7 @@ cfg_time! {
use std::fmt;
use std::task::{Context, Poll};
-/// Send values to the associated `Receiver`.
+/// Sends values to the associated `Receiver`.
///
/// Instances are created by the [`channel`](channel) function.
///
@@ -22,7 +22,7 @@ pub struct Sender<T> {
chan: chan::Tx<T, Semaphore>,
}
-/// Permit to send one value into the channel.
+/// Permits to send one value into the channel.
///
/// `Permit` values are returned by [`Sender::reserve()`] and [`Sender::try_reserve()`]
/// and are used to guarantee channel capacity before generating a message to send.
@@ -49,7 +49,7 @@ pub struct OwnedPermit<T> {
chan: Option<chan::Tx<T, Semaphore>>,
}
-/// Receive values from the associated `Sender`.
+/// Receives values from the associated `Sender`.
///
/// Instances are created by the [`channel`](channel) function.
///
@@ -57,7 +57,7 @@ pub struct OwnedPermit<T> {
///
/// [`ReceiverStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.ReceiverStream.html
pub struct Receiver<T> {
- /// The channel receiver
+ /// The channel receiver.
chan: chan::Rx<T, Semaphore>,
}
@@ -187,7 +187,7 @@ impl<T> Receiver<T> {
poll_fn(|cx| self.chan.recv(cx)).await
}
- /// Try to receive the next value for this receiver.
+ /// Tries to receive the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
@@ -672,7 +672,7 @@ impl<T> Sender<T> {
self.chan.is_closed()
}
- /// Wait for channel capacity. Once capacity to send one message is
+ /// Waits for channel capacity. Once capacity to send one message is
/// available, it is reserved for the caller.
///
/// If the channel is full, the function waits for the number of unreceived
@@ -721,7 +721,7 @@ impl<T> Sender<T> {
Ok(Permit { chan: &self.chan })
}
- /// Wait for channel capacity, moving the `Sender` and returning an owned
+ /// Waits for channel capacity, moving the `Sender` and returning an owned
/// permit. Once capacity to send one message is available, it is reserved
/// for the caller.
///
@@ -815,7 +815,7 @@ impl<T> Sender<T> {
}
}
- /// Try to acquire a slot in the channel without waiting for the slot to become
+ /// Tries to acquire a slot in the channel without waiting for the slot to become
/// available.
///
/// If the channel is full this function will return [`TrySendError`], otherwise
@@ -868,7 +868,7 @@ impl<T> Sender<T> {
Ok(Permit { chan: &self.chan })
}
- /// Try to acquire a slot in the channel without waiting for the slot to become
+ /// Tries to acquire a slot in the channel without waiting for the slot to become
/// available, returning an owned permit.
///
/// This moves the sender _by value_, and returns an owned permit that can
@@ -1117,7 +1117,7 @@ impl<T> OwnedPermit<T> {
Sender { chan }
}
- /// Release the reserved capacity *without* sending a message, returning the
+ /// Releases the reserved capacity *without* sending a message, returning the
/// [`Sender`].
///
/// # Examples
diff --git a/src/sync/mpsc/chan.rs b/src/sync/mpsc/chan.rs
index 637ae1f..c3007de 100644
--- a/src/sync/mpsc/chan.rs
+++ b/src/sync/mpsc/chan.rs
@@ -14,7 +14,7 @@ use std::sync::atomic::Ordering::{AcqRel, Relaxed};
use std::task::Poll::{Pending, Ready};
use std::task::{Context, Poll};
-/// Channel sender
+/// Channel sender.
pub(crate) struct Tx<T, S> {
inner: Arc<Chan<T, S>>,
}
@@ -25,7 +25,7 @@ impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> {
}
}
-/// Channel receiver
+/// Channel receiver.
pub(crate) struct Rx<T, S: Semaphore> {
inner: Arc<Chan<T, S>>,
}
@@ -47,7 +47,7 @@ pub(crate) trait Semaphore {
}
struct Chan<T, S> {
- /// Notifies all tasks listening for the receiver being dropped
+ /// Notifies all tasks listening for the receiver being dropped.
notify_rx_closed: Notify,
/// Handle to the push half of the lock-free list.
diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs
index 48ca379..b7b9cf7 100644
--- a/src/sync/mpsc/error.rs
+++ b/src/sync/mpsc/error.rs
@@ -1,4 +1,4 @@
-//! Channel error types
+//! Channel error types.
use std::error::Error;
use std::fmt;
diff --git a/src/sync/mpsc/list.rs b/src/sync/mpsc/list.rs
index 53c34d2..e4eeb45 100644
--- a/src/sync/mpsc/list.rs
+++ b/src/sync/mpsc/list.rs
@@ -8,7 +8,7 @@ use std::fmt;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
-/// List queue transmit handle
+/// List queue transmit handle.
pub(crate) struct Tx<T> {
/// Tail in the `Block` mpmc list.
block_tail: AtomicPtr<Block<T>>,
@@ -79,7 +79,7 @@ impl<T> Tx<T> {
}
}
- /// Closes the send half of the list
+ /// Closes the send half of the list.
///
/// Similar process as pushing a value, but instead of writing the value &
/// setting the ready flag, the TX_CLOSED flag is set on the block.
diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs
index 8961930..b133f9f 100644
--- a/src/sync/mpsc/unbounded.rs
+++ b/src/sync/mpsc/unbounded.rs
@@ -129,7 +129,7 @@ impl<T> UnboundedReceiver<T> {
poll_fn(|cx| self.poll_recv(cx)).await
}
- /// Try to receive the next value for this receiver.
+ /// Tries to receive the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs
index 6acd28b..4d9f988 100644
--- a/src/sync/mutex.rs
+++ b/src/sync/mutex.rs
@@ -301,6 +301,40 @@ impl<T: ?Sized> Mutex<T> {
MutexGuard { lock: self }
}
+ /// Blocking lock this mutex. When the lock has been acquired, function returns a
+ /// [`MutexGuard`].
+ ///
+ /// This method is intended for use cases where you
+ /// need to use this mutex in asynchronous code as well as in synchronous code.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::sync::Arc;
+ /// use tokio::sync::Mutex;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let mutex = Arc::new(Mutex::new(1));
+ ///
+ /// let mutex1 = Arc::clone(&mutex);
+ /// let sync_code = tokio::task::spawn_blocking(move || {
+ /// let mut n = mutex1.blocking_lock();
+ /// *n = 2;
+ /// });
+ ///
+ /// sync_code.await.unwrap();
+ ///
+ /// let n = mutex.lock().await;
+ /// assert_eq!(*n, 2);
+ /// }
+ ///
+ /// ```
+ #[cfg(feature = "sync")]
+ pub fn blocking_lock(&self) -> MutexGuard<'_, T> {
+ crate::future::block_on(self.lock())
+ }
+
/// Locks this mutex, causing the current task to yield until the lock has
/// been acquired. When the lock has been acquired, this returns an
/// [`OwnedMutexGuard`].
@@ -462,14 +496,14 @@ where
}
}
-impl<T> std::fmt::Debug for Mutex<T>
+impl<T: ?Sized> std::fmt::Debug for Mutex<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut d = f.debug_struct("Mutex");
match self.try_lock() {
- Ok(inner) => d.field("data", &*inner),
+ Ok(inner) => d.field("data", &&*inner),
Err(_) => d.field("data", &format_args!("<locked>")),
};
d.finish()
diff --git a/src/sync/notify.rs b/src/sync/notify.rs
index 74b97cc..c93ce3b 100644
--- a/src/sync/notify.rs
+++ b/src/sync/notify.rs
@@ -20,7 +20,7 @@ use std::task::{Context, Poll, Waker};
type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
-/// Notify a single task to wake up.
+/// Notifies a single task to wake up.
///
/// `Notify` provides a basic mechanism to notify a single task of an event.
/// `Notify` itself does not carry any data. Instead, it is to be used to signal
@@ -57,13 +57,16 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
/// let notify = Arc::new(Notify::new());
/// let notify2 = notify.clone();
///
-/// tokio::spawn(async move {
+/// let handle = tokio::spawn(async move {
/// notify2.notified().await;
/// println!("received notification");
/// });
///
/// println!("sending notification");
/// notify.notify_one();
+///
+/// // Wait for task to receive notification.
+/// handle.await.unwrap();
/// }
/// ```
///
@@ -128,10 +131,10 @@ enum NotificationType {
#[derive(Debug)]
struct Waiter {
- /// Intrusive linked-list pointers
+ /// Intrusive linked-list pointers.
pointers: linked_list::Pointers<Waiter>,
- /// Waiting task's waker
+ /// Waiting task's waker.
waker: Option<Waker>,
/// `true` if the notification has been assigned to this waiter.
@@ -168,13 +171,13 @@ const NOTIFY_WAITERS_SHIFT: usize = 2;
const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1;
const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK;
-/// Initial "idle" state
+/// Initial "idle" state.
const EMPTY: usize = 0;
/// One or more threads are currently waiting to be notified.
const WAITING: usize = 1;
-/// Pending notification
+/// Pending notification.
const NOTIFIED: usize = 2;
fn set_state(data: usize, state: usize) -> usize {
@@ -289,7 +292,7 @@ impl Notify {
}
}
- /// Notifies a waiting task
+ /// Notifies a waiting task.
///
/// If a task is currently waiting, that task is notified. Otherwise, a
/// permit is stored in this `Notify` value and the **next** call to
@@ -359,7 +362,7 @@ impl Notify {
}
}
- /// Notifies all waiting tasks
+ /// Notifies all waiting tasks.
///
/// If a task is currently waiting, that task is notified. Unlike with
/// `notify_one()`, no permit is stored to be used by the next call to
@@ -551,6 +554,10 @@ impl Future for Notified<'_> {
return Poll::Ready(());
}
+ // Clone the waker before locking, a waker clone can be
+ // triggering arbitrary code.
+ let waker = cx.waker().clone();
+
// Acquire the lock and attempt to transition to the waiting
// state.
let mut waiters = notify.waiters.lock();
@@ -612,7 +619,7 @@ impl Future for Notified<'_> {
// Safety: called while locked.
unsafe {
- (*waiter.get()).waker = Some(cx.waker().clone());
+ (*waiter.get()).waker = Some(waker);
}
// Insert the waiter into the linked list
diff --git a/src/sync/once_cell.rs b/src/sync/once_cell.rs
index 91705a5..d31a40e 100644
--- a/src/sync/once_cell.rs
+++ b/src/sync/once_cell.rs
@@ -245,7 +245,7 @@ impl<T> OnceCell<T> {
}
}
- /// Set the value of the `OnceCell` to the given value if the `OnceCell` is
+ /// Sets the value of the `OnceCell` to the given value if the `OnceCell` is
/// empty.
///
/// If the `OnceCell` already has a value, this call will fail with an
@@ -283,7 +283,7 @@ impl<T> OnceCell<T> {
}
}
- /// Get the value currently in the `OnceCell`, or initialize it with the
+ /// Gets the value currently in the `OnceCell`, or initialize it with the
/// given asynchronous operation.
///
/// If some other task is currently working on initializing the `OnceCell`,
@@ -331,7 +331,7 @@ impl<T> OnceCell<T> {
}
}
- /// Get the value currently in the `OnceCell`, or initialize it with the
+ /// Gets the value currently in the `OnceCell`, or initialize it with the
/// given asynchronous operation.
///
/// If some other task is currently working on initializing the `OnceCell`,
@@ -382,7 +382,7 @@ impl<T> OnceCell<T> {
}
}
- /// Take the value from the cell, destroying the cell in the process.
+ /// Takes the value from the cell, destroying the cell in the process.
/// Returns `None` if the cell is empty.
pub fn into_inner(mut self) -> Option<T> {
if self.initialized_mut() {
diff --git a/src/sync/oneshot.rs b/src/sync/oneshot.rs
index 0df6037..4fb22ec 100644
--- a/src/sync/oneshot.rs
+++ b/src/sync/oneshot.rs
@@ -51,6 +51,70 @@
//! }
//! }
//! ```
+//!
+//! To use a oneshot channel in a `tokio::select!` loop, add `&mut` in front of
+//! the channel.
+//!
+//! ```
+//! use tokio::sync::oneshot;
+//! use tokio::time::{interval, sleep, Duration};
+//!
+//! #[tokio::main]
+//! # async fn _doc() {}
+//! # #[tokio::main(flavor = "current_thread", start_paused = true)]
+//! async fn main() {
+//! let (send, mut recv) = oneshot::channel();
+//! let mut interval = interval(Duration::from_millis(100));
+//!
+//! # let handle =
+//! tokio::spawn(async move {
+//! sleep(Duration::from_secs(1)).await;
+//! send.send("shut down").unwrap();
+//! });
+//!
+//! loop {
+//! tokio::select! {
+//! _ = interval.tick() => println!("Another 100ms"),
+//! msg = &mut recv => {
+//! println!("Got message: {}", msg.unwrap());
+//! break;
+//! }
+//! }
+//! }
+//! # handle.await.unwrap();
+//! }
+//! ```
+//!
+//! To use a `Sender` from a destructor, put it in an [`Option`] and call
+//! [`Option::take`].
+//!
+//! ```
+//! use tokio::sync::oneshot;
+//!
+//! struct SendOnDrop {
+//! sender: Option<oneshot::Sender<&'static str>>,
+//! }
+//! impl Drop for SendOnDrop {
+//! fn drop(&mut self) {
+//! if let Some(sender) = self.sender.take() {
+//! // Using `let _ =` to ignore send errors.
+//! let _ = sender.send("I got dropped!");
+//! }
+//! }
+//! }
+//!
+//! #[tokio::main]
+//! # async fn _doc() {}
+//! # #[tokio::main(flavor = "current_thread")]
+//! async fn main() {
+//! let (send, recv) = oneshot::channel();
+//!
+//! let send_on_drop = SendOnDrop { sender: Some(send) };
+//! drop(send_on_drop);
+//!
+//! assert_eq!(recv.await, Ok("I got dropped!"));
+//! }
+//! ```
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
@@ -68,16 +132,98 @@ use std::task::{Context, Poll, Waker};
///
/// A pair of both a [`Sender`] and a [`Receiver`] are created by the
/// [`channel`](fn@channel) function.
+///
+/// # Examples
+///
+/// ```
+/// use tokio::sync::oneshot;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let (tx, rx) = oneshot::channel();
+///
+/// tokio::spawn(async move {
+/// if let Err(_) = tx.send(3) {
+/// println!("the receiver dropped");
+/// }
+/// });
+///
+/// match rx.await {
+/// Ok(v) => println!("got = {:?}", v),
+/// Err(_) => println!("the sender dropped"),
+/// }
+/// }
+/// ```
+///
+/// If the sender is dropped without sending, the receiver will fail with
+/// [`error::RecvError`]:
+///
+/// ```
+/// use tokio::sync::oneshot;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let (tx, rx) = oneshot::channel::<u32>();
+///
+/// tokio::spawn(async move {
+/// drop(tx);
+/// });
+///
+/// match rx.await {
+/// Ok(_) => panic!("This doesn't happen"),
+/// Err(_) => println!("the sender dropped"),
+/// }
+/// }
+/// ```
+///
+/// To use a `Sender` from a destructor, put it in an [`Option`] and call
+/// [`Option::take`].
+///
+/// ```
+/// use tokio::sync::oneshot;
+///
+/// struct SendOnDrop {
+/// sender: Option<oneshot::Sender<&'static str>>,
+/// }
+/// impl Drop for SendOnDrop {
+/// fn drop(&mut self) {
+/// if let Some(sender) = self.sender.take() {
+/// // Using `let _ =` to ignore send errors.
+/// let _ = sender.send("I got dropped!");
+/// }
+/// }
+/// }
+///
+/// #[tokio::main]
+/// # async fn _doc() {}
+/// # #[tokio::main(flavor = "current_thread")]
+/// async fn main() {
+/// let (send, recv) = oneshot::channel();
+///
+/// let send_on_drop = SendOnDrop { sender: Some(send) };
+/// drop(send_on_drop);
+///
+/// assert_eq!(recv.await, Ok("I got dropped!"));
+/// }
+/// ```
+///
+/// [`Option`]: std::option::Option
+/// [`Option::take`]: std::option::Option::take
#[derive(Debug)]
pub struct Sender<T> {
inner: Option<Arc<Inner<T>>>,
}
-/// Receive a value from the associated [`Sender`].
+/// Receives a value from the associated [`Sender`].
///
/// A pair of both a [`Sender`] and a [`Receiver`] are created by the
/// [`channel`](fn@channel) function.
///
+/// This channel has no `recv` method because the receiver itself implements the
+/// [`Future`] trait. To receive a value, `.await` the `Receiver` object directly.
+///
+/// [`Future`]: trait@std::future::Future
+///
/// # Examples
///
/// ```
@@ -120,13 +266,46 @@ pub struct Sender<T> {
/// }
/// }
/// ```
+///
+/// To use a `Receiver` in a `tokio::select!` loop, add `&mut` in front of the
+/// channel.
+///
+/// ```
+/// use tokio::sync::oneshot;
+/// use tokio::time::{interval, sleep, Duration};
+///
+/// #[tokio::main]
+/// # async fn _doc() {}
+/// # #[tokio::main(flavor = "current_thread", start_paused = true)]
+/// async fn main() {
+/// let (send, mut recv) = oneshot::channel();
+/// let mut interval = interval(Duration::from_millis(100));
+///
+/// # let handle =
+/// tokio::spawn(async move {
+/// sleep(Duration::from_secs(1)).await;
+/// send.send("shut down").unwrap();
+/// });
+///
+/// loop {
+/// tokio::select! {
+/// _ = interval.tick() => println!("Another 100ms"),
+/// msg = &mut recv => {
+/// println!("Got message: {}", msg.unwrap());
+/// break;
+/// }
+/// }
+/// }
+/// # handle.await.unwrap();
+/// }
+/// ```
#[derive(Debug)]
pub struct Receiver<T> {
inner: Option<Arc<Inner<T>>>,
}
pub mod error {
- //! Oneshot error types
+ //! Oneshot error types.
use std::fmt;
@@ -171,7 +350,7 @@ pub mod error {
use self::error::*;
struct Inner<T> {
- /// Manages the state of the inner cell
+ /// Manages the state of the inner cell.
state: AtomicUsize,
/// The value. This is set by `Sender` and read by `Receiver`. The state of
@@ -179,9 +358,19 @@ struct Inner<T> {
value: UnsafeCell<Option<T>>,
/// The task to notify when the receiver drops without consuming the value.
+ ///
+ /// ## Safety
+ ///
+ /// The `TX_TASK_SET` bit in the `state` field is set if this field is
+ /// initialized. If that bit is unset, this field may be uninitialized.
tx_task: Task,
/// The task to notify when the value is sent.
+ ///
+ /// ## Safety
+ ///
+ /// The `RX_TASK_SET` bit in the `state` field is set if this field is
+ /// initialized. If that bit is unset, this field may be uninitialized.
rx_task: Task,
}
@@ -220,7 +409,7 @@ impl Task {
#[derive(Clone, Copy)]
struct State(usize);
-/// Create a new one-shot channel for sending single values across asynchronous
+/// Creates a new one-shot channel for sending single values across asynchronous
/// tasks.
///
/// The function returns separate "send" and "receive" handles. The `Sender`
@@ -311,11 +500,24 @@ impl<T> Sender<T> {
let inner = self.inner.take().unwrap();
inner.value.with_mut(|ptr| unsafe {
+ // SAFETY: The receiver will not access the `UnsafeCell` unless the
+ // channel has been marked as "complete" (the `VALUE_SENT` state bit
+ // is set).
+ // That bit is only set by the sender later on in this method, and
+ // calling this method consumes `self`. Therefore, if it was possible to
+ // call this method, we know that the `VALUE_SENT` bit is unset, and
+ // the receiver is not currently accessing the `UnsafeCell`.
*ptr = Some(t);
});
if !inner.complete() {
unsafe {
+ // SAFETY: The receiver will not access the `UnsafeCell` unless
+ // the channel has been marked as "complete". Calling
+ // `complete()` will return true if this bit is set, and false
+ // if it is not set. Thus, if `complete()` returned false, it is
+ // safe for us to access the value, because we know that the
+ // receiver will not.
return Err(inner.consume_value().unwrap());
}
}
@@ -430,7 +632,7 @@ impl<T> Sender<T> {
state.is_closed()
}
- /// Check whether the oneshot channel has been closed, and if not, schedules the
+ /// Checks whether the oneshot channel has been closed, and if not, schedules the
/// `Waker` in the provided `Context` to receive a notification when the channel is
/// closed.
///
@@ -661,6 +863,11 @@ impl<T> Receiver<T> {
let state = State::load(&inner.state, Acquire);
if state.is_complete() {
+ // SAFETY: If `state.is_complete()` returns true, then the
+ // `VALUE_SENT` bit has been set and the sender side of the
+ // channel will no longer attempt to access the inner
+ // `UnsafeCell`. Therefore, it is now safe for us to access the
+ // cell.
match unsafe { inner.consume_value() } {
Some(value) => Ok(value),
None => Err(TryRecvError::Closed),
@@ -751,6 +958,11 @@ impl<T> Inner<T> {
State::set_rx_task(&self.state);
coop.made_progress();
+ // SAFETY: If `state.is_complete()` returns true, then the
+ // `VALUE_SENT` bit has been set and the sender side of the
+ // channel will no longer attempt to access the inner
+ // `UnsafeCell`. Therefore, it is now safe for us to access the
+ // cell.
return match unsafe { self.consume_value() } {
Some(value) => Ready(Ok(value)),
None => Ready(Err(RecvError(()))),
@@ -797,6 +1009,14 @@ impl<T> Inner<T> {
}
/// Consumes the value. This function does not check `state`.
+ ///
+ /// # Safety
+ ///
+ /// Calling this method concurrently on multiple threads will result in a
+ /// data race. The `VALUE_SENT` state bit is used to ensure that only the
+ /// sender *or* the receiver will call this method at a given point in time.
+ /// If `VALUE_SENT` is not set, then only the sender may call this method;
+ /// if it is set, then only the receiver may call this method.
unsafe fn consume_value(&self) -> Option<T> {
self.value.with_mut(|ptr| (*ptr).take())
}
@@ -837,9 +1057,28 @@ impl<T: fmt::Debug> fmt::Debug for Inner<T> {
}
}
+/// Indicates that a waker for the receiving task has been set.
+///
+/// # Safety
+///
+/// If this bit is not set, the `rx_task` field may be uninitialized.
const RX_TASK_SET: usize = 0b00001;
+/// Indicates that a value has been stored in the channel's inner `UnsafeCell`.
+///
+/// # Safety
+///
+/// This bit controls which side of the channel is permitted to access the
+/// `UnsafeCell`. If it is set, the `UnsafeCell` may ONLY be accessed by the
+/// receiver. If this bit is NOT set, the `UnsafeCell` may ONLY be accessed by
+/// the sender.
const VALUE_SENT: usize = 0b00010;
const CLOSED: usize = 0b00100;
+
+/// Indicates that a waker for the sending task has been set.
+///
+/// # Safety
+///
+/// If this bit is not set, the `tx_task` field may be uninitialized.
const TX_TASK_SET: usize = 0b01000;
impl State {
@@ -852,11 +1091,38 @@ impl State {
}
fn set_complete(cell: &AtomicUsize) -> State {
- // TODO: This could be `Release`, followed by an `Acquire` fence *if*
- // the `RX_TASK_SET` flag is set. However, `loom` does not support
- // fences yet.
- let val = cell.fetch_or(VALUE_SENT, AcqRel);
- State(val)
+ // This method is a compare-and-swap loop rather than a fetch-or like
+ // other `set_$WHATEVER` methods on `State`. This is because we must
+ // check if the state has been closed before setting the `VALUE_SENT`
+ // bit.
+ //
+ // We don't want to set both the `VALUE_SENT` bit if the `CLOSED`
+ // bit is already set, because `VALUE_SENT` will tell the receiver that
+ // it's okay to access the inner `UnsafeCell`. Immediately after calling
+ // `set_complete`, if the channel was closed, the sender will _also_
+ // access the `UnsafeCell` to take the value back out, so if a
+ // `poll_recv` or `try_recv` call is occurring concurrently, both
+ // threads may try to access the `UnsafeCell` if we were to set the
+ // `VALUE_SENT` bit on a closed channel.
+ let mut state = cell.load(Ordering::Relaxed);
+ loop {
+ if State(state).is_closed() {
+ break;
+ }
+ // TODO: This could be `Release`, followed by an `Acquire` fence *if*
+ // the `RX_TASK_SET` flag is set. However, `loom` does not support
+ // fences yet.
+ match cell.compare_exchange_weak(
+ state,
+ state | VALUE_SENT,
+ Ordering::AcqRel,
+ Ordering::Acquire,
+ ) {
+ Ok(_) => break,
+ Err(actual) => state = actual,
+ }
+ }
+ State(state)
}
fn is_rx_task_set(self) -> bool {
diff --git a/src/sync/rwlock/owned_read_guard.rs b/src/sync/rwlock/owned_read_guard.rs
index b7f3926..1881295 100644
--- a/src/sync/rwlock/owned_read_guard.rs
+++ b/src/sync/rwlock/owned_read_guard.rs
@@ -22,7 +22,7 @@ pub struct OwnedRwLockReadGuard<T: ?Sized, U: ?Sized = T> {
}
impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> {
- /// Make a new `OwnedRwLockReadGuard` for a component of the locked data.
+ /// Makes a new `OwnedRwLockReadGuard` for a component of the locked data.
/// This operation cannot fail as the `OwnedRwLockReadGuard` passed in
/// already locked the data.
///
diff --git a/src/sync/rwlock/owned_write_guard.rs b/src/sync/rwlock/owned_write_guard.rs
index 91b6595..0a78d28 100644
--- a/src/sync/rwlock/owned_write_guard.rs
+++ b/src/sync/rwlock/owned_write_guard.rs
@@ -24,7 +24,7 @@ pub struct OwnedRwLockWriteGuard<T: ?Sized> {
}
impl<T: ?Sized> OwnedRwLockWriteGuard<T> {
- /// Make a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked
+ /// Makes a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked
/// data.
///
/// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in
diff --git a/src/sync/rwlock/owned_write_guard_mapped.rs b/src/sync/rwlock/owned_write_guard_mapped.rs
index 6453236..d88ee01 100644
--- a/src/sync/rwlock/owned_write_guard_mapped.rs
+++ b/src/sync/rwlock/owned_write_guard_mapped.rs
@@ -23,7 +23,7 @@ pub struct OwnedRwLockMappedWriteGuard<T: ?Sized, U: ?Sized = T> {
}
impl<T: ?Sized, U: ?Sized> OwnedRwLockMappedWriteGuard<T, U> {
- /// Make a new `OwnedRwLockMappedWriteGuard` for a component of the locked
+ /// Makes a new `OwnedRwLockMappedWriteGuard` for a component of the locked
/// data.
///
/// This operation cannot fail as the `OwnedRwLockMappedWriteGuard` passed
diff --git a/src/sync/rwlock/read_guard.rs b/src/sync/rwlock/read_guard.rs
index 38eec77..090b297 100644
--- a/src/sync/rwlock/read_guard.rs
+++ b/src/sync/rwlock/read_guard.rs
@@ -19,7 +19,7 @@ pub struct RwLockReadGuard<'a, T: ?Sized> {
}
impl<'a, T: ?Sized> RwLockReadGuard<'a, T> {
- /// Make a new `RwLockReadGuard` for a component of the locked data.
+ /// Makes a new `RwLockReadGuard` for a component of the locked data.
///
/// This operation cannot fail as the `RwLockReadGuard` passed in already
/// locked the data.
diff --git a/src/sync/rwlock/write_guard.rs b/src/sync/rwlock/write_guard.rs
index 865a121..8c80ee7 100644
--- a/src/sync/rwlock/write_guard.rs
+++ b/src/sync/rwlock/write_guard.rs
@@ -22,7 +22,7 @@ pub struct RwLockWriteGuard<'a, T: ?Sized> {
}
impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> {
- /// Make a new [`RwLockMappedWriteGuard`] for a component of the locked data.
+ /// Makes a new [`RwLockMappedWriteGuard`] for a component of the locked data.
///
/// This operation cannot fail as the `RwLockWriteGuard` passed in already
/// locked the data.
diff --git a/src/sync/rwlock/write_guard_mapped.rs b/src/sync/rwlock/write_guard_mapped.rs
index 9c5b1e7..3cf69de 100644
--- a/src/sync/rwlock/write_guard_mapped.rs
+++ b/src/sync/rwlock/write_guard_mapped.rs
@@ -21,7 +21,7 @@ pub struct RwLockMappedWriteGuard<'a, T: ?Sized> {
}
impl<'a, T: ?Sized> RwLockMappedWriteGuard<'a, T> {
- /// Make a new `RwLockMappedWriteGuard` for a component of the locked data.
+ /// Makes a new `RwLockMappedWriteGuard` for a component of the locked data.
///
/// This operation cannot fail as the `RwLockMappedWriteGuard` passed in already
/// locked the data.
diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs
index 8616007..e1330fb 100644
--- a/src/sync/task/atomic_waker.rs
+++ b/src/sync/task/atomic_waker.rs
@@ -4,6 +4,7 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::{self, AtomicUsize};
use std::fmt;
+use std::panic::{resume_unwind, AssertUnwindSafe, RefUnwindSafe, UnwindSafe};
use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
use std::task::Waker;
@@ -27,6 +28,9 @@ pub(crate) struct AtomicWaker {
waker: UnsafeCell<Option<Waker>>,
}
+impl RefUnwindSafe for AtomicWaker {}
+impl UnwindSafe for AtomicWaker {}
+
// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell
// stores a `Waker` value produced by calls to `register` and many threads can
// race to take the waker by calling `wake`.
@@ -84,7 +88,7 @@ pub(crate) struct AtomicWaker {
// back to `WAITING`. This transition must succeed as, at this point, the state
// cannot be transitioned by another thread.
//
-// If the thread is unable to obtain the lock, the `WAKING` bit is still.
+// If the thread is unable to obtain the lock, the `WAKING` bit is still set.
// This is because it has either been set by the current thread but the previous
// value included the `REGISTERING` bit **or** a concurrent thread is in the
// `WAKING` critical section. Either way, no action must be taken.
@@ -123,7 +127,7 @@ pub(crate) struct AtomicWaker {
// Thread A still holds the `wake` lock, the call to `register` will result
// in the task waking itself and get scheduled again.
-/// Idle state
+/// Idle state.
const WAITING: usize = 0;
/// A new waker value is being registered with the `AtomicWaker` cell.
@@ -171,6 +175,10 @@ impl AtomicWaker {
where
W: WakerRef,
{
+ fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> std::thread::Result<R> {
+ std::panic::catch_unwind(AssertUnwindSafe(f))
+ }
+
match self
.state
.compare_exchange(WAITING, REGISTERING, Acquire, Acquire)
@@ -178,8 +186,24 @@ impl AtomicWaker {
{
WAITING => {
unsafe {
- // Locked acquired, update the waker cell
- self.waker.with_mut(|t| *t = Some(waker.into_waker()));
+ // If `into_waker` panics (because it's code outside of
+ // AtomicWaker) we need to prime a guard that is called on
+ // unwind to restore the waker to a WAITING state. Otherwise
+ // any future calls to register will incorrectly be stuck
+ // believing it's being updated by someone else.
+ let new_waker_or_panic = catch_unwind(move || waker.into_waker());
+
+ // Set the field to contain the new waker, or if
+ // `into_waker` panicked, leave the old value.
+ let mut maybe_panic = None;
+ let mut old_waker = None;
+ match new_waker_or_panic {
+ Ok(new_waker) => {
+ old_waker = self.waker.with_mut(|t| (*t).take());
+ self.waker.with_mut(|t| *t = Some(new_waker));
+ }
+ Err(panic) => maybe_panic = Some(panic),
+ }
// Release the lock. If the state transitioned to include
// the `WAKING` bit, this means that a wake has been
@@ -193,33 +217,67 @@ impl AtomicWaker {
.compare_exchange(REGISTERING, WAITING, AcqRel, Acquire);
match res {
- Ok(_) => {}
+ Ok(_) => {
+ // We don't want to give the caller the panic if it
+ // was someone else who put in that waker.
+ let _ = catch_unwind(move || {
+ drop(old_waker);
+ });
+ }
Err(actual) => {
// This branch can only be reached if a
// concurrent thread called `wake`. In this
// case, `actual` **must** be `REGISTERING |
- // `WAKING`.
+ // WAKING`.
debug_assert_eq!(actual, REGISTERING | WAKING);
// Take the waker to wake once the atomic operation has
// completed.
- let waker = self.waker.with_mut(|t| (*t).take()).unwrap();
+ let mut waker = self.waker.with_mut(|t| (*t).take());
// Just swap, because no one could change state
// while state == `Registering | `Waking`
self.state.swap(WAITING, AcqRel);
- // The atomic swap was complete, now
- // wake the waker and return.
- waker.wake();
+ // If `into_waker` panicked, then the waker in the
+ // waker slot is actually the old waker.
+ if maybe_panic.is_some() {
+ old_waker = waker.take();
+ }
+
+ // We don't want to give the caller the panic if it
+ // was someone else who put in that waker.
+ if let Some(old_waker) = old_waker {
+ let _ = catch_unwind(move || {
+ old_waker.wake();
+ });
+ }
+
+ // The atomic swap was complete, now wake the waker
+ // and return.
+ //
+ // If this panics, we end up in a consumed state and
+ // return the panic to the caller.
+ if let Some(waker) = waker {
+ debug_assert!(maybe_panic.is_none());
+ waker.wake();
+ }
}
}
+
+ if let Some(panic) = maybe_panic {
+ // If `into_waker` panicked, return the panic to the caller.
+ resume_unwind(panic);
+ }
}
}
WAKING => {
// Currently in the process of waking the task, i.e.,
// `wake` is currently being called on the old waker.
// So, we call wake on the new waker.
+ //
+ // If this panics, someone else is responsible for restoring the
+ // state of the waker.
waker.wake();
// This is equivalent to a spin lock, so use a spin hint.
@@ -245,6 +303,8 @@ impl AtomicWaker {
/// If `register` has not been called yet, then this does nothing.
pub(crate) fn wake(&self) {
if let Some(waker) = self.take_waker() {
+ // If wake panics, we've consumed the waker which is a legitimate
+ // outcome.
waker.wake();
}
}
diff --git a/src/sync/tests/atomic_waker.rs b/src/sync/tests/atomic_waker.rs
index c832d62..b167a5d 100644
--- a/src/sync/tests/atomic_waker.rs
+++ b/src/sync/tests/atomic_waker.rs
@@ -32,3 +32,42 @@ fn wake_without_register() {
assert!(!waker.is_woken());
}
+
+#[test]
+fn atomic_waker_panic_safe() {
+ use std::panic;
+ use std::ptr;
+ use std::task::{RawWaker, RawWakerVTable, Waker};
+
+ static PANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new(
+ |_| panic!("clone"),
+ |_| unimplemented!("wake"),
+ |_| unimplemented!("wake_by_ref"),
+ |_| (),
+ );
+
+ static NONPANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new(
+ |_| RawWaker::new(ptr::null(), &NONPANICKING_VTABLE),
+ |_| unimplemented!("wake"),
+ |_| unimplemented!("wake_by_ref"),
+ |_| (),
+ );
+
+ let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) };
+ let nonpanicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NONPANICKING_VTABLE)) };
+
+ let atomic_waker = AtomicWaker::new();
+
+ let panicking = panic::AssertUnwindSafe(&panicking);
+
+ let result = panic::catch_unwind(|| {
+ let panic::AssertUnwindSafe(panicking) = panicking;
+ atomic_waker.register_by_ref(panicking);
+ });
+
+ assert!(result.is_err());
+ assert!(atomic_waker.take_waker().is_none());
+
+ atomic_waker.register_by_ref(&nonpanicking);
+ assert!(atomic_waker.take_waker().is_some());
+}
diff --git a/src/sync/tests/loom_atomic_waker.rs b/src/sync/tests/loom_atomic_waker.rs
index c148bcb..f8bae65 100644
--- a/src/sync/tests/loom_atomic_waker.rs
+++ b/src/sync/tests/loom_atomic_waker.rs
@@ -43,3 +43,58 @@ fn basic_notification() {
}));
});
}
+
+#[test]
+fn test_panicky_waker() {
+ use std::panic;
+ use std::ptr;
+ use std::task::{RawWaker, RawWakerVTable, Waker};
+
+ static PANICKING_VTABLE: RawWakerVTable =
+ RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ());
+
+ let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) };
+
+ // If you're working with this test (and I sure hope you never have to!),
+ // uncomment the following section because there will be a lot of panics
+ // which would otherwise log.
+ //
+ // We can't however leaved it uncommented, because it's global.
+ // panic::set_hook(Box::new(|_| ()));
+
+ const NUM_NOTIFY: usize = 2;
+
+ loom::model(move || {
+ let chan = Arc::new(Chan {
+ num: AtomicUsize::new(0),
+ task: AtomicWaker::new(),
+ });
+
+ for _ in 0..NUM_NOTIFY {
+ let chan = chan.clone();
+
+ thread::spawn(move || {
+ chan.num.fetch_add(1, Relaxed);
+ chan.task.wake();
+ });
+ }
+
+ // Note: this panic should have no effect on the overall state of the
+ // waker and it should proceed as normal.
+ //
+ // A thread above might race to flag a wakeup, and a WAKING state will
+ // be preserved if this expected panic races with that so the below
+ // procedure should be allowed to continue uninterrupted.
+ let _ = panic::catch_unwind(|| chan.task.register_by_ref(&panicking));
+
+ block_on(poll_fn(move |cx| {
+ chan.task.register_by_ref(cx.waker());
+
+ if NUM_NOTIFY == chan.num.load(Relaxed) {
+ return Ready(());
+ }
+
+ Pending
+ }));
+ });
+}
diff --git a/src/sync/tests/loom_oneshot.rs b/src/sync/tests/loom_oneshot.rs
index 9729cfb..c5f7972 100644
--- a/src/sync/tests/loom_oneshot.rs
+++ b/src/sync/tests/loom_oneshot.rs
@@ -55,6 +55,35 @@ fn changing_rx_task() {
});
}
+#[test]
+fn try_recv_close() {
+ // reproduces https://github.com/tokio-rs/tokio/issues/4225
+ loom::model(|| {
+ let (tx, mut rx) = oneshot::channel();
+ thread::spawn(move || {
+ let _ = tx.send(());
+ });
+
+ rx.close();
+ let _ = rx.try_recv();
+ })
+}
+
+#[test]
+fn recv_closed() {
+ // reproduces https://github.com/tokio-rs/tokio/issues/4225
+ loom::model(|| {
+ let (tx, mut rx) = oneshot::channel();
+
+ thread::spawn(move || {
+ let _ = tx.send(1);
+ });
+
+ rx.close();
+ let _ = block_on(rx);
+ });
+}
+
// TODO: Move this into `oneshot` proper.
use std::future::Future;
diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs
index c5d5601..ee76418 100644
--- a/src/sync/tests/mod.rs
+++ b/src/sync/tests/mod.rs
@@ -1,5 +1,6 @@
cfg_not_loom! {
mod atomic_waker;
+ mod notify;
mod semaphore_batch;
}
diff --git a/src/sync/tests/notify.rs b/src/sync/tests/notify.rs
new file mode 100644
index 0000000..8c9a573
--- /dev/null
+++ b/src/sync/tests/notify.rs
@@ -0,0 +1,44 @@
+use crate::sync::Notify;
+use std::future::Future;
+use std::mem::ManuallyDrop;
+use std::sync::Arc;
+use std::task::{Context, RawWaker, RawWakerVTable, Waker};
+
+#[test]
+fn notify_clones_waker_before_lock() {
+ const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone_w, wake, wake_by_ref, drop_w);
+
+ unsafe fn clone_w(data: *const ()) -> RawWaker {
+ let arc = ManuallyDrop::new(Arc::<Notify>::from_raw(data as *const Notify));
+ // Or some other arbitrary code that shouldn't be executed while the
+ // Notify wait list is locked.
+ arc.notify_one();
+ let _arc_clone: ManuallyDrop<_> = arc.clone();
+ RawWaker::new(data, VTABLE)
+ }
+
+ unsafe fn drop_w(data: *const ()) {
+ let _ = Arc::<Notify>::from_raw(data as *const Notify);
+ }
+
+ unsafe fn wake(_data: *const ()) {
+ unreachable!()
+ }
+
+ unsafe fn wake_by_ref(_data: *const ()) {
+ unreachable!()
+ }
+
+ let notify = Arc::new(Notify::new());
+ let notify2 = notify.clone();
+
+ let waker =
+ unsafe { Waker::from_raw(RawWaker::new(Arc::into_raw(notify2) as *const _, VTABLE)) };
+ let mut cx = Context::from_waker(&waker);
+
+ let future = notify.notified();
+ pin!(future);
+
+ // The result doesn't matter, we're just testing that we don't deadlock.
+ let _ = future.poll(&mut cx);
+}
diff --git a/src/sync/watch.rs b/src/sync/watch.rs
index b5da218..7e45c11 100644
--- a/src/sync/watch.rs
+++ b/src/sync/watch.rs
@@ -58,6 +58,7 @@ use crate::sync::notify::Notify;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::Relaxed;
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard};
+use std::mem;
use std::ops;
/// Receives values from the associated [`Sender`](struct@Sender).
@@ -85,7 +86,7 @@ pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
-/// Returns a reference to the inner value
+/// Returns a reference to the inner value.
///
/// Outstanding borrows hold a read lock on the inner value. This means that
/// long lived borrows could cause the produce half to block. It is recommended
@@ -97,27 +98,27 @@ pub struct Ref<'a, T> {
#[derive(Debug)]
struct Shared<T> {
- /// The most recent value
+ /// The most recent value.
value: RwLock<T>,
- /// The current version
+ /// The current version.
///
/// The lowest bit represents a "closed" state. The rest of the bits
/// represent the current version.
state: AtomicState,
- /// Tracks the number of `Receiver` instances
+ /// Tracks the number of `Receiver` instances.
ref_count_rx: AtomicUsize,
/// Notifies waiting receivers that the value changed.
notify_rx: Notify,
- /// Notifies any task listening for `Receiver` dropped events
+ /// Notifies any task listening for `Receiver` dropped events.
notify_tx: Notify,
}
pub mod error {
- //! Watch error types
+ //! Watch error types.
use std::fmt;
@@ -317,7 +318,7 @@ impl<T> Receiver<T> {
Ref { inner }
}
- /// Wait for a change notification, then mark the newest value as seen.
+ /// Waits for a change notification, then marks the newest value as seen.
///
/// If the newest value in the channel has not yet been marked seen when
/// this method is called, the method marks that value seen and returns
@@ -432,10 +433,31 @@ impl<T> Sender<T> {
return Err(error::SendError(value));
}
- {
+ self.send_replace(value);
+ Ok(())
+ }
+
+ /// Sends a new value via the channel, notifying all receivers and returning
+ /// the previous value in the channel.
+ ///
+ /// This can be useful for reusing the buffers inside a watched value.
+ /// Additionally, this method permits sending values even when there are no
+ /// receivers.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::watch;
+ ///
+ /// let (tx, _rx) = watch::channel(1);
+ /// assert_eq!(tx.send_replace(2), 1);
+ /// assert_eq!(tx.send_replace(3), 2);
+ /// ```
+ pub fn send_replace(&self, value: T) -> T {
+ let old = {
// Acquire the write lock and update the value.
let mut lock = self.shared.value.write().unwrap();
- *lock = value;
+ let old = mem::replace(&mut *lock, value);
self.shared.state.increment_version();
@@ -445,12 +467,14 @@ impl<T> Sender<T> {
// that receivers are able to figure out the version number of the
// value they are currently looking at.
drop(lock);
- }
+
+ old
+ };
// Notify all watchers
self.shared.notify_rx.notify_waiters();
- Ok(())
+ old
}
/// Returns a reference to the most recently sent value
@@ -595,7 +619,7 @@ impl<T> Sender<T> {
Receiver::from_shared(version, shared)
}
- /// Returns the number of receivers that currently exist
+ /// Returns the number of receivers that currently exist.
///
/// # Examples
///