aboutsummaryrefslogtreecommitdiff
path: root/src/stream/select_with_strategy.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/stream/select_with_strategy.rs')
-rw-r--r--src/stream/select_with_strategy.rs143
1 files changed, 109 insertions, 34 deletions
diff --git a/src/stream/select_with_strategy.rs b/src/stream/select_with_strategy.rs
index bd86990..224d5f8 100644
--- a/src/stream/select_with_strategy.rs
+++ b/src/stream/select_with_strategy.rs
@@ -1,5 +1,4 @@
use super::assert_stream;
-use crate::stream::{Fuse, StreamExt};
use core::{fmt, pin::Pin};
use futures_core::stream::{FusedStream, Stream};
use futures_core::task::{Context, Poll};
@@ -18,13 +17,15 @@ impl PollNext {
/// Toggle the value and return the old one.
pub fn toggle(&mut self) -> Self {
let old = *self;
+ *self = self.other();
+ old
+ }
+ fn other(&self) -> PollNext {
match self {
- PollNext::Left => *self = PollNext::Right,
- PollNext::Right => *self = PollNext::Left,
+ PollNext::Left => PollNext::Right,
+ PollNext::Right => PollNext::Left,
}
-
- old
}
}
@@ -34,14 +35,41 @@ impl Default for PollNext {
}
}
+enum InternalState {
+ Start,
+ LeftFinished,
+ RightFinished,
+ BothFinished,
+}
+
+impl InternalState {
+ fn finish(&mut self, ps: PollNext) {
+ match (&self, ps) {
+ (InternalState::Start, PollNext::Left) => {
+ *self = InternalState::LeftFinished;
+ }
+ (InternalState::Start, PollNext::Right) => {
+ *self = InternalState::RightFinished;
+ }
+ (InternalState::LeftFinished, PollNext::Right)
+ | (InternalState::RightFinished, PollNext::Left) => {
+ *self = InternalState::BothFinished;
+ }
+ _ => {}
+ }
+ }
+}
+
pin_project! {
/// Stream for the [`select_with_strategy()`] function. See function docs for details.
#[must_use = "streams do nothing unless polled"]
+ #[project = SelectWithStrategyProj]
pub struct SelectWithStrategy<St1, St2, Clos, State> {
#[pin]
- stream1: Fuse<St1>,
+ stream1: St1,
#[pin]
- stream2: Fuse<St2>,
+ stream2: St2,
+ internal_state: InternalState,
state: State,
clos: Clos,
}
@@ -120,9 +148,10 @@ where
State: Default,
{
assert_stream::<St1::Item, _>(SelectWithStrategy {
- stream1: stream1.fuse(),
- stream2: stream2.fuse(),
+ stream1,
+ stream2,
state: Default::default(),
+ internal_state: InternalState::Start,
clos: which,
})
}
@@ -131,7 +160,7 @@ impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
/// Acquires a reference to the underlying streams that this combinator is
/// pulling from.
pub fn get_ref(&self) -> (&St1, &St2) {
- (self.stream1.get_ref(), self.stream2.get_ref())
+ (&self.stream1, &self.stream2)
}
/// Acquires a mutable reference to the underlying streams that this
@@ -140,7 +169,7 @@ impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
/// Note that care must be taken to avoid tampering with the state of the
/// stream which may otherwise confuse this combinator.
pub fn get_mut(&mut self) -> (&mut St1, &mut St2) {
- (self.stream1.get_mut(), self.stream2.get_mut())
+ (&mut self.stream1, &mut self.stream2)
}
/// Acquires a pinned mutable reference to the underlying streams that this
@@ -150,7 +179,7 @@ impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
/// stream which may otherwise confuse this combinator.
pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) {
let this = self.project();
- (this.stream1.get_pin_mut(), this.stream2.get_pin_mut())
+ (this.stream1, this.stream2)
}
/// Consumes this combinator, returning the underlying streams.
@@ -158,7 +187,7 @@ impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
/// Note that this may discard intermediate state of this combinator, so
/// care should be taken to avoid losing resources when this is called.
pub fn into_inner(self) -> (St1, St2) {
- (self.stream1.into_inner(), self.stream2.into_inner())
+ (self.stream1, self.stream2)
}
}
@@ -169,47 +198,93 @@ where
Clos: FnMut(&mut State) -> PollNext,
{
fn is_terminated(&self) -> bool {
- self.stream1.is_terminated() && self.stream2.is_terminated()
+ match self.internal_state {
+ InternalState::BothFinished => true,
+ _ => false,
+ }
}
}
-impl<St1, St2, Clos, State> Stream for SelectWithStrategy<St1, St2, Clos, State>
+#[inline]
+fn poll_side<St1, St2, Clos, State>(
+ select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
+ side: PollNext,
+ cx: &mut Context<'_>,
+) -> Poll<Option<St1::Item>>
where
St1: Stream,
St2: Stream<Item = St1::Item>,
- Clos: FnMut(&mut State) -> PollNext,
{
- type Item = St1::Item;
-
- fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
- let this = self.project();
-
- match (this.clos)(this.state) {
- PollNext::Left => poll_inner(this.stream1, this.stream2, cx),
- PollNext::Right => poll_inner(this.stream2, this.stream1, cx),
- }
+ match side {
+ PollNext::Left => select.stream1.as_mut().poll_next(cx),
+ PollNext::Right => select.stream2.as_mut().poll_next(cx),
}
}
-fn poll_inner<St1, St2>(
- a: Pin<&mut St1>,
- b: Pin<&mut St2>,
+#[inline]
+fn poll_inner<St1, St2, Clos, State>(
+ select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
+ side: PollNext,
cx: &mut Context<'_>,
) -> Poll<Option<St1::Item>>
where
St1: Stream,
St2: Stream<Item = St1::Item>,
{
- let a_done = match a.poll_next(cx) {
+ let first_done = match poll_side(select, side, cx) {
Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
- Poll::Ready(None) => true,
+ Poll::Ready(None) => {
+ select.internal_state.finish(side);
+ true
+ }
Poll::Pending => false,
};
+ let other = side.other();
+ match poll_side(select, other, cx) {
+ Poll::Ready(None) => {
+ select.internal_state.finish(other);
+ if first_done {
+ Poll::Ready(None)
+ } else {
+ Poll::Pending
+ }
+ }
+ a => a,
+ }
+}
+
+impl<St1, St2, Clos, State> Stream for SelectWithStrategy<St1, St2, Clos, State>
+where
+ St1: Stream,
+ St2: Stream<Item = St1::Item>,
+ Clos: FnMut(&mut State) -> PollNext,
+{
+ type Item = St1::Item;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
+ let mut this = self.project();
- match b.poll_next(cx) {
- Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
- Poll::Ready(None) if a_done => Poll::Ready(None),
- Poll::Ready(None) | Poll::Pending => Poll::Pending,
+ match this.internal_state {
+ InternalState::Start => {
+ let next_side = (this.clos)(this.state);
+ poll_inner(&mut this, next_side, cx)
+ }
+ InternalState::LeftFinished => match this.stream2.poll_next(cx) {
+ Poll::Ready(None) => {
+ *this.internal_state = InternalState::BothFinished;
+ Poll::Ready(None)
+ }
+ a => a,
+ },
+ InternalState::RightFinished => match this.stream1.poll_next(cx) {
+ Poll::Ready(None) => {
+ *this.internal_state = InternalState::BothFinished;
+ Poll::Ready(None)
+ }
+ a => a,
+ },
+ InternalState::BothFinished => Poll::Ready(None),
+ }
}
}