diff options
Diffstat (limited to 'src/entry.rs')
-rw-r--r-- | src/entry.rs | 116 |
1 files changed, 91 insertions, 25 deletions
diff --git a/src/entry.rs b/src/entry.rs index 5cb4a49..6460e70 100644 --- a/src/entry.rs +++ b/src/entry.rs @@ -1,5 +1,5 @@ use proc_macro::TokenStream; -use proc_macro2::Span; +use proc_macro2::{Ident, Span}; use quote::{quote, quote_spanned, ToTokens}; use syn::parse::Parser; @@ -29,6 +29,7 @@ struct FinalConfig { flavor: RuntimeFlavor, worker_threads: Option<usize>, start_paused: Option<bool>, + crate_name: Option<String>, } /// Config used in case of the attribute not being able to build a valid config @@ -36,6 +37,7 @@ const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { flavor: RuntimeFlavor::CurrentThread, worker_threads: None, start_paused: None, + crate_name: None, }; struct Configuration { @@ -45,6 +47,7 @@ struct Configuration { worker_threads: Option<(usize, Span)>, start_paused: Option<(bool, Span)>, is_test: bool, + crate_name: Option<String>, } impl Configuration { @@ -59,6 +62,7 @@ impl Configuration { worker_threads: None, start_paused: None, is_test, + crate_name: None, } } @@ -104,6 +108,15 @@ impl Configuration { Ok(()) } + fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> { + if self.crate_name.is_some() { + return Err(syn::Error::new(span, "`crate` set multiple times.")); + } + let name_ident = parse_ident(name, span, "crate")?; + self.crate_name = Some(name_ident.to_string()); + Ok(()) + } + fn macro_name(&self) -> &'static str { if self.is_test { "tokio::test" @@ -151,6 +164,7 @@ impl Configuration { }; Ok(FinalConfig { + crate_name: self.crate_name.clone(), flavor, worker_threads, start_paused, @@ -185,6 +199,27 @@ fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::E } } +fn parse_ident(lit: syn::Lit, span: Span, field: &str) -> Result<Ident, syn::Error> { + match lit { + syn::Lit::Str(s) => { + let err = syn::Error::new( + span, + format!( + "Failed to parse value of `{}` as ident: \"{}\"", + field, + s.value() + ), + ); + let path = s.parse::<syn::Path>().map_err(|_| err.clone())?; + path.get_ident().cloned().ok_or(err) + } + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as ident.", field), + )), + } +} + fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> { match bool { syn::Lit::Bool(b) => Ok(b.value), @@ -243,9 +278,15 @@ fn build_config( let msg = "Attribute `core_threads` is renamed to `worker_threads`"; return Err(syn::Error::new_spanned(namevalue, msg)); } + "crate" => { + config.set_crate_name( + namevalue.lit.clone(), + syn::spanned::Spanned::span(&namevalue.lit), + )?; + } name => { let msg = format!( - "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", + "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name, ); return Err(syn::Error::new_spanned(namevalue, msg)); @@ -275,7 +316,7 @@ fn build_config( format!("The `{}` attribute requires an argument.", name) } name => { - format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", name) + format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name) } }; return Err(syn::Error::new_spanned(path, msg)); @@ -313,12 +354,16 @@ fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> To (start, end) }; + let crate_name = config.crate_name.as_deref().unwrap_or("tokio"); + + let crate_ident = Ident::new(crate_name, last_stmt_start_span); + let mut rt = match config.flavor { RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> - tokio::runtime::Builder::new_current_thread() + #crate_ident::runtime::Builder::new_current_thread() }, RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> - tokio::runtime::Builder::new_multi_thread() + #crate_ident::runtime::Builder::new_multi_thread() }, }; if let Some(v) = config.worker_threads { @@ -338,29 +383,50 @@ fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> To let body = &input.block; let brace_token = input.block.brace_token; - let (tail_return, tail_semicolon) = match body.stmts.last() { - Some(syn::Stmt::Semi(syn::Expr::Return(_), _)) => (quote! { return }, quote! { ; }), - Some(syn::Stmt::Semi(..)) | Some(syn::Stmt::Local(..)) | None => { - match &input.sig.output { - syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) => - { - (quote! {}, quote! { ; }) // unit - } - syn::ReturnType::Default => (quote! {}, quote! { ; }), // unit - syn::ReturnType::Type(..) => (quote! {}, quote! {}), // ! or another - } - } - _ => (quote! {}, quote! {}), - }; - input.block = syn::parse2(quote_spanned! {last_stmt_end_span=> + let body_ident = quote! { body }; + let block_expr = quote_spanned! {last_stmt_end_span=> + #[allow(clippy::expect_used, clippy::diverging_sub_expression)] { - let body = async #body; - #[allow(clippy::expect_used)] - #tail_return #rt + return #rt .enable_all() .build() .expect("Failed building the Runtime") - .block_on(body)#tail_semicolon + .block_on(#body_ident); + } + }; + + // For test functions pin the body to the stack and use `Pin<&mut dyn + // Future>` to reduce the amount of `Runtime::block_on` (and related + // functions) copies we generate during compilation due to the generic + // parameter `F` (the future to block on). This could have an impact on + // performance, but because it's only for testing it's unlikely to be very + // large. + // + // We don't do this for the main function as it should only be used once so + // there will be no benefit. + let body = if is_test { + let output_type = match &input.sig.output { + // For functions with no return value syn doesn't print anything, + // but that doesn't work as `Output` for our boxed `Future`, so + // default to `()` (the same type as the function output). + syn::ReturnType::Default => quote! { () }, + syn::ReturnType::Type(_, ret_type) => quote! { #ret_type }, + }; + quote! { + let body = async #body; + #crate_ident::pin!(body); + let body: ::std::pin::Pin<&mut dyn ::std::future::Future<Output = #output_type>> = body; + } + } else { + quote! { + let body = async #body; + } + }; + + input.block = syn::parse2(quote! { + { + #body + #block_expr } }) .expect("Parsing failure"); @@ -414,7 +480,7 @@ pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) }; let config = if let Some(attr) = input.attrs.iter().find(|attr| attr.path.is_ident("test")) { let msg = "second test attribute is supplied"; - Err(syn::Error::new_spanned(&attr, msg)) + Err(syn::Error::new_spanned(attr, msg)) } else { AttributeArgs::parse_terminated .parse(args) |