summaryrefslogtreecommitdiff
path: root/src/stream_select.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/stream_select.rs')
-rw-r--r--src/stream_select.rs113
1 files changed, 113 insertions, 0 deletions
diff --git a/src/stream_select.rs b/src/stream_select.rs
new file mode 100644
index 0000000..9927b53
--- /dev/null
+++ b/src/stream_select.rs
@@ -0,0 +1,113 @@
+use proc_macro2::TokenStream;
+use quote::{format_ident, quote, ToTokens};
+use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token};
+
+/// The `stream_select!` macro.
+pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> {
+ let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?;
+ if args.len() < 2 {
+ return Ok(quote! {
+ compile_error!("stream select macro needs at least two arguments.")
+ });
+ }
+ let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
+ let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
+ let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>();
+ let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();
+ let args = args.iter().map(|e| e.to_token_stream());
+
+ Ok(quote! {
+ {
+ #[derive(Debug)]
+ struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*);
+
+ enum StreamEnum<#(#generic_idents),*> {
+ #(
+ #generic_idents(#generic_idents)
+ ),*,
+ None,
+ }
+
+ impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*>
+ where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
+ {
+ type Item = ITEM;
+
+ fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
+ match self.get_mut() {
+ #(
+ Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx)
+ ),*,
+ Self::None => panic!("StreamEnum::None should never be polled!"),
+ }
+ }
+ }
+
+ impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
+ where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
+ {
+ type Item = ITEM;
+
+ fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
+ let Self(#(ref mut #field_idents),*) = self.get_mut();
+ #(
+ let mut #field_idents_2 = false;
+ )*
+ let mut any_pending = false;
+ {
+ let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*];
+ __futures_crate::async_await::shuffle(&mut stream_array);
+
+ for mut s in stream_array {
+ if let StreamEnum::None = s {
+ continue;
+ } else {
+ match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) {
+ r @ __futures_crate::task::Poll::Ready(Some(_)) => {
+ return r;
+ },
+ __futures_crate::task::Poll::Pending => {
+ any_pending = true;
+ },
+ __futures_crate::task::Poll::Ready(None) => {
+ match s {
+ #(
+ StreamEnum::#generic_idents(_) => { #field_idents_2 = true; }
+ ),*,
+ StreamEnum::None => panic!("StreamEnum::None should never be polled!"),
+ }
+ },
+ }
+ }
+ }
+ }
+ #(
+ if #field_idents_2 {
+ *#field_idents = None;
+ }
+ )*
+ if any_pending {
+ __futures_crate::task::Poll::Pending
+ } else {
+ __futures_crate::task::Poll::Ready(None)
+ }
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ let mut s = (0, Some(0));
+ #(
+ if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) {
+ s.0 += new_hint.0;
+ // We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
+ s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
+ }
+ )*
+ s
+ }
+ }
+
+ StreamSelect(#(Some(#args)),*)
+
+ }
+ })
+}