summaryrefslogtreecommitdiff
path: root/src/join.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/join.rs')
-rw-r--r--src/join.rs170
1 files changed, 170 insertions, 0 deletions
diff --git a/src/join.rs b/src/join.rs
new file mode 100644
index 0000000..2a7b11c
--- /dev/null
+++ b/src/join.rs
@@ -0,0 +1,170 @@
+//! The futures-rs `join! macro implementation.
+
+use proc_macro::TokenStream;
+use proc_macro2::{Span, TokenStream as TokenStream2};
+use quote::{format_ident, quote};
+use syn::parse::{Parse, ParseStream};
+use syn::{parenthesized, parse_quote, Expr, Ident, Token};
+
+mod kw {
+ syn::custom_keyword!(futures_crate_path);
+}
+
+#[derive(Default)]
+struct Join {
+ futures_crate_path: Option<syn::Path>,
+ fut_exprs: Vec<Expr>,
+}
+
+impl Parse for Join {
+ fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
+ let mut join = Join::default();
+
+ // When `futures_crate_path(::path::to::futures::lib)` is provided,
+ // it sets the path through which futures library functions will be
+ // accessed.
+ if input.peek(kw::futures_crate_path) {
+ input.parse::<kw::futures_crate_path>()?;
+ let content;
+ parenthesized!(content in input);
+ join.futures_crate_path = Some(content.parse()?);
+ }
+
+ while !input.is_empty() {
+ join.fut_exprs.push(input.parse::<Expr>()?);
+
+ if !input.is_empty() {
+ input.parse::<Token![,]>()?;
+ }
+ }
+
+ Ok(join)
+ }
+}
+
+fn bind_futures(
+ futures_crate: &syn::Path,
+ fut_exprs: Vec<Expr>,
+ span: Span,
+) -> (Vec<TokenStream2>, Vec<Ident>) {
+ let mut future_let_bindings = Vec::with_capacity(fut_exprs.len());
+ let future_names: Vec<_> = fut_exprs
+ .into_iter()
+ .enumerate()
+ .map(|(i, expr)| {
+ let name = format_ident!("_fut{}", i, span = span);
+ future_let_bindings.push(quote! {
+ // Move future into a local so that it is pinned in one place and
+ // is no longer accessible by the end user.
+ let mut #name = #futures_crate::future::maybe_done(#expr);
+ });
+ name
+ })
+ .collect();
+
+ (future_let_bindings, future_names)
+}
+
+/// The `join!` macro.
+pub(crate) fn join(input: TokenStream) -> TokenStream {
+ let parsed = syn::parse_macro_input!(input as Join);
+
+ let futures_crate = parsed
+ .futures_crate_path
+ .unwrap_or_else(|| parse_quote!(::futures_util));
+
+ // should be def_site, but that's unstable
+ let span = Span::call_site();
+
+ let (future_let_bindings, future_names) = bind_futures(&futures_crate, parsed.fut_exprs, span);
+
+ let poll_futures = future_names.iter().map(|fut| {
+ quote! {
+ __all_done &= #futures_crate::core_reexport::future::Future::poll(
+ unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }, __cx).is_ready();
+ }
+ });
+ let take_outputs = future_names.iter().map(|fut| {
+ quote! {
+ unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }.take_output().unwrap(),
+ }
+ });
+
+ TokenStream::from(quote! { {
+ #( #future_let_bindings )*
+
+ #futures_crate::future::poll_fn(move |__cx: &mut #futures_crate::task::Context<'_>| {
+ let mut __all_done = true;
+ #( #poll_futures )*
+ if __all_done {
+ #futures_crate::core_reexport::task::Poll::Ready((
+ #( #take_outputs )*
+ ))
+ } else {
+ #futures_crate::core_reexport::task::Poll::Pending
+ }
+ }).await
+ } })
+}
+
+/// The `try_join!` macro.
+pub(crate) fn try_join(input: TokenStream) -> TokenStream {
+ let parsed = syn::parse_macro_input!(input as Join);
+
+ let futures_crate = parsed
+ .futures_crate_path
+ .unwrap_or_else(|| parse_quote!(::futures_util));
+
+ // should be def_site, but that's unstable
+ let span = Span::call_site();
+
+ let (future_let_bindings, future_names) = bind_futures(&futures_crate, parsed.fut_exprs, span);
+
+ let poll_futures = future_names.iter().map(|fut| {
+ quote! {
+ if #futures_crate::core_reexport::future::Future::poll(
+ unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }, __cx).is_pending()
+ {
+ __all_done = false;
+ } else if unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }.output_mut().unwrap().is_err() {
+ // `.err().unwrap()` rather than `.unwrap_err()` so that we don't introduce
+ // a `T: Debug` bound.
+ // Also, for an error type of ! any code after `err().unwrap()` is unreachable.
+ #[allow(unreachable_code)]
+ return #futures_crate::core_reexport::task::Poll::Ready(
+ #futures_crate::core_reexport::result::Result::Err(
+ unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }.take_output().unwrap().err().unwrap()
+ )
+ );
+ }
+ }
+ });
+ let take_outputs = future_names.iter().map(|fut| {
+ quote! {
+ // `.ok().unwrap()` rather than `.unwrap()` so that we don't introduce
+ // an `E: Debug` bound.
+ // Also, for an ok type of ! any code after `ok().unwrap()` is unreachable.
+ #[allow(unreachable_code)]
+ unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }.take_output().unwrap().ok().unwrap(),
+ }
+ });
+
+ TokenStream::from(quote! { {
+ #( #future_let_bindings )*
+
+ #[allow(clippy::diverging_sub_expression)]
+ #futures_crate::future::poll_fn(move |__cx: &mut #futures_crate::task::Context<'_>| {
+ let mut __all_done = true;
+ #( #poll_futures )*
+ if __all_done {
+ #futures_crate::core_reexport::task::Poll::Ready(
+ #futures_crate::core_reexport::result::Result::Ok((
+ #( #take_outputs )*
+ ))
+ )
+ } else {
+ #futures_crate::core_reexport::task::Poll::Pending
+ }
+ }).await
+ } })
+}