diff options
Diffstat (limited to 'src/sync')
-rw-r--r-- | src/sync/batch_semaphore.rs | 8 | ||||
-rw-r--r-- | src/sync/broadcast.rs | 32 | ||||
-rw-r--r-- | src/sync/mpsc/block.rs | 4 | ||||
-rw-r--r-- | src/sync/mpsc/bounded.rs | 20 | ||||
-rw-r--r-- | src/sync/mpsc/chan.rs | 6 | ||||
-rw-r--r-- | src/sync/mpsc/error.rs | 2 | ||||
-rw-r--r-- | src/sync/mpsc/list.rs | 4 | ||||
-rw-r--r-- | src/sync/mpsc/unbounded.rs | 2 | ||||
-rw-r--r-- | src/sync/mutex.rs | 38 | ||||
-rw-r--r-- | src/sync/notify.rs | 25 | ||||
-rw-r--r-- | src/sync/once_cell.rs | 8 | ||||
-rw-r--r-- | src/sync/oneshot.rs | 286 | ||||
-rw-r--r-- | src/sync/rwlock/owned_read_guard.rs | 2 | ||||
-rw-r--r-- | src/sync/rwlock/owned_write_guard.rs | 2 | ||||
-rw-r--r-- | src/sync/rwlock/owned_write_guard_mapped.rs | 2 | ||||
-rw-r--r-- | src/sync/rwlock/read_guard.rs | 2 | ||||
-rw-r--r-- | src/sync/rwlock/write_guard.rs | 2 | ||||
-rw-r--r-- | src/sync/rwlock/write_guard_mapped.rs | 2 | ||||
-rw-r--r-- | src/sync/task/atomic_waker.rs | 80 | ||||
-rw-r--r-- | src/sync/tests/atomic_waker.rs | 39 | ||||
-rw-r--r-- | src/sync/tests/loom_atomic_waker.rs | 55 | ||||
-rw-r--r-- | src/sync/tests/loom_oneshot.rs | 29 | ||||
-rw-r--r-- | src/sync/tests/mod.rs | 1 | ||||
-rw-r--r-- | src/sync/tests/notify.rs | 44 | ||||
-rw-r--r-- | src/sync/watch.rs | 48 |
25 files changed, 651 insertions, 92 deletions
diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs index 9b43404..b5c39d2 100644 --- a/src/sync/batch_semaphore.rs +++ b/src/sync/batch_semaphore.rs @@ -1,5 +1,5 @@ #![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] -//! # Implementation Details +//! # Implementation Details. //! //! The semaphore is implemented using an intrusive linked list of waiters. An //! atomic counter tracks the number of available permits. If the semaphore does @@ -138,7 +138,7 @@ impl Semaphore { } } - /// Creates a new semaphore with the initial number of permits + /// Creates a new semaphore with the initial number of permits. /// /// Maximum number of permits on 32-bit platforms is `1<<29`. /// @@ -159,7 +159,7 @@ impl Semaphore { } } - /// Returns the current number of available permits + /// Returns the current number of available permits. pub(crate) fn available_permits(&self) -> usize { self.permits.load(Acquire) >> Self::PERMIT_SHIFT } @@ -197,7 +197,7 @@ impl Semaphore { } } - /// Returns true if the semaphore is closed + /// Returns true if the semaphore is closed. pub(crate) fn is_closed(&self) -> bool { self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED } diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs index a2ca445..0d9cd3b 100644 --- a/src/sync/broadcast.rs +++ b/src/sync/broadcast.rs @@ -293,37 +293,37 @@ pub mod error { use self::error::*; -/// Data shared between senders and receivers +/// Data shared between senders and receivers. struct Shared<T> { - /// slots in the channel + /// slots in the channel. buffer: Box<[RwLock<Slot<T>>]>, - /// Mask a position -> index + /// Mask a position -> index. mask: usize, /// Tail of the queue. Includes the rx wait list. tail: Mutex<Tail>, - /// Number of outstanding Sender handles + /// Number of outstanding Sender handles. num_tx: AtomicUsize, } -/// Next position to write a value +/// Next position to write a value. struct Tail { - /// Next position to write to + /// Next position to write to. pos: u64, - /// Number of active receivers + /// Number of active receivers. rx_cnt: usize, - /// True if the channel is closed + /// True if the channel is closed. closed: bool, - /// Receivers waiting for a value + /// Receivers waiting for a value. waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, } -/// Slot in the buffer +/// Slot in the buffer. struct Slot<T> { /// Remaining number of receivers that are expected to see this value. /// @@ -333,7 +333,7 @@ struct Slot<T> { /// acquired. rem: AtomicUsize, - /// Uniquely identifies the `send` stored in the slot + /// Uniquely identifies the `send` stored in the slot. pos: u64, /// True signals the channel is closed. @@ -346,9 +346,9 @@ struct Slot<T> { val: UnsafeCell<Option<T>>, } -/// An entry in the wait queue +/// An entry in the wait queue. struct Waiter { - /// True if queued + /// True if queued. queued: bool, /// Task waiting on the broadcast channel. @@ -365,12 +365,12 @@ struct RecvGuard<'a, T> { slot: RwLockReadGuard<'a, Slot<T>>, } -/// Receive a value future +/// Receive a value future. struct Recv<'a, T> { - /// Receiver being waited on + /// Receiver being waited on. receiver: &'a mut Receiver<T>, - /// Entry in the waiter `LinkedList` + /// Entry in the waiter `LinkedList`. waiter: UnsafeCell<Waiter>, } diff --git a/src/sync/mpsc/block.rs b/src/sync/mpsc/block.rs index 6e7b700..58f4a9f 100644 --- a/src/sync/mpsc/block.rs +++ b/src/sync/mpsc/block.rs @@ -40,7 +40,7 @@ struct Values<T>([UnsafeCell<MaybeUninit<T>>; BLOCK_CAP]); use super::BLOCK_CAP; -/// Masks an index to get the block identifier +/// Masks an index to get the block identifier. const BLOCK_MASK: usize = !(BLOCK_CAP - 1); /// Masks an index to get the value offset in a block. @@ -89,7 +89,7 @@ impl<T> Block<T> { } } - /// Returns `true` if the block matches the given index + /// Returns `true` if the block matches the given index. pub(crate) fn is_at_index(&self, index: usize) -> bool { debug_assert!(offset(index) == 0); self.start_index == index diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs index bcad84d..5a2bfa6 100644 --- a/src/sync/mpsc/bounded.rs +++ b/src/sync/mpsc/bounded.rs @@ -10,7 +10,7 @@ cfg_time! { use std::fmt; use std::task::{Context, Poll}; -/// Send values to the associated `Receiver`. +/// Sends values to the associated `Receiver`. /// /// Instances are created by the [`channel`](channel) function. /// @@ -22,7 +22,7 @@ pub struct Sender<T> { chan: chan::Tx<T, Semaphore>, } -/// Permit to send one value into the channel. +/// Permits to send one value into the channel. /// /// `Permit` values are returned by [`Sender::reserve()`] and [`Sender::try_reserve()`] /// and are used to guarantee channel capacity before generating a message to send. @@ -49,7 +49,7 @@ pub struct OwnedPermit<T> { chan: Option<chan::Tx<T, Semaphore>>, } -/// Receive values from the associated `Sender`. +/// Receives values from the associated `Sender`. /// /// Instances are created by the [`channel`](channel) function. /// @@ -57,7 +57,7 @@ pub struct OwnedPermit<T> { /// /// [`ReceiverStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.ReceiverStream.html pub struct Receiver<T> { - /// The channel receiver + /// The channel receiver. chan: chan::Rx<T, Semaphore>, } @@ -187,7 +187,7 @@ impl<T> Receiver<T> { poll_fn(|cx| self.chan.recv(cx)).await } - /// Try to receive the next value for this receiver. + /// Tries to receive the next value for this receiver. /// /// This method returns the [`Empty`] error if the channel is currently /// empty, but there are still outstanding [senders] or [permits]. @@ -672,7 +672,7 @@ impl<T> Sender<T> { self.chan.is_closed() } - /// Wait for channel capacity. Once capacity to send one message is + /// Waits for channel capacity. Once capacity to send one message is /// available, it is reserved for the caller. /// /// If the channel is full, the function waits for the number of unreceived @@ -721,7 +721,7 @@ impl<T> Sender<T> { Ok(Permit { chan: &self.chan }) } - /// Wait for channel capacity, moving the `Sender` and returning an owned + /// Waits for channel capacity, moving the `Sender` and returning an owned /// permit. Once capacity to send one message is available, it is reserved /// for the caller. /// @@ -815,7 +815,7 @@ impl<T> Sender<T> { } } - /// Try to acquire a slot in the channel without waiting for the slot to become + /// Tries to acquire a slot in the channel without waiting for the slot to become /// available. /// /// If the channel is full this function will return [`TrySendError`], otherwise @@ -868,7 +868,7 @@ impl<T> Sender<T> { Ok(Permit { chan: &self.chan }) } - /// Try to acquire a slot in the channel without waiting for the slot to become + /// Tries to acquire a slot in the channel without waiting for the slot to become /// available, returning an owned permit. /// /// This moves the sender _by value_, and returns an owned permit that can @@ -1117,7 +1117,7 @@ impl<T> OwnedPermit<T> { Sender { chan } } - /// Release the reserved capacity *without* sending a message, returning the + /// Releases the reserved capacity *without* sending a message, returning the /// [`Sender`]. /// /// # Examples diff --git a/src/sync/mpsc/chan.rs b/src/sync/mpsc/chan.rs index 637ae1f..c3007de 100644 --- a/src/sync/mpsc/chan.rs +++ b/src/sync/mpsc/chan.rs @@ -14,7 +14,7 @@ use std::sync::atomic::Ordering::{AcqRel, Relaxed}; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; -/// Channel sender +/// Channel sender. pub(crate) struct Tx<T, S> { inner: Arc<Chan<T, S>>, } @@ -25,7 +25,7 @@ impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> { } } -/// Channel receiver +/// Channel receiver. pub(crate) struct Rx<T, S: Semaphore> { inner: Arc<Chan<T, S>>, } @@ -47,7 +47,7 @@ pub(crate) trait Semaphore { } struct Chan<T, S> { - /// Notifies all tasks listening for the receiver being dropped + /// Notifies all tasks listening for the receiver being dropped. notify_rx_closed: Notify, /// Handle to the push half of the lock-free list. diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs index 48ca379..b7b9cf7 100644 --- a/src/sync/mpsc/error.rs +++ b/src/sync/mpsc/error.rs @@ -1,4 +1,4 @@ -//! Channel error types +//! Channel error types. use std::error::Error; use std::fmt; diff --git a/src/sync/mpsc/list.rs b/src/sync/mpsc/list.rs index 53c34d2..e4eeb45 100644 --- a/src/sync/mpsc/list.rs +++ b/src/sync/mpsc/list.rs @@ -8,7 +8,7 @@ use std::fmt; use std::ptr::NonNull; use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; -/// List queue transmit handle +/// List queue transmit handle. pub(crate) struct Tx<T> { /// Tail in the `Block` mpmc list. block_tail: AtomicPtr<Block<T>>, @@ -79,7 +79,7 @@ impl<T> Tx<T> { } } - /// Closes the send half of the list + /// Closes the send half of the list. /// /// Similar process as pushing a value, but instead of writing the value & /// setting the ready flag, the TX_CLOSED flag is set on the block. diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs index 8961930..b133f9f 100644 --- a/src/sync/mpsc/unbounded.rs +++ b/src/sync/mpsc/unbounded.rs @@ -129,7 +129,7 @@ impl<T> UnboundedReceiver<T> { poll_fn(|cx| self.poll_recv(cx)).await } - /// Try to receive the next value for this receiver. + /// Tries to receive the next value for this receiver. /// /// This method returns the [`Empty`] error if the channel is currently /// empty, but there are still outstanding [senders] or [permits]. diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs index 6acd28b..4d9f988 100644 --- a/src/sync/mutex.rs +++ b/src/sync/mutex.rs @@ -301,6 +301,40 @@ impl<T: ?Sized> Mutex<T> { MutexGuard { lock: self } } + /// Blocking lock this mutex. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// This method is intended for use cases where you + /// need to use this mutex in asynchronous code as well as in synchronous code. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let mutex1 = Arc::clone(&mutex); + /// let sync_code = tokio::task::spawn_blocking(move || { + /// let mut n = mutex1.blocking_lock(); + /// *n = 2; + /// }); + /// + /// sync_code.await.unwrap(); + /// + /// let n = mutex.lock().await; + /// assert_eq!(*n, 2); + /// } + /// + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_lock(&self) -> MutexGuard<'_, T> { + crate::future::block_on(self.lock()) + } + /// Locks this mutex, causing the current task to yield until the lock has /// been acquired. When the lock has been acquired, this returns an /// [`OwnedMutexGuard`]. @@ -462,14 +496,14 @@ where } } -impl<T> std::fmt::Debug for Mutex<T> +impl<T: ?Sized> std::fmt::Debug for Mutex<T> where T: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut d = f.debug_struct("Mutex"); match self.try_lock() { - Ok(inner) => d.field("data", &*inner), + Ok(inner) => d.field("data", &&*inner), Err(_) => d.field("data", &format_args!("<locked>")), }; d.finish() diff --git a/src/sync/notify.rs b/src/sync/notify.rs index 74b97cc..c93ce3b 100644 --- a/src/sync/notify.rs +++ b/src/sync/notify.rs @@ -20,7 +20,7 @@ use std::task::{Context, Poll, Waker}; type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; -/// Notify a single task to wake up. +/// Notifies a single task to wake up. /// /// `Notify` provides a basic mechanism to notify a single task of an event. /// `Notify` itself does not carry any data. Instead, it is to be used to signal @@ -57,13 +57,16 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; /// let notify = Arc::new(Notify::new()); /// let notify2 = notify.clone(); /// -/// tokio::spawn(async move { +/// let handle = tokio::spawn(async move { /// notify2.notified().await; /// println!("received notification"); /// }); /// /// println!("sending notification"); /// notify.notify_one(); +/// +/// // Wait for task to receive notification. +/// handle.await.unwrap(); /// } /// ``` /// @@ -128,10 +131,10 @@ enum NotificationType { #[derive(Debug)] struct Waiter { - /// Intrusive linked-list pointers + /// Intrusive linked-list pointers. pointers: linked_list::Pointers<Waiter>, - /// Waiting task's waker + /// Waiting task's waker. waker: Option<Waker>, /// `true` if the notification has been assigned to this waiter. @@ -168,13 +171,13 @@ const NOTIFY_WAITERS_SHIFT: usize = 2; const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1; const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK; -/// Initial "idle" state +/// Initial "idle" state. const EMPTY: usize = 0; /// One or more threads are currently waiting to be notified. const WAITING: usize = 1; -/// Pending notification +/// Pending notification. const NOTIFIED: usize = 2; fn set_state(data: usize, state: usize) -> usize { @@ -289,7 +292,7 @@ impl Notify { } } - /// Notifies a waiting task + /// Notifies a waiting task. /// /// If a task is currently waiting, that task is notified. Otherwise, a /// permit is stored in this `Notify` value and the **next** call to @@ -359,7 +362,7 @@ impl Notify { } } - /// Notifies all waiting tasks + /// Notifies all waiting tasks. /// /// If a task is currently waiting, that task is notified. Unlike with /// `notify_one()`, no permit is stored to be used by the next call to @@ -551,6 +554,10 @@ impl Future for Notified<'_> { return Poll::Ready(()); } + // Clone the waker before locking, a waker clone can be + // triggering arbitrary code. + let waker = cx.waker().clone(); + // Acquire the lock and attempt to transition to the waiting // state. let mut waiters = notify.waiters.lock(); @@ -612,7 +619,7 @@ impl Future for Notified<'_> { // Safety: called while locked. unsafe { - (*waiter.get()).waker = Some(cx.waker().clone()); + (*waiter.get()).waker = Some(waker); } // Insert the waiter into the linked list diff --git a/src/sync/once_cell.rs b/src/sync/once_cell.rs index 91705a5..d31a40e 100644 --- a/src/sync/once_cell.rs +++ b/src/sync/once_cell.rs @@ -245,7 +245,7 @@ impl<T> OnceCell<T> { } } - /// Set the value of the `OnceCell` to the given value if the `OnceCell` is + /// Sets the value of the `OnceCell` to the given value if the `OnceCell` is /// empty. /// /// If the `OnceCell` already has a value, this call will fail with an @@ -283,7 +283,7 @@ impl<T> OnceCell<T> { } } - /// Get the value currently in the `OnceCell`, or initialize it with the + /// Gets the value currently in the `OnceCell`, or initialize it with the /// given asynchronous operation. /// /// If some other task is currently working on initializing the `OnceCell`, @@ -331,7 +331,7 @@ impl<T> OnceCell<T> { } } - /// Get the value currently in the `OnceCell`, or initialize it with the + /// Gets the value currently in the `OnceCell`, or initialize it with the /// given asynchronous operation. /// /// If some other task is currently working on initializing the `OnceCell`, @@ -382,7 +382,7 @@ impl<T> OnceCell<T> { } } - /// Take the value from the cell, destroying the cell in the process. + /// Takes the value from the cell, destroying the cell in the process. /// Returns `None` if the cell is empty. pub fn into_inner(mut self) -> Option<T> { if self.initialized_mut() { diff --git a/src/sync/oneshot.rs b/src/sync/oneshot.rs index 0df6037..4fb22ec 100644 --- a/src/sync/oneshot.rs +++ b/src/sync/oneshot.rs @@ -51,6 +51,70 @@ //! } //! } //! ``` +//! +//! To use a oneshot channel in a `tokio::select!` loop, add `&mut` in front of +//! the channel. +//! +//! ``` +//! use tokio::sync::oneshot; +//! use tokio::time::{interval, sleep, Duration}; +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread", start_paused = true)] +//! async fn main() { +//! let (send, mut recv) = oneshot::channel(); +//! let mut interval = interval(Duration::from_millis(100)); +//! +//! # let handle = +//! tokio::spawn(async move { +//! sleep(Duration::from_secs(1)).await; +//! send.send("shut down").unwrap(); +//! }); +//! +//! loop { +//! tokio::select! { +//! _ = interval.tick() => println!("Another 100ms"), +//! msg = &mut recv => { +//! println!("Got message: {}", msg.unwrap()); +//! break; +//! } +//! } +//! } +//! # handle.await.unwrap(); +//! } +//! ``` +//! +//! To use a `Sender` from a destructor, put it in an [`Option`] and call +//! [`Option::take`]. +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! struct SendOnDrop { +//! sender: Option<oneshot::Sender<&'static str>>, +//! } +//! impl Drop for SendOnDrop { +//! fn drop(&mut self) { +//! if let Some(sender) = self.sender.take() { +//! // Using `let _ =` to ignore send errors. +//! let _ = sender.send("I got dropped!"); +//! } +//! } +//! } +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread")] +//! async fn main() { +//! let (send, recv) = oneshot::channel(); +//! +//! let send_on_drop = SendOnDrop { sender: Some(send) }; +//! drop(send_on_drop); +//! +//! assert_eq!(recv.await, Ok("I got dropped!")); +//! } +//! ``` use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; @@ -68,16 +132,98 @@ use std::task::{Context, Poll, Waker}; /// /// A pair of both a [`Sender`] and a [`Receiver`] are created by the /// [`channel`](fn@channel) function. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// If the sender is dropped without sending, the receiver will fail with +/// [`error::RecvError`]: +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel::<u32>(); +/// +/// tokio::spawn(async move { +/// drop(tx); +/// }); +/// +/// match rx.await { +/// Ok(_) => panic!("This doesn't happen"), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// To use a `Sender` from a destructor, put it in an [`Option`] and call +/// [`Option::take`]. +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// struct SendOnDrop { +/// sender: Option<oneshot::Sender<&'static str>>, +/// } +/// impl Drop for SendOnDrop { +/// fn drop(&mut self) { +/// if let Some(sender) = self.sender.take() { +/// // Using `let _ =` to ignore send errors. +/// let _ = sender.send("I got dropped!"); +/// } +/// } +/// } +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread")] +/// async fn main() { +/// let (send, recv) = oneshot::channel(); +/// +/// let send_on_drop = SendOnDrop { sender: Some(send) }; +/// drop(send_on_drop); +/// +/// assert_eq!(recv.await, Ok("I got dropped!")); +/// } +/// ``` +/// +/// [`Option`]: std::option::Option +/// [`Option::take`]: std::option::Option::take #[derive(Debug)] pub struct Sender<T> { inner: Option<Arc<Inner<T>>>, } -/// Receive a value from the associated [`Sender`]. +/// Receives a value from the associated [`Sender`]. /// /// A pair of both a [`Sender`] and a [`Receiver`] are created by the /// [`channel`](fn@channel) function. /// +/// This channel has no `recv` method because the receiver itself implements the +/// [`Future`] trait. To receive a value, `.await` the `Receiver` object directly. +/// +/// [`Future`]: trait@std::future::Future +/// /// # Examples /// /// ``` @@ -120,13 +266,46 @@ pub struct Sender<T> { /// } /// } /// ``` +/// +/// To use a `Receiver` in a `tokio::select!` loop, add `&mut` in front of the +/// channel. +/// +/// ``` +/// use tokio::sync::oneshot; +/// use tokio::time::{interval, sleep, Duration}; +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread", start_paused = true)] +/// async fn main() { +/// let (send, mut recv) = oneshot::channel(); +/// let mut interval = interval(Duration::from_millis(100)); +/// +/// # let handle = +/// tokio::spawn(async move { +/// sleep(Duration::from_secs(1)).await; +/// send.send("shut down").unwrap(); +/// }); +/// +/// loop { +/// tokio::select! { +/// _ = interval.tick() => println!("Another 100ms"), +/// msg = &mut recv => { +/// println!("Got message: {}", msg.unwrap()); +/// break; +/// } +/// } +/// } +/// # handle.await.unwrap(); +/// } +/// ``` #[derive(Debug)] pub struct Receiver<T> { inner: Option<Arc<Inner<T>>>, } pub mod error { - //! Oneshot error types + //! Oneshot error types. use std::fmt; @@ -171,7 +350,7 @@ pub mod error { use self::error::*; struct Inner<T> { - /// Manages the state of the inner cell + /// Manages the state of the inner cell. state: AtomicUsize, /// The value. This is set by `Sender` and read by `Receiver`. The state of @@ -179,9 +358,19 @@ struct Inner<T> { value: UnsafeCell<Option<T>>, /// The task to notify when the receiver drops without consuming the value. + /// + /// ## Safety + /// + /// The `TX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. tx_task: Task, /// The task to notify when the value is sent. + /// + /// ## Safety + /// + /// The `RX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. rx_task: Task, } @@ -220,7 +409,7 @@ impl Task { #[derive(Clone, Copy)] struct State(usize); -/// Create a new one-shot channel for sending single values across asynchronous +/// Creates a new one-shot channel for sending single values across asynchronous /// tasks. /// /// The function returns separate "send" and "receive" handles. The `Sender` @@ -311,11 +500,24 @@ impl<T> Sender<T> { let inner = self.inner.take().unwrap(); inner.value.with_mut(|ptr| unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless the + // channel has been marked as "complete" (the `VALUE_SENT` state bit + // is set). + // That bit is only set by the sender later on in this method, and + // calling this method consumes `self`. Therefore, if it was possible to + // call this method, we know that the `VALUE_SENT` bit is unset, and + // the receiver is not currently accessing the `UnsafeCell`. *ptr = Some(t); }); if !inner.complete() { unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless + // the channel has been marked as "complete". Calling + // `complete()` will return true if this bit is set, and false + // if it is not set. Thus, if `complete()` returned false, it is + // safe for us to access the value, because we know that the + // receiver will not. return Err(inner.consume_value().unwrap()); } } @@ -430,7 +632,7 @@ impl<T> Sender<T> { state.is_closed() } - /// Check whether the oneshot channel has been closed, and if not, schedules the + /// Checks whether the oneshot channel has been closed, and if not, schedules the /// `Waker` in the provided `Context` to receive a notification when the channel is /// closed. /// @@ -661,6 +863,11 @@ impl<T> Receiver<T> { let state = State::load(&inner.state, Acquire); if state.is_complete() { + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. match unsafe { inner.consume_value() } { Some(value) => Ok(value), None => Err(TryRecvError::Closed), @@ -751,6 +958,11 @@ impl<T> Inner<T> { State::set_rx_task(&self.state); coop.made_progress(); + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. return match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), @@ -797,6 +1009,14 @@ impl<T> Inner<T> { } /// Consumes the value. This function does not check `state`. + /// + /// # Safety + /// + /// Calling this method concurrently on multiple threads will result in a + /// data race. The `VALUE_SENT` state bit is used to ensure that only the + /// sender *or* the receiver will call this method at a given point in time. + /// If `VALUE_SENT` is not set, then only the sender may call this method; + /// if it is set, then only the receiver may call this method. unsafe fn consume_value(&self) -> Option<T> { self.value.with_mut(|ptr| (*ptr).take()) } @@ -837,9 +1057,28 @@ impl<T: fmt::Debug> fmt::Debug for Inner<T> { } } +/// Indicates that a waker for the receiving task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `rx_task` field may be uninitialized. const RX_TASK_SET: usize = 0b00001; +/// Indicates that a value has been stored in the channel's inner `UnsafeCell`. +/// +/// # Safety +/// +/// This bit controls which side of the channel is permitted to access the +/// `UnsafeCell`. If it is set, the `UnsafeCell` may ONLY be accessed by the +/// receiver. If this bit is NOT set, the `UnsafeCell` may ONLY be accessed by +/// the sender. const VALUE_SENT: usize = 0b00010; const CLOSED: usize = 0b00100; + +/// Indicates that a waker for the sending task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `tx_task` field may be uninitialized. const TX_TASK_SET: usize = 0b01000; impl State { @@ -852,11 +1091,38 @@ impl State { } fn set_complete(cell: &AtomicUsize) -> State { - // TODO: This could be `Release`, followed by an `Acquire` fence *if* - // the `RX_TASK_SET` flag is set. However, `loom` does not support - // fences yet. - let val = cell.fetch_or(VALUE_SENT, AcqRel); - State(val) + // This method is a compare-and-swap loop rather than a fetch-or like + // other `set_$WHATEVER` methods on `State`. This is because we must + // check if the state has been closed before setting the `VALUE_SENT` + // bit. + // + // We don't want to set both the `VALUE_SENT` bit if the `CLOSED` + // bit is already set, because `VALUE_SENT` will tell the receiver that + // it's okay to access the inner `UnsafeCell`. Immediately after calling + // `set_complete`, if the channel was closed, the sender will _also_ + // access the `UnsafeCell` to take the value back out, so if a + // `poll_recv` or `try_recv` call is occurring concurrently, both + // threads may try to access the `UnsafeCell` if we were to set the + // `VALUE_SENT` bit on a closed channel. + let mut state = cell.load(Ordering::Relaxed); + loop { + if State(state).is_closed() { + break; + } + // TODO: This could be `Release`, followed by an `Acquire` fence *if* + // the `RX_TASK_SET` flag is set. However, `loom` does not support + // fences yet. + match cell.compare_exchange_weak( + state, + state | VALUE_SENT, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(actual) => state = actual, + } + } + State(state) } fn is_rx_task_set(self) -> bool { diff --git a/src/sync/rwlock/owned_read_guard.rs b/src/sync/rwlock/owned_read_guard.rs index b7f3926..1881295 100644 --- a/src/sync/rwlock/owned_read_guard.rs +++ b/src/sync/rwlock/owned_read_guard.rs @@ -22,7 +22,7 @@ pub struct OwnedRwLockReadGuard<T: ?Sized, U: ?Sized = T> { } impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> { - /// Make a new `OwnedRwLockReadGuard` for a component of the locked data. + /// Makes a new `OwnedRwLockReadGuard` for a component of the locked data. /// This operation cannot fail as the `OwnedRwLockReadGuard` passed in /// already locked the data. /// diff --git a/src/sync/rwlock/owned_write_guard.rs b/src/sync/rwlock/owned_write_guard.rs index 91b6595..0a78d28 100644 --- a/src/sync/rwlock/owned_write_guard.rs +++ b/src/sync/rwlock/owned_write_guard.rs @@ -24,7 +24,7 @@ pub struct OwnedRwLockWriteGuard<T: ?Sized> { } impl<T: ?Sized> OwnedRwLockWriteGuard<T> { - /// Make a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked + /// Makes a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked /// data. /// /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in diff --git a/src/sync/rwlock/owned_write_guard_mapped.rs b/src/sync/rwlock/owned_write_guard_mapped.rs index 6453236..d88ee01 100644 --- a/src/sync/rwlock/owned_write_guard_mapped.rs +++ b/src/sync/rwlock/owned_write_guard_mapped.rs @@ -23,7 +23,7 @@ pub struct OwnedRwLockMappedWriteGuard<T: ?Sized, U: ?Sized = T> { } impl<T: ?Sized, U: ?Sized> OwnedRwLockMappedWriteGuard<T, U> { - /// Make a new `OwnedRwLockMappedWriteGuard` for a component of the locked + /// Makes a new `OwnedRwLockMappedWriteGuard` for a component of the locked /// data. /// /// This operation cannot fail as the `OwnedRwLockMappedWriteGuard` passed diff --git a/src/sync/rwlock/read_guard.rs b/src/sync/rwlock/read_guard.rs index 38eec77..090b297 100644 --- a/src/sync/rwlock/read_guard.rs +++ b/src/sync/rwlock/read_guard.rs @@ -19,7 +19,7 @@ pub struct RwLockReadGuard<'a, T: ?Sized> { } impl<'a, T: ?Sized> RwLockReadGuard<'a, T> { - /// Make a new `RwLockReadGuard` for a component of the locked data. + /// Makes a new `RwLockReadGuard` for a component of the locked data. /// /// This operation cannot fail as the `RwLockReadGuard` passed in already /// locked the data. diff --git a/src/sync/rwlock/write_guard.rs b/src/sync/rwlock/write_guard.rs index 865a121..8c80ee7 100644 --- a/src/sync/rwlock/write_guard.rs +++ b/src/sync/rwlock/write_guard.rs @@ -22,7 +22,7 @@ pub struct RwLockWriteGuard<'a, T: ?Sized> { } impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { - /// Make a new [`RwLockMappedWriteGuard`] for a component of the locked data. + /// Makes a new [`RwLockMappedWriteGuard`] for a component of the locked data. /// /// This operation cannot fail as the `RwLockWriteGuard` passed in already /// locked the data. diff --git a/src/sync/rwlock/write_guard_mapped.rs b/src/sync/rwlock/write_guard_mapped.rs index 9c5b1e7..3cf69de 100644 --- a/src/sync/rwlock/write_guard_mapped.rs +++ b/src/sync/rwlock/write_guard_mapped.rs @@ -21,7 +21,7 @@ pub struct RwLockMappedWriteGuard<'a, T: ?Sized> { } impl<'a, T: ?Sized> RwLockMappedWriteGuard<'a, T> { - /// Make a new `RwLockMappedWriteGuard` for a component of the locked data. + /// Makes a new `RwLockMappedWriteGuard` for a component of the locked data. /// /// This operation cannot fail as the `RwLockMappedWriteGuard` passed in already /// locked the data. diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs index 8616007..e1330fb 100644 --- a/src/sync/task/atomic_waker.rs +++ b/src/sync/task/atomic_waker.rs @@ -4,6 +4,7 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::{self, AtomicUsize}; use std::fmt; +use std::panic::{resume_unwind, AssertUnwindSafe, RefUnwindSafe, UnwindSafe}; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; use std::task::Waker; @@ -27,6 +28,9 @@ pub(crate) struct AtomicWaker { waker: UnsafeCell<Option<Waker>>, } +impl RefUnwindSafe for AtomicWaker {} +impl UnwindSafe for AtomicWaker {} + // `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell // stores a `Waker` value produced by calls to `register` and many threads can // race to take the waker by calling `wake`. @@ -84,7 +88,7 @@ pub(crate) struct AtomicWaker { // back to `WAITING`. This transition must succeed as, at this point, the state // cannot be transitioned by another thread. // -// If the thread is unable to obtain the lock, the `WAKING` bit is still. +// If the thread is unable to obtain the lock, the `WAKING` bit is still set. // This is because it has either been set by the current thread but the previous // value included the `REGISTERING` bit **or** a concurrent thread is in the // `WAKING` critical section. Either way, no action must be taken. @@ -123,7 +127,7 @@ pub(crate) struct AtomicWaker { // Thread A still holds the `wake` lock, the call to `register` will result // in the task waking itself and get scheduled again. -/// Idle state +/// Idle state. const WAITING: usize = 0; /// A new waker value is being registered with the `AtomicWaker` cell. @@ -171,6 +175,10 @@ impl AtomicWaker { where W: WakerRef, { + fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> std::thread::Result<R> { + std::panic::catch_unwind(AssertUnwindSafe(f)) + } + match self .state .compare_exchange(WAITING, REGISTERING, Acquire, Acquire) @@ -178,8 +186,24 @@ impl AtomicWaker { { WAITING => { unsafe { - // Locked acquired, update the waker cell - self.waker.with_mut(|t| *t = Some(waker.into_waker())); + // If `into_waker` panics (because it's code outside of + // AtomicWaker) we need to prime a guard that is called on + // unwind to restore the waker to a WAITING state. Otherwise + // any future calls to register will incorrectly be stuck + // believing it's being updated by someone else. + let new_waker_or_panic = catch_unwind(move || waker.into_waker()); + + // Set the field to contain the new waker, or if + // `into_waker` panicked, leave the old value. + let mut maybe_panic = None; + let mut old_waker = None; + match new_waker_or_panic { + Ok(new_waker) => { + old_waker = self.waker.with_mut(|t| (*t).take()); + self.waker.with_mut(|t| *t = Some(new_waker)); + } + Err(panic) => maybe_panic = Some(panic), + } // Release the lock. If the state transitioned to include // the `WAKING` bit, this means that a wake has been @@ -193,33 +217,67 @@ impl AtomicWaker { .compare_exchange(REGISTERING, WAITING, AcqRel, Acquire); match res { - Ok(_) => {} + Ok(_) => { + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + let _ = catch_unwind(move || { + drop(old_waker); + }); + } Err(actual) => { // This branch can only be reached if a // concurrent thread called `wake`. In this // case, `actual` **must** be `REGISTERING | - // `WAKING`. + // WAKING`. debug_assert_eq!(actual, REGISTERING | WAKING); // Take the waker to wake once the atomic operation has // completed. - let waker = self.waker.with_mut(|t| (*t).take()).unwrap(); + let mut waker = self.waker.with_mut(|t| (*t).take()); // Just swap, because no one could change state // while state == `Registering | `Waking` self.state.swap(WAITING, AcqRel); - // The atomic swap was complete, now - // wake the waker and return. - waker.wake(); + // If `into_waker` panicked, then the waker in the + // waker slot is actually the old waker. + if maybe_panic.is_some() { + old_waker = waker.take(); + } + + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + if let Some(old_waker) = old_waker { + let _ = catch_unwind(move || { + old_waker.wake(); + }); + } + + // The atomic swap was complete, now wake the waker + // and return. + // + // If this panics, we end up in a consumed state and + // return the panic to the caller. + if let Some(waker) = waker { + debug_assert!(maybe_panic.is_none()); + waker.wake(); + } } } + + if let Some(panic) = maybe_panic { + // If `into_waker` panicked, return the panic to the caller. + resume_unwind(panic); + } } } WAKING => { // Currently in the process of waking the task, i.e., // `wake` is currently being called on the old waker. // So, we call wake on the new waker. + // + // If this panics, someone else is responsible for restoring the + // state of the waker. waker.wake(); // This is equivalent to a spin lock, so use a spin hint. @@ -245,6 +303,8 @@ impl AtomicWaker { /// If `register` has not been called yet, then this does nothing. pub(crate) fn wake(&self) { if let Some(waker) = self.take_waker() { + // If wake panics, we've consumed the waker which is a legitimate + // outcome. waker.wake(); } } diff --git a/src/sync/tests/atomic_waker.rs b/src/sync/tests/atomic_waker.rs index c832d62..b167a5d 100644 --- a/src/sync/tests/atomic_waker.rs +++ b/src/sync/tests/atomic_waker.rs @@ -32,3 +32,42 @@ fn wake_without_register() { assert!(!waker.is_woken()); } + +#[test] +fn atomic_waker_panic_safe() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| panic!("clone"), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + static NONPANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| RawWaker::new(ptr::null(), &NONPANICKING_VTABLE), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + let nonpanicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NONPANICKING_VTABLE)) }; + + let atomic_waker = AtomicWaker::new(); + + let panicking = panic::AssertUnwindSafe(&panicking); + + let result = panic::catch_unwind(|| { + let panic::AssertUnwindSafe(panicking) = panicking; + atomic_waker.register_by_ref(panicking); + }); + + assert!(result.is_err()); + assert!(atomic_waker.take_waker().is_none()); + + atomic_waker.register_by_ref(&nonpanicking); + assert!(atomic_waker.take_waker().is_some()); +} diff --git a/src/sync/tests/loom_atomic_waker.rs b/src/sync/tests/loom_atomic_waker.rs index c148bcb..f8bae65 100644 --- a/src/sync/tests/loom_atomic_waker.rs +++ b/src/sync/tests/loom_atomic_waker.rs @@ -43,3 +43,58 @@ fn basic_notification() { })); }); } + +#[test] +fn test_panicky_waker() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = + RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ()); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + + // If you're working with this test (and I sure hope you never have to!), + // uncomment the following section because there will be a lot of panics + // which would otherwise log. + // + // We can't however leaved it uncommented, because it's global. + // panic::set_hook(Box::new(|_| ())); + + const NUM_NOTIFY: usize = 2; + + loom::model(move || { + let chan = Arc::new(Chan { + num: AtomicUsize::new(0), + task: AtomicWaker::new(), + }); + + for _ in 0..NUM_NOTIFY { + let chan = chan.clone(); + + thread::spawn(move || { + chan.num.fetch_add(1, Relaxed); + chan.task.wake(); + }); + } + + // Note: this panic should have no effect on the overall state of the + // waker and it should proceed as normal. + // + // A thread above might race to flag a wakeup, and a WAKING state will + // be preserved if this expected panic races with that so the below + // procedure should be allowed to continue uninterrupted. + let _ = panic::catch_unwind(|| chan.task.register_by_ref(&panicking)); + + block_on(poll_fn(move |cx| { + chan.task.register_by_ref(cx.waker()); + + if NUM_NOTIFY == chan.num.load(Relaxed) { + return Ready(()); + } + + Pending + })); + }); +} diff --git a/src/sync/tests/loom_oneshot.rs b/src/sync/tests/loom_oneshot.rs index 9729cfb..c5f7972 100644 --- a/src/sync/tests/loom_oneshot.rs +++ b/src/sync/tests/loom_oneshot.rs @@ -55,6 +55,35 @@ fn changing_rx_task() { }); } +#[test] +fn try_recv_close() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + thread::spawn(move || { + let _ = tx.send(()); + }); + + rx.close(); + let _ = rx.try_recv(); + }) +} + +#[test] +fn recv_closed() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + + thread::spawn(move || { + let _ = tx.send(1); + }); + + rx.close(); + let _ = block_on(rx); + }); +} + // TODO: Move this into `oneshot` proper. use std::future::Future; diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs index c5d5601..ee76418 100644 --- a/src/sync/tests/mod.rs +++ b/src/sync/tests/mod.rs @@ -1,5 +1,6 @@ cfg_not_loom! { mod atomic_waker; + mod notify; mod semaphore_batch; } diff --git a/src/sync/tests/notify.rs b/src/sync/tests/notify.rs new file mode 100644 index 0000000..8c9a573 --- /dev/null +++ b/src/sync/tests/notify.rs @@ -0,0 +1,44 @@ +use crate::sync::Notify; +use std::future::Future; +use std::mem::ManuallyDrop; +use std::sync::Arc; +use std::task::{Context, RawWaker, RawWakerVTable, Waker}; + +#[test] +fn notify_clones_waker_before_lock() { + const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone_w, wake, wake_by_ref, drop_w); + + unsafe fn clone_w(data: *const ()) -> RawWaker { + let arc = ManuallyDrop::new(Arc::<Notify>::from_raw(data as *const Notify)); + // Or some other arbitrary code that shouldn't be executed while the + // Notify wait list is locked. + arc.notify_one(); + let _arc_clone: ManuallyDrop<_> = arc.clone(); + RawWaker::new(data, VTABLE) + } + + unsafe fn drop_w(data: *const ()) { + let _ = Arc::<Notify>::from_raw(data as *const Notify); + } + + unsafe fn wake(_data: *const ()) { + unreachable!() + } + + unsafe fn wake_by_ref(_data: *const ()) { + unreachable!() + } + + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + let waker = + unsafe { Waker::from_raw(RawWaker::new(Arc::into_raw(notify2) as *const _, VTABLE)) }; + let mut cx = Context::from_waker(&waker); + + let future = notify.notified(); + pin!(future); + + // The result doesn't matter, we're just testing that we don't deadlock. + let _ = future.poll(&mut cx); +} diff --git a/src/sync/watch.rs b/src/sync/watch.rs index b5da218..7e45c11 100644 --- a/src/sync/watch.rs +++ b/src/sync/watch.rs @@ -58,6 +58,7 @@ use crate::sync::notify::Notify; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::atomic::Ordering::Relaxed; use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; +use std::mem; use std::ops; /// Receives values from the associated [`Sender`](struct@Sender). @@ -85,7 +86,7 @@ pub struct Sender<T> { shared: Arc<Shared<T>>, } -/// Returns a reference to the inner value +/// Returns a reference to the inner value. /// /// Outstanding borrows hold a read lock on the inner value. This means that /// long lived borrows could cause the produce half to block. It is recommended @@ -97,27 +98,27 @@ pub struct Ref<'a, T> { #[derive(Debug)] struct Shared<T> { - /// The most recent value + /// The most recent value. value: RwLock<T>, - /// The current version + /// The current version. /// /// The lowest bit represents a "closed" state. The rest of the bits /// represent the current version. state: AtomicState, - /// Tracks the number of `Receiver` instances + /// Tracks the number of `Receiver` instances. ref_count_rx: AtomicUsize, /// Notifies waiting receivers that the value changed. notify_rx: Notify, - /// Notifies any task listening for `Receiver` dropped events + /// Notifies any task listening for `Receiver` dropped events. notify_tx: Notify, } pub mod error { - //! Watch error types + //! Watch error types. use std::fmt; @@ -317,7 +318,7 @@ impl<T> Receiver<T> { Ref { inner } } - /// Wait for a change notification, then mark the newest value as seen. + /// Waits for a change notification, then marks the newest value as seen. /// /// If the newest value in the channel has not yet been marked seen when /// this method is called, the method marks that value seen and returns @@ -432,10 +433,31 @@ impl<T> Sender<T> { return Err(error::SendError(value)); } - { + self.send_replace(value); + Ok(()) + } + + /// Sends a new value via the channel, notifying all receivers and returning + /// the previous value in the channel. + /// + /// This can be useful for reusing the buffers inside a watched value. + /// Additionally, this method permits sending values even when there are no + /// receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _rx) = watch::channel(1); + /// assert_eq!(tx.send_replace(2), 1); + /// assert_eq!(tx.send_replace(3), 2); + /// ``` + pub fn send_replace(&self, value: T) -> T { + let old = { // Acquire the write lock and update the value. let mut lock = self.shared.value.write().unwrap(); - *lock = value; + let old = mem::replace(&mut *lock, value); self.shared.state.increment_version(); @@ -445,12 +467,14 @@ impl<T> Sender<T> { // that receivers are able to figure out the version number of the // value they are currently looking at. drop(lock); - } + + old + }; // Notify all watchers self.shared.notify_rx.notify_waiters(); - Ok(()) + old } /// Returns a reference to the most recently sent value @@ -595,7 +619,7 @@ impl<T> Sender<T> { Receiver::from_shared(version, shared) } - /// Returns the number of receivers that currently exist + /// Returns the number of receivers that currently exist. /// /// # Examples /// |