aboutsummaryrefslogtreecommitdiff
path: root/src/sync/broadcast.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/sync/broadcast.rs')
-rw-r--r--src/sync/broadcast.rs368
1 files changed, 314 insertions, 54 deletions
diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs
index 1c6b2ca..42cde81 100644
--- a/src/sync/broadcast.rs
+++ b/src/sync/broadcast.rs
@@ -4,7 +4,7 @@
//! A [`Sender`] is used to broadcast values to **all** connected [`Receiver`]
//! values. [`Sender`] handles are clone-able, allowing concurrent send and
//! receive actions. [`Sender`] and [`Receiver`] are both `Send` and `Sync` as
-//! long as `T` is also `Send` or `Sync` respectively.
+//! long as `T` is `Send`.
//!
//! When a value is sent, **all** [`Receiver`] handles are notified and will
//! receive the value. The value is stored once inside the channel and cloned on
@@ -54,6 +54,10 @@
//! all values retained by the channel, the next call to [`recv`] will return
//! with [`RecvError::Closed`].
//!
+//! When a [`Receiver`] handle is dropped, any messages not read by the receiver
+//! will be marked as read. If this receiver was the only one not to have read
+//! that message, the message will be dropped at this point.
+//!
//! [`Sender`]: crate::sync::broadcast::Sender
//! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe
//! [`Receiver`]: crate::sync::broadcast::Receiver
@@ -114,8 +118,9 @@
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
-use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
-use crate::util::linked_list::{self, LinkedList};
+use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
+use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
+use crate::util::WakeList;
use std::fmt;
use std::future::Future;
@@ -361,6 +366,17 @@ struct Waiter {
_p: PhantomPinned,
}
+impl Waiter {
+ fn new() -> Self {
+ Self {
+ queued: false,
+ waker: None,
+ pointers: linked_list::Pointers::new(),
+ _p: PhantomPinned,
+ }
+ }
+}
+
generate_addr_of_methods! {
impl<> Waiter {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
@@ -439,42 +455,13 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2;
/// This will panic if `capacity` is equal to `0` or larger
/// than `usize::MAX / 2`.
#[track_caller]
-pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
- assert!(capacity > 0, "capacity is empty");
- assert!(capacity <= usize::MAX >> 1, "requested capacity too large");
-
- // Round to a power of two
- capacity = capacity.next_power_of_two();
-
- let mut buffer = Vec::with_capacity(capacity);
-
- for i in 0..capacity {
- buffer.push(RwLock::new(Slot {
- rem: AtomicUsize::new(0),
- pos: (i as u64).wrapping_sub(capacity as u64),
- val: UnsafeCell::new(None),
- }));
- }
-
- let shared = Arc::new(Shared {
- buffer: buffer.into_boxed_slice(),
- mask: capacity - 1,
- tail: Mutex::new(Tail {
- pos: 0,
- rx_cnt: 1,
- closed: false,
- waiters: LinkedList::new(),
- }),
- num_tx: AtomicUsize::new(1),
- });
-
+pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
+ // SAFETY: In the line below we are creating one extra receiver, so there will be 1 in total.
+ let tx = unsafe { Sender::new_with_receiver_count(1, capacity) };
let rx = Receiver {
- shared: shared.clone(),
+ shared: tx.shared.clone(),
next: 0,
};
-
- let tx = Sender { shared };
-
(tx, rx)
}
@@ -485,6 +472,65 @@ unsafe impl<T: Send> Send for Receiver<T> {}
unsafe impl<T: Send> Sync for Receiver<T> {}
impl<T> Sender<T> {
+ /// Creates the sending-half of the [`broadcast`] channel.
+ ///
+ /// See documentation of [`broadcast::channel`] for errors when calling this function.
+ ///
+ /// [`broadcast`]: crate::sync::broadcast
+ /// [`broadcast::channel`]: crate::sync::broadcast
+ #[track_caller]
+ pub fn new(capacity: usize) -> Self {
+ // SAFETY: We don't create extra receivers, so there are 0.
+ unsafe { Self::new_with_receiver_count(0, capacity) }
+ }
+
+ /// Creates the sending-half of the [`broadcast`](self) channel, and provide the receiver
+ /// count.
+ ///
+ /// See the documentation of [`broadcast::channel`](self::channel) for more errors when
+ /// calling this function.
+ ///
+ /// # Safety:
+ ///
+ /// The caller must ensure that the amount of receivers for this Sender is correct before
+ /// the channel functionalities are used, the count is zero by default, as this function
+ /// does not create any receivers by itself.
+ #[track_caller]
+ unsafe fn new_with_receiver_count(receiver_count: usize, mut capacity: usize) -> Self {
+ assert!(capacity > 0, "broadcast channel capacity cannot be zero");
+ assert!(
+ capacity <= usize::MAX >> 1,
+ "broadcast channel capacity exceeded `usize::MAX / 2`"
+ );
+
+ // Round to a power of two
+ capacity = capacity.next_power_of_two();
+
+ let mut buffer = Vec::with_capacity(capacity);
+
+ for i in 0..capacity {
+ buffer.push(RwLock::new(Slot {
+ rem: AtomicUsize::new(0),
+ pos: (i as u64).wrapping_sub(capacity as u64),
+ val: UnsafeCell::new(None),
+ }));
+ }
+
+ let shared = Arc::new(Shared {
+ buffer: buffer.into_boxed_slice(),
+ mask: capacity - 1,
+ tail: Mutex::new(Tail {
+ pos: 0,
+ rx_cnt: receiver_count,
+ closed: false,
+ waiters: LinkedList::new(),
+ }),
+ num_tx: AtomicUsize::new(1),
+ });
+
+ Sender { shared }
+ }
+
/// Attempts to send a value to all active [`Receiver`] handles, returning
/// it back if it could not be sent.
///
@@ -496,7 +542,8 @@ impl<T> Sender<T> {
///
/// On success, the number of subscribed [`Receiver`] handles is returned.
/// This does not mean that this number of receivers will see the message as
- /// a receiver may drop before receiving the message.
+ /// a receiver may drop or lag ([see lagging](self#lagging)) before receiving
+ /// the message.
///
/// # Note
///
@@ -565,12 +612,10 @@ impl<T> Sender<T> {
// Release the slot lock before notifying the receivers.
drop(slot);
- tail.notify_rx();
-
- // Release the mutex. This must happen after the slot lock is released,
- // otherwise the writer lock bit could be cleared while another thread
- // is in the critical section.
- drop(tail);
+ // Notify and release the mutex. This must happen after the slot lock is
+ // released, otherwise the writer lock bit could be cleared while another
+ // thread is in the critical section.
+ self.shared.notify_rx(tail);
Ok(rem)
}
@@ -735,11 +780,34 @@ impl<T> Sender<T> {
tail.rx_cnt
}
+ /// Returns `true` if senders belong to the same channel.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, _rx) = broadcast::channel::<()>(16);
+ /// let tx2 = tx.clone();
+ ///
+ /// assert!(tx.same_channel(&tx2));
+ ///
+ /// let (tx3, _rx3) = broadcast::channel::<()>(16);
+ ///
+ /// assert!(!tx3.same_channel(&tx2));
+ /// }
+ /// ```
+ pub fn same_channel(&self, other: &Self) -> bool {
+ Arc::ptr_eq(&self.shared, &other.shared)
+ }
+
fn close_channel(&self) {
let mut tail = self.shared.tail.lock();
tail.closed = true;
- tail.notify_rx();
+ self.shared.notify_rx(tail);
}
}
@@ -760,18 +828,110 @@ fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> {
Receiver { shared, next }
}
-impl Tail {
- fn notify_rx(&mut self) {
- while let Some(mut waiter) = self.waiters.pop_back() {
- // Safety: `waiters` lock is still held.
- let waiter = unsafe { waiter.as_mut() };
+/// List used in `Shared::notify_rx`. It wraps a guarded linked list
+/// and gates the access to it on the `Shared.tail` mutex. It also empties
+/// the list on drop.
+struct WaitersList<'a, T> {
+ list: GuardedLinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
+ is_empty: bool,
+ shared: &'a Shared<T>,
+}
+
+impl<'a, T> Drop for WaitersList<'a, T> {
+ fn drop(&mut self) {
+ // If the list is not empty, we unlink all waiters from it.
+ // We do not wake the waiters to avoid double panics.
+ if !self.is_empty {
+ let _lock_guard = self.shared.tail.lock();
+ while self.list.pop_back().is_some() {}
+ }
+ }
+}
- assert!(waiter.queued);
- waiter.queued = false;
+impl<'a, T> WaitersList<'a, T> {
+ fn new(
+ unguarded_list: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
+ guard: Pin<&'a Waiter>,
+ shared: &'a Shared<T>,
+ ) -> Self {
+ let guard_ptr = NonNull::from(guard.get_ref());
+ let list = unguarded_list.into_guarded(guard_ptr);
+ WaitersList {
+ list,
+ is_empty: false,
+ shared,
+ }
+ }
- let waker = waiter.waker.take().unwrap();
- waker.wake();
+ /// Removes the last element from the guarded list. Modifying this list
+ /// requires an exclusive access to the main list in `Notify`.
+ fn pop_back_locked(&mut self, _tail: &mut Tail) -> Option<NonNull<Waiter>> {
+ let result = self.list.pop_back();
+ if result.is_none() {
+ // Save information about emptiness to avoid waiting for lock
+ // in the destructor.
+ self.is_empty = true;
}
+ result
+ }
+}
+
+impl<T> Shared<T> {
+ fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) {
+ // It is critical for `GuardedLinkedList` safety that the guard node is
+ // pinned in memory and is not dropped until the guarded list is dropped.
+ let guard = Waiter::new();
+ pin!(guard);
+
+ // We move all waiters to a secondary list. It uses a `GuardedLinkedList`
+ // underneath to allow every waiter to safely remove itself from it.
+ //
+ // * This list will be still guarded by the `waiters` lock.
+ // `NotifyWaitersList` wrapper makes sure we hold the lock to modify it.
+ // * This wrapper will empty the list on drop. It is critical for safety
+ // that we will not leave any list entry with a pointer to the local
+ // guard node after this function returns / panics.
+ let mut list = WaitersList::new(std::mem::take(&mut tail.waiters), guard.as_ref(), self);
+
+ let mut wakers = WakeList::new();
+ 'outer: loop {
+ while wakers.can_push() {
+ match list.pop_back_locked(&mut tail) {
+ Some(mut waiter) => {
+ // Safety: `tail` lock is still held.
+ let waiter = unsafe { waiter.as_mut() };
+
+ assert!(waiter.queued);
+ waiter.queued = false;
+
+ if let Some(waker) = waiter.waker.take() {
+ wakers.push(waker);
+ }
+ }
+ None => {
+ break 'outer;
+ }
+ }
+ }
+
+ // Release the lock before waking.
+ drop(tail);
+
+ // Before we acquire the lock again all sorts of things can happen:
+ // some waiters may remove themselves from the list and new waiters
+ // may be added. This is fine since at worst we will unnecessarily
+ // wake up waiters which will then queue themselves again.
+
+ wakers.wake_all();
+
+ // Acquire the lock again.
+ tail = self.tail.lock();
+ }
+
+ // Release the lock before waking.
+ drop(tail);
+
+ wakers.wake_all();
}
}
@@ -860,6 +1020,29 @@ impl<T> Receiver<T> {
self.len() == 0
}
+ /// Returns `true` if receivers belong to the same channel.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, rx) = broadcast::channel::<()>(16);
+ /// let rx2 = tx.subscribe();
+ ///
+ /// assert!(rx.same_channel(&rx2));
+ ///
+ /// let (_tx3, rx3) = broadcast::channel::<()>(16);
+ ///
+ /// assert!(!rx3.same_channel(&rx2));
+ /// }
+ /// ```
+ pub fn same_channel(&self, other: &Self) -> bool {
+ Arc::ptr_eq(&self.shared, &other.shared)
+ }
+
/// Locks the next value if there is one.
fn recv_ref(
&mut self,
@@ -880,6 +1063,8 @@ impl<T> Receiver<T> {
// the slot lock.
drop(slot);
+ let mut old_waker = None;
+
let mut tail = self.shared.tail.lock();
// Acquire slot lock again
@@ -912,7 +1097,10 @@ impl<T> Receiver<T> {
match (*ptr).waker {
Some(ref w) if w.will_wake(waker) => {}
_ => {
- (*ptr).waker = Some(waker.clone());
+ old_waker = std::mem::replace(
+ &mut (*ptr).waker,
+ Some(waker.clone()),
+ );
}
}
@@ -924,6 +1112,11 @@ impl<T> Receiver<T> {
}
}
+ // Drop the old waker after releasing the locks.
+ drop(slot);
+ drop(tail);
+ drop(old_waker);
+
return Err(TryRecvError::Empty);
}
@@ -1106,6 +1299,33 @@ impl<T: Clone> Receiver<T> {
let guard = self.recv_ref(None)?;
guard.clone_value().ok_or(TryRecvError::Closed)
}
+
+ /// Blocking receive to call outside of asynchronous contexts.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if called within an asynchronous execution
+ /// context.
+ ///
+ /// # Examples
+ /// ```
+ /// use std::thread;
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = broadcast::channel(16);
+ ///
+ /// let sync_code = thread::spawn(move || {
+ /// assert_eq!(rx.blocking_recv(), Ok(10));
+ /// });
+ ///
+ /// let _ = tx.send(10);
+ /// sync_code.join().unwrap();
+ /// }
+ pub fn blocking_recv(&mut self) -> Result<T, RecvError> {
+ crate::future::block_on(self.recv())
+ }
}
impl<T> Drop for Receiver<T> {
@@ -1164,6 +1384,8 @@ where
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
+ ready!(crate::trace::trace_leaf(cx));
+
let (receiver, waiter) = self.project();
let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) {
@@ -1252,3 +1474,41 @@ impl<'a, T> Drop for RecvGuard<'a, T> {
}
fn is_unpin<T: Unpin>() {}
+
+#[cfg(not(loom))]
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn receiver_count_on_sender_constructor() {
+ let sender = Sender::<i32>::new(16);
+ assert_eq!(sender.receiver_count(), 0);
+
+ let rx_1 = sender.subscribe();
+ assert_eq!(sender.receiver_count(), 1);
+
+ let rx_2 = rx_1.resubscribe();
+ assert_eq!(sender.receiver_count(), 2);
+
+ let rx_3 = sender.subscribe();
+ assert_eq!(sender.receiver_count(), 3);
+
+ drop(rx_3);
+ drop(rx_1);
+ assert_eq!(sender.receiver_count(), 1);
+
+ drop(rx_2);
+ assert_eq!(sender.receiver_count(), 0);
+ }
+
+ #[cfg(not(loom))]
+ #[test]
+ fn receiver_count_on_channel_constructor() {
+ let (sender, rx) = channel::<i32>(16);
+ assert_eq!(sender.receiver_count(), 1);
+
+ let _rx_2 = rx.resubscribe();
+ assert_eq!(sender.receiver_count(), 2);
+ }
+}