diff options
Diffstat (limited to 'src/sync/broadcast.rs')
-rw-r--r-- | src/sync/broadcast.rs | 368 |
1 files changed, 314 insertions, 54 deletions
diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs index 1c6b2ca..42cde81 100644 --- a/src/sync/broadcast.rs +++ b/src/sync/broadcast.rs @@ -4,7 +4,7 @@ //! A [`Sender`] is used to broadcast values to **all** connected [`Receiver`] //! values. [`Sender`] handles are clone-able, allowing concurrent send and //! receive actions. [`Sender`] and [`Receiver`] are both `Send` and `Sync` as -//! long as `T` is also `Send` or `Sync` respectively. +//! long as `T` is `Send`. //! //! When a value is sent, **all** [`Receiver`] handles are notified and will //! receive the value. The value is stored once inside the channel and cloned on @@ -54,6 +54,10 @@ //! all values retained by the channel, the next call to [`recv`] will return //! with [`RecvError::Closed`]. //! +//! When a [`Receiver`] handle is dropped, any messages not read by the receiver +//! will be marked as read. If this receiver was the only one not to have read +//! that message, the message will be dropped at this point. +//! //! [`Sender`]: crate::sync::broadcast::Sender //! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe //! [`Receiver`]: crate::sync::broadcast::Receiver @@ -114,8 +118,9 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; -use crate::util::linked_list::{self, LinkedList}; +use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard}; +use crate::util::linked_list::{self, GuardedLinkedList, LinkedList}; +use crate::util::WakeList; use std::fmt; use std::future::Future; @@ -361,6 +366,17 @@ struct Waiter { _p: PhantomPinned, } +impl Waiter { + fn new() -> Self { + Self { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + } + } +} + generate_addr_of_methods! { impl<> Waiter { unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> { @@ -439,42 +455,13 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// This will panic if `capacity` is equal to `0` or larger /// than `usize::MAX / 2`. #[track_caller] -pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { - assert!(capacity > 0, "capacity is empty"); - assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); - - // Round to a power of two - capacity = capacity.next_power_of_two(); - - let mut buffer = Vec::with_capacity(capacity); - - for i in 0..capacity { - buffer.push(RwLock::new(Slot { - rem: AtomicUsize::new(0), - pos: (i as u64).wrapping_sub(capacity as u64), - val: UnsafeCell::new(None), - })); - } - - let shared = Arc::new(Shared { - buffer: buffer.into_boxed_slice(), - mask: capacity - 1, - tail: Mutex::new(Tail { - pos: 0, - rx_cnt: 1, - closed: false, - waiters: LinkedList::new(), - }), - num_tx: AtomicUsize::new(1), - }); - +pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) { + // SAFETY: In the line below we are creating one extra receiver, so there will be 1 in total. + let tx = unsafe { Sender::new_with_receiver_count(1, capacity) }; let rx = Receiver { - shared: shared.clone(), + shared: tx.shared.clone(), next: 0, }; - - let tx = Sender { shared }; - (tx, rx) } @@ -485,6 +472,65 @@ unsafe impl<T: Send> Send for Receiver<T> {} unsafe impl<T: Send> Sync for Receiver<T> {} impl<T> Sender<T> { + /// Creates the sending-half of the [`broadcast`] channel. + /// + /// See documentation of [`broadcast::channel`] for errors when calling this function. + /// + /// [`broadcast`]: crate::sync::broadcast + /// [`broadcast::channel`]: crate::sync::broadcast + #[track_caller] + pub fn new(capacity: usize) -> Self { + // SAFETY: We don't create extra receivers, so there are 0. + unsafe { Self::new_with_receiver_count(0, capacity) } + } + + /// Creates the sending-half of the [`broadcast`](self) channel, and provide the receiver + /// count. + /// + /// See the documentation of [`broadcast::channel`](self::channel) for more errors when + /// calling this function. + /// + /// # Safety: + /// + /// The caller must ensure that the amount of receivers for this Sender is correct before + /// the channel functionalities are used, the count is zero by default, as this function + /// does not create any receivers by itself. + #[track_caller] + unsafe fn new_with_receiver_count(receiver_count: usize, mut capacity: usize) -> Self { + assert!(capacity > 0, "broadcast channel capacity cannot be zero"); + assert!( + capacity <= usize::MAX >> 1, + "broadcast channel capacity exceeded `usize::MAX / 2`" + ); + + // Round to a power of two + capacity = capacity.next_power_of_two(); + + let mut buffer = Vec::with_capacity(capacity); + + for i in 0..capacity { + buffer.push(RwLock::new(Slot { + rem: AtomicUsize::new(0), + pos: (i as u64).wrapping_sub(capacity as u64), + val: UnsafeCell::new(None), + })); + } + + let shared = Arc::new(Shared { + buffer: buffer.into_boxed_slice(), + mask: capacity - 1, + tail: Mutex::new(Tail { + pos: 0, + rx_cnt: receiver_count, + closed: false, + waiters: LinkedList::new(), + }), + num_tx: AtomicUsize::new(1), + }); + + Sender { shared } + } + /// Attempts to send a value to all active [`Receiver`] handles, returning /// it back if it could not be sent. /// @@ -496,7 +542,8 @@ impl<T> Sender<T> { /// /// On success, the number of subscribed [`Receiver`] handles is returned. /// This does not mean that this number of receivers will see the message as - /// a receiver may drop before receiving the message. + /// a receiver may drop or lag ([see lagging](self#lagging)) before receiving + /// the message. /// /// # Note /// @@ -565,12 +612,10 @@ impl<T> Sender<T> { // Release the slot lock before notifying the receivers. drop(slot); - tail.notify_rx(); - - // Release the mutex. This must happen after the slot lock is released, - // otherwise the writer lock bit could be cleared while another thread - // is in the critical section. - drop(tail); + // Notify and release the mutex. This must happen after the slot lock is + // released, otherwise the writer lock bit could be cleared while another + // thread is in the critical section. + self.shared.notify_rx(tail); Ok(rem) } @@ -735,11 +780,34 @@ impl<T> Sender<T> { tail.rx_cnt } + /// Returns `true` if senders belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = broadcast::channel::<()>(16); + /// let tx2 = tx.clone(); + /// + /// assert!(tx.same_channel(&tx2)); + /// + /// let (tx3, _rx3) = broadcast::channel::<()>(16); + /// + /// assert!(!tx3.same_channel(&tx2)); + /// } + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.shared, &other.shared) + } + fn close_channel(&self) { let mut tail = self.shared.tail.lock(); tail.closed = true; - tail.notify_rx(); + self.shared.notify_rx(tail); } } @@ -760,18 +828,110 @@ fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> { Receiver { shared, next } } -impl Tail { - fn notify_rx(&mut self) { - while let Some(mut waiter) = self.waiters.pop_back() { - // Safety: `waiters` lock is still held. - let waiter = unsafe { waiter.as_mut() }; +/// List used in `Shared::notify_rx`. It wraps a guarded linked list +/// and gates the access to it on the `Shared.tail` mutex. It also empties +/// the list on drop. +struct WaitersList<'a, T> { + list: GuardedLinkedList<Waiter, <Waiter as linked_list::Link>::Target>, + is_empty: bool, + shared: &'a Shared<T>, +} + +impl<'a, T> Drop for WaitersList<'a, T> { + fn drop(&mut self) { + // If the list is not empty, we unlink all waiters from it. + // We do not wake the waiters to avoid double panics. + if !self.is_empty { + let _lock_guard = self.shared.tail.lock(); + while self.list.pop_back().is_some() {} + } + } +} - assert!(waiter.queued); - waiter.queued = false; +impl<'a, T> WaitersList<'a, T> { + fn new( + unguarded_list: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, + guard: Pin<&'a Waiter>, + shared: &'a Shared<T>, + ) -> Self { + let guard_ptr = NonNull::from(guard.get_ref()); + let list = unguarded_list.into_guarded(guard_ptr); + WaitersList { + list, + is_empty: false, + shared, + } + } - let waker = waiter.waker.take().unwrap(); - waker.wake(); + /// Removes the last element from the guarded list. Modifying this list + /// requires an exclusive access to the main list in `Notify`. + fn pop_back_locked(&mut self, _tail: &mut Tail) -> Option<NonNull<Waiter>> { + let result = self.list.pop_back(); + if result.is_none() { + // Save information about emptiness to avoid waiting for lock + // in the destructor. + self.is_empty = true; } + result + } +} + +impl<T> Shared<T> { + fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) { + // It is critical for `GuardedLinkedList` safety that the guard node is + // pinned in memory and is not dropped until the guarded list is dropped. + let guard = Waiter::new(); + pin!(guard); + + // We move all waiters to a secondary list. It uses a `GuardedLinkedList` + // underneath to allow every waiter to safely remove itself from it. + // + // * This list will be still guarded by the `waiters` lock. + // `NotifyWaitersList` wrapper makes sure we hold the lock to modify it. + // * This wrapper will empty the list on drop. It is critical for safety + // that we will not leave any list entry with a pointer to the local + // guard node after this function returns / panics. + let mut list = WaitersList::new(std::mem::take(&mut tail.waiters), guard.as_ref(), self); + + let mut wakers = WakeList::new(); + 'outer: loop { + while wakers.can_push() { + match list.pop_back_locked(&mut tail) { + Some(mut waiter) => { + // Safety: `tail` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.queued); + waiter.queued = false; + + if let Some(waker) = waiter.waker.take() { + wakers.push(waker); + } + } + None => { + break 'outer; + } + } + } + + // Release the lock before waking. + drop(tail); + + // Before we acquire the lock again all sorts of things can happen: + // some waiters may remove themselves from the list and new waiters + // may be added. This is fine since at worst we will unnecessarily + // wake up waiters which will then queue themselves again. + + wakers.wake_all(); + + // Acquire the lock again. + tail = self.tail.lock(); + } + + // Release the lock before waking. + drop(tail); + + wakers.wake_all(); } } @@ -860,6 +1020,29 @@ impl<T> Receiver<T> { self.len() == 0 } + /// Returns `true` if receivers belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = broadcast::channel::<()>(16); + /// let rx2 = tx.subscribe(); + /// + /// assert!(rx.same_channel(&rx2)); + /// + /// let (_tx3, rx3) = broadcast::channel::<()>(16); + /// + /// assert!(!rx3.same_channel(&rx2)); + /// } + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.shared, &other.shared) + } + /// Locks the next value if there is one. fn recv_ref( &mut self, @@ -880,6 +1063,8 @@ impl<T> Receiver<T> { // the slot lock. drop(slot); + let mut old_waker = None; + let mut tail = self.shared.tail.lock(); // Acquire slot lock again @@ -912,7 +1097,10 @@ impl<T> Receiver<T> { match (*ptr).waker { Some(ref w) if w.will_wake(waker) => {} _ => { - (*ptr).waker = Some(waker.clone()); + old_waker = std::mem::replace( + &mut (*ptr).waker, + Some(waker.clone()), + ); } } @@ -924,6 +1112,11 @@ impl<T> Receiver<T> { } } + // Drop the old waker after releasing the locks. + drop(slot); + drop(tail); + drop(old_waker); + return Err(TryRecvError::Empty); } @@ -1106,6 +1299,33 @@ impl<T: Clone> Receiver<T> { let guard = self.recv_ref(None)?; guard.clone_value().ok_or(TryRecvError::Closed) } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// ``` + /// use std::thread; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(16); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(rx.blocking_recv(), Ok(10)); + /// }); + /// + /// let _ = tx.send(10); + /// sync_code.join().unwrap(); + /// } + pub fn blocking_recv(&mut self) -> Result<T, RecvError> { + crate::future::block_on(self.recv()) + } } impl<T> Drop for Receiver<T> { @@ -1164,6 +1384,8 @@ where type Output = Result<T, RecvError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + ready!(crate::trace::trace_leaf(cx)); + let (receiver, waiter) = self.project(); let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { @@ -1252,3 +1474,41 @@ impl<'a, T> Drop for RecvGuard<'a, T> { } fn is_unpin<T: Unpin>() {} + +#[cfg(not(loom))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn receiver_count_on_sender_constructor() { + let sender = Sender::<i32>::new(16); + assert_eq!(sender.receiver_count(), 0); + + let rx_1 = sender.subscribe(); + assert_eq!(sender.receiver_count(), 1); + + let rx_2 = rx_1.resubscribe(); + assert_eq!(sender.receiver_count(), 2); + + let rx_3 = sender.subscribe(); + assert_eq!(sender.receiver_count(), 3); + + drop(rx_3); + drop(rx_1); + assert_eq!(sender.receiver_count(), 1); + + drop(rx_2); + assert_eq!(sender.receiver_count(), 0); + } + + #[cfg(not(loom))] + #[test] + fn receiver_count_on_channel_constructor() { + let (sender, rx) = channel::<i32>(16); + assert_eq!(sender.receiver_count(), 1); + + let _rx_2 = rx.resubscribe(); + assert_eq!(sender.receiver_count(), 2); + } +} |