diff options
Diffstat (limited to 'src/stream/select.rs')
-rw-r--r-- | src/stream/select.rs | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/src/stream/select.rs b/src/stream/select.rs index b5fb813..36503e4 100644 --- a/src/stream/select.rs +++ b/src/stream/select.rs @@ -2,18 +2,20 @@ use crate::stream::{StreamExt, Fuse}; use core::pin::Pin; use futures_core::stream::{FusedStream, Stream}; use futures_core::task::{Context, Poll}; +use pin_project::{pin_project, project}; /// Stream for the [`select()`] function. +#[pin_project] #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Select<St1, St2> { + #[pin] stream1: Fuse<St1>, + #[pin] stream2: Fuse<St2>, flag: bool, } -impl<St1: Unpin, St2: Unpin> Unpin for Select<St1, St2> {} - /// This function will attempt to pull items from both streams. Each /// stream will be polled in a round-robin fashion, and whenever a stream is /// ready to yield an item that item is yielded. @@ -56,11 +58,11 @@ impl<St1, St2> Select<St1, St2> { /// /// Note that care must be taken to avoid tampering with the state of the /// stream which may otherwise confuse this combinator. + #[project] pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) { - unsafe { - let Self { stream1, stream2, .. } = self.get_unchecked_mut(); - (Pin::new_unchecked(stream1).get_pin_mut(), Pin::new_unchecked(stream2).get_pin_mut()) - } + #[project] + let Select { stream1, stream2, .. } = self.project(); + (stream1.get_pin_mut(), stream2.get_pin_mut()) } /// Consumes this combinator, returning the underlying streams. @@ -87,14 +89,13 @@ impl<St1, St2> Stream for Select<St1, St2> { type Item = St1::Item; + #[project] fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<St1::Item>> { - let Select { flag, stream1, stream2 } = - unsafe { self.get_unchecked_mut() }; - let stream1 = unsafe { Pin::new_unchecked(stream1) }; - let stream2 = unsafe { Pin::new_unchecked(stream2) }; + #[project] + let Select { flag, stream1, stream2 } = self.project(); if !*flag { poll_inner(flag, stream1, stream2, cx) |