aboutsummaryrefslogtreecommitdiff
path: root/src/stream/stream/flatten_unordered.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/stream/stream/flatten_unordered.rs')
-rw-r--r--src/stream/stream/flatten_unordered.rs509
1 files changed, 509 insertions, 0 deletions
diff --git a/src/stream/stream/flatten_unordered.rs b/src/stream/stream/flatten_unordered.rs
new file mode 100644
index 0000000..07f971c
--- /dev/null
+++ b/src/stream/stream/flatten_unordered.rs
@@ -0,0 +1,509 @@
+use alloc::sync::Arc;
+use core::{
+ cell::UnsafeCell,
+ convert::identity,
+ fmt,
+ num::NonZeroUsize,
+ pin::Pin,
+ sync::atomic::{AtomicU8, Ordering},
+};
+
+use pin_project_lite::pin_project;
+
+use futures_core::{
+ future::Future,
+ ready,
+ stream::{FusedStream, Stream},
+ task::{Context, Poll, Waker},
+};
+#[cfg(feature = "sink")]
+use futures_sink::Sink;
+use futures_task::{waker, ArcWake};
+
+use crate::stream::FuturesUnordered;
+
+/// There is nothing to poll and stream isn't being
+/// polled or waking at the moment.
+const NONE: u8 = 0;
+
+/// Inner streams need to be polled.
+const NEED_TO_POLL_INNER_STREAMS: u8 = 1;
+
+/// The base stream needs to be polled.
+const NEED_TO_POLL_STREAM: u8 = 0b10;
+
+/// It needs to poll base stream and inner streams.
+const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM;
+
+/// The current stream is being polled at the moment.
+const POLLING: u8 = 0b100;
+
+/// Inner streams are being woken at the moment.
+const WAKING_INNER_STREAMS: u8 = 0b1000;
+
+/// The base stream is being woken at the moment.
+const WAKING_STREAM: u8 = 0b10000;
+
+/// The base stream and inner streams are being woken at the moment.
+const WAKING_ALL: u8 = WAKING_STREAM | WAKING_INNER_STREAMS;
+
+/// The stream was waked and will be polled.
+const WOKEN: u8 = 0b100000;
+
+/// Determines what needs to be polled, and is stream being polled at the
+/// moment or not.
+#[derive(Clone, Debug)]
+struct SharedPollState {
+ state: Arc<AtomicU8>,
+}
+
+impl SharedPollState {
+ /// Constructs new `SharedPollState` with the given state.
+ fn new(value: u8) -> SharedPollState {
+ SharedPollState { state: Arc::new(AtomicU8::new(value)) }
+ }
+
+ /// Attempts to start polling, returning stored state in case of success.
+ /// Returns `None` if some waker is waking at the moment.
+ fn start_polling(
+ &self,
+ ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
+ let value = self
+ .state
+ .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
+ if value & WAKING_ALL == NONE {
+ Some(POLLING)
+ } else {
+ None
+ }
+ })
+ .ok()?;
+ let bomb = PollStateBomb::new(self, SharedPollState::reset);
+
+ Some((value, bomb))
+ }
+
+ /// Starts the waking process and performs bitwise or with the given value.
+ fn start_waking(
+ &self,
+ to_poll: u8,
+ waking: u8,
+ ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
+ let value = self
+ .state
+ .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
+ // Waking process for this waker already started
+ if value & waking != NONE {
+ return None;
+ }
+ let mut next_value = value | to_poll;
+ // Only start the waking process if we're not in the polling phase and the stream isn't woken already
+ if value & (WOKEN | POLLING) == NONE {
+ next_value |= waking;
+ }
+
+ if next_value != value {
+ Some(next_value)
+ } else {
+ None
+ }
+ })
+ .ok()?;
+
+ if value & (WOKEN | POLLING) == NONE {
+ let bomb = PollStateBomb::new(self, move |state| state.stop_waking(waking));
+
+ Some((value, bomb))
+ } else {
+ None
+ }
+ }
+
+ /// Sets current state to
+ /// - `!POLLING` allowing to use wakers
+ /// - `WOKEN` if the state was changed during `POLLING` phase as waker will be called,
+ /// or `will_be_woken` flag supplied
+ /// - `!WAKING_ALL` as
+ /// * Wakers called during the `POLLING` phase won't propagate their calls
+ /// * `POLLING` phase can't start if some of the wakers are active
+ /// So no wrapped waker can touch the inner waker's cell, it's safe to poll again.
+ fn stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8 {
+ self.state
+ .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut value| {
+ let mut next_value = to_poll;
+
+ value &= NEED_TO_POLL_ALL;
+ if value != NONE || will_be_woken {
+ next_value |= WOKEN;
+ }
+ next_value |= value;
+
+ Some(next_value & !POLLING & !WAKING_ALL)
+ })
+ .unwrap()
+ }
+
+ /// Toggles state to non-waking, allowing to start polling.
+ fn stop_waking(&self, waking: u8) -> u8 {
+ self.state
+ .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
+ let mut next_value = value & !waking;
+ // Waker will be called only if the current waking state is the same as the specified waker state
+ if value & WAKING_ALL == waking {
+ next_value |= WOKEN;
+ }
+
+ if next_value != value {
+ Some(next_value)
+ } else {
+ None
+ }
+ })
+ .unwrap_or_else(identity)
+ }
+
+ /// Resets current state allowing to poll the stream and wake up wakers.
+ fn reset(&self) -> u8 {
+ self.state.swap(NEED_TO_POLL_ALL, Ordering::AcqRel)
+ }
+}
+
+/// Used to execute some function on the given state when dropped.
+struct PollStateBomb<'a, F: FnOnce(&SharedPollState) -> u8> {
+ state: &'a SharedPollState,
+ drop: Option<F>,
+}
+
+impl<'a, F: FnOnce(&SharedPollState) -> u8> PollStateBomb<'a, F> {
+ /// Constructs new bomb with the given state.
+ fn new(state: &'a SharedPollState, drop: F) -> Self {
+ Self { state, drop: Some(drop) }
+ }
+
+ /// Deactivates bomb, forces it to not call provided function when dropped.
+ fn deactivate(mut self) {
+ self.drop.take();
+ }
+
+ /// Manually fires the bomb, returning supplied state.
+ fn fire(mut self) -> Option<u8> {
+ self.drop.take().map(|drop| (drop)(self.state))
+ }
+}
+
+impl<F: FnOnce(&SharedPollState) -> u8> Drop for PollStateBomb<'_, F> {
+ fn drop(&mut self) {
+ if let Some(drop) = self.drop.take() {
+ (drop)(self.state);
+ }
+ }
+}
+
+/// Will update state with the provided value on `wake_by_ref` call
+/// and then, if there is a need, call `inner_waker`.
+struct InnerWaker {
+ inner_waker: UnsafeCell<Option<Waker>>,
+ poll_state: SharedPollState,
+ need_to_poll: u8,
+}
+
+unsafe impl Send for InnerWaker {}
+unsafe impl Sync for InnerWaker {}
+
+impl InnerWaker {
+ /// Replaces given waker's inner_waker for polling stream/futures which will
+ /// update poll state on `wake_by_ref` call. Use only if you need several
+ /// contexts.
+ ///
+ /// ## Safety
+ ///
+ /// This function will modify waker's `inner_waker` via `UnsafeCell`, so
+ /// it should be used only during `POLLING` phase.
+ unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) -> Waker {
+ *self_arc.inner_waker.get() = cx.waker().clone().into();
+ waker(self_arc.clone())
+ }
+
+ /// Attempts to start the waking process for the waker with the given value.
+ /// If succeeded, then the stream isn't yet woken and not being polled at the moment.
+ fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
+ self.poll_state.start_waking(self.need_to_poll, self.waking_state())
+ }
+
+ /// Returns the corresponding waking state toggled by this waker.
+ fn waking_state(&self) -> u8 {
+ self.need_to_poll << 3
+ }
+}
+
+impl ArcWake for InnerWaker {
+ fn wake_by_ref(self_arc: &Arc<Self>) {
+ if let Some((_, state_bomb)) = self_arc.start_waking() {
+ // Safety: now state is not `POLLING`
+ let waker_opt = unsafe { self_arc.inner_waker.get().as_ref().unwrap() };
+
+ if let Some(inner_waker) = waker_opt.clone() {
+ // Stop waking to allow polling stream
+ let poll_state_value = state_bomb.fire().unwrap();
+
+ // Here we want to call waker only if stream isn't woken yet and
+ // also to optimize the case when two wakers are called at the same time.
+ //
+ // In this case the best strategy will be to propagate only the latest waker's awake,
+ // and then poll both entities in a single `poll_next` call
+ if poll_state_value & (WOKEN | WAKING_ALL) == self_arc.waking_state() {
+ // Wake up inner waker
+ inner_waker.wake();
+ }
+ }
+ }
+ }
+}
+
+pin_project! {
+ /// Future which contains optional stream.
+ ///
+ /// If it's `Some`, it will attempt to call `poll_next` on it,
+ /// returning `Some((item, next_item_fut))` in case of `Poll::Ready(Some(...))`
+ /// or `None` in case of `Poll::Ready(None)`.
+ ///
+ /// If `poll_next` will return `Poll::Pending`, it will be forwarded to
+ /// the future and current task will be notified by waker.
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ struct PollStreamFut<St> {
+ #[pin]
+ stream: Option<St>,
+ }
+}
+
+impl<St> PollStreamFut<St> {
+ /// Constructs new `PollStreamFut` using given `stream`.
+ fn new(stream: impl Into<Option<St>>) -> Self {
+ Self { stream: stream.into() }
+ }
+}
+
+impl<St: Stream + Unpin> Future for PollStreamFut<St> {
+ type Output = Option<(St::Item, PollStreamFut<St>)>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let mut stream = self.project().stream;
+
+ let item = if let Some(stream) = stream.as_mut().as_pin_mut() {
+ ready!(stream.poll_next(cx))
+ } else {
+ None
+ };
+ let next_item_fut = PollStreamFut::new(stream.get_mut().take());
+ let out = item.map(|item| (item, next_item_fut));
+
+ Poll::Ready(out)
+ }
+}
+
+pin_project! {
+ /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
+ /// method.
+ #[project = FlattenUnorderedProj]
+ #[must_use = "streams do nothing unless polled"]
+ pub struct FlattenUnordered<St> where St: Stream {
+ #[pin]
+ inner_streams: FuturesUnordered<PollStreamFut<St::Item>>,
+ #[pin]
+ stream: St,
+ poll_state: SharedPollState,
+ limit: Option<NonZeroUsize>,
+ is_stream_done: bool,
+ inner_streams_waker: Arc<InnerWaker>,
+ stream_waker: Arc<InnerWaker>,
+ }
+}
+
+impl<St> fmt::Debug for FlattenUnordered<St>
+where
+ St: Stream + fmt::Debug,
+ St::Item: Stream + fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("FlattenUnordered")
+ .field("poll_state", &self.poll_state)
+ .field("inner_streams", &self.inner_streams)
+ .field("limit", &self.limit)
+ .field("stream", &self.stream)
+ .field("is_stream_done", &self.is_stream_done)
+ .finish()
+ }
+}
+
+impl<St> FlattenUnordered<St>
+where
+ St: Stream,
+ St::Item: Stream + Unpin,
+{
+ pub(super) fn new(stream: St, limit: Option<usize>) -> FlattenUnordered<St> {
+ let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
+
+ FlattenUnordered {
+ inner_streams: FuturesUnordered::new(),
+ stream,
+ is_stream_done: false,
+ limit: limit.and_then(NonZeroUsize::new),
+ inner_streams_waker: Arc::new(InnerWaker {
+ inner_waker: UnsafeCell::new(None),
+ poll_state: poll_state.clone(),
+ need_to_poll: NEED_TO_POLL_INNER_STREAMS,
+ }),
+ stream_waker: Arc::new(InnerWaker {
+ inner_waker: UnsafeCell::new(None),
+ poll_state: poll_state.clone(),
+ need_to_poll: NEED_TO_POLL_STREAM,
+ }),
+ poll_state,
+ }
+ }
+
+ delegate_access_inner!(stream, St, ());
+}
+
+impl<St> FlattenUnorderedProj<'_, St>
+where
+ St: Stream,
+{
+ /// Checks if current `inner_streams` size is less than optional limit.
+ fn is_exceeded_limit(&self) -> bool {
+ self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
+ }
+}
+
+impl<St> FusedStream for FlattenUnordered<St>
+where
+ St: FusedStream,
+ St::Item: FusedStream + Unpin,
+{
+ fn is_terminated(&self) -> bool {
+ self.stream.is_terminated() && self.inner_streams.is_empty()
+ }
+}
+
+impl<St> Stream for FlattenUnordered<St>
+where
+ St: Stream,
+ St::Item: Stream + Unpin,
+{
+ type Item = <St::Item as Stream>::Item;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let mut next_item = None;
+ let mut need_to_poll_next = NONE;
+
+ let mut this = self.as_mut().project();
+
+ let (mut poll_state_value, state_bomb) = match this.poll_state.start_polling() {
+ Some(value) => value,
+ _ => {
+ // Waker was called, just wait for the next poll
+ return Poll::Pending;
+ }
+ };
+
+ if poll_state_value & NEED_TO_POLL_STREAM != NONE {
+ // Safety: now state is `POLLING`.
+ let stream_waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) };
+
+ // Here we need to poll the base stream.
+ //
+ // To improve performance, we will attempt to place as many items as we can
+ // to the `FuturesUnordered` bucket before polling inner streams
+ loop {
+ if this.is_exceeded_limit() || *this.is_stream_done {
+ // We either exceeded the limit or the stream is exhausted
+ if !*this.is_stream_done {
+ // The stream needs to be polled in the next iteration
+ need_to_poll_next |= NEED_TO_POLL_STREAM;
+ }
+
+ break;
+ } else {
+ match this.stream.as_mut().poll_next(&mut Context::from_waker(&stream_waker)) {
+ Poll::Ready(Some(inner_stream)) => {
+ // Add new stream to the inner streams bucket
+ this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream));
+ // Inner streams must be polled afterward
+ poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
+ }
+ Poll::Ready(None) => {
+ // Mark the stream as done
+ *this.is_stream_done = true;
+ }
+ Poll::Pending => {
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
+ // Safety: now state is `POLLING`.
+ let inner_streams_waker =
+ unsafe { InnerWaker::replace_waker(this.inner_streams_waker, cx) };
+
+ match this
+ .inner_streams
+ .as_mut()
+ .poll_next(&mut Context::from_waker(&inner_streams_waker))
+ {
+ Poll::Ready(Some(Some((item, next_item_fut)))) => {
+ // Push next inner stream item future to the list of inner streams futures
+ this.inner_streams.as_mut().push(next_item_fut);
+ // Take the received item
+ next_item = Some(item);
+ // On the next iteration, inner streams must be polled again
+ need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
+ }
+ Poll::Ready(Some(None)) => {
+ // On the next iteration, inner streams must be polled again
+ need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
+ }
+ _ => {}
+ }
+ }
+
+ // We didn't have any `poll_next` panic, so it's time to deactivate the bomb
+ state_bomb.deactivate();
+
+ let mut force_wake =
+ // we need to poll the stream and didn't reach the limit yet
+ need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit()
+ // or we need to poll inner streams again
+ || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE;
+
+ // Stop polling and swap the latest state
+ poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake);
+ // If state was changed during `POLLING` phase, need to manually call a waker
+ force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE;
+
+ let is_done = *this.is_stream_done && this.inner_streams.is_empty();
+
+ if next_item.is_some() || is_done {
+ Poll::Ready(next_item)
+ } else {
+ if force_wake {
+ cx.waker().wake_by_ref();
+ }
+
+ Poll::Pending
+ }
+ }
+}
+
+// Forwarding impl of Sink from the underlying stream
+#[cfg(feature = "sink")]
+impl<St, Item> Sink<Item> for FlattenUnordered<St>
+where
+ St: Stream + Sink<Item>,
+{
+ type Error = St::Error;
+
+ delegate_sink!(stream, Item);
+}