aboutsummaryrefslogtreecommitdiff
path: root/tests/stream.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tests/stream.rs')
-rw-r--r--tests/stream.rs42
1 files changed, 42 insertions, 0 deletions
diff --git a/tests/stream.rs b/tests/stream.rs
index 71ec654..5cde458 100644
--- a/tests/stream.rs
+++ b/tests/stream.rs
@@ -1,5 +1,9 @@
+use std::cell::Cell;
use std::iter;
+use std::pin::Pin;
+use std::rc::Rc;
use std::sync::Arc;
+use std::task::Context;
use futures::channel::mpsc;
use futures::executor::block_on;
@@ -9,6 +13,7 @@ use futures::sink::SinkExt;
use futures::stream::{self, StreamExt};
use futures::task::Poll;
use futures::{ready, FutureExt};
+use futures_core::Stream;
use futures_test::task::noop_context;
#[test]
@@ -419,3 +424,40 @@ fn ready_chunks() {
assert_eq!(s.next().await.unwrap(), vec![4]);
});
}
+
+struct SlowStream {
+ times_should_poll: usize,
+ times_polled: Rc<Cell<usize>>,
+}
+impl Stream for SlowStream {
+ type Item = usize;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.times_polled.set(self.times_polled.get() + 1);
+ if self.times_polled.get() % 2 == 0 {
+ cx.waker().wake_by_ref();
+ return Poll::Pending;
+ }
+ if self.times_polled.get() >= self.times_should_poll {
+ return Poll::Ready(None);
+ }
+ Poll::Ready(Some(self.times_polled.get()))
+ }
+}
+
+#[test]
+fn select_with_strategy_doesnt_terminate_early() {
+ for side in [stream::PollNext::Left, stream::PollNext::Right] {
+ let times_should_poll = 10;
+ let count = Rc::new(Cell::new(0));
+ let b = stream::iter([10, 20]);
+
+ let mut selected = stream::select_with_strategy(
+ SlowStream { times_should_poll, times_polled: count.clone() },
+ b,
+ |_: &mut ()| side,
+ );
+ block_on(async move { while selected.next().await.is_some() {} });
+ assert_eq!(count.get(), times_should_poll + 1);
+ }
+}