diff options
Diffstat (limited to 'src/entry.rs')
-rw-r--r-- | src/entry.rs | 269 |
1 files changed, 176 insertions, 93 deletions
diff --git a/src/entry.rs b/src/entry.rs index f82a329..01f8ee4 100644 --- a/src/entry.rs +++ b/src/entry.rs @@ -1,7 +1,10 @@ use proc_macro::TokenStream; use proc_macro2::Span; -use quote::quote; -use syn::spanned::Spanned; +use quote::{quote, quote_spanned, ToTokens}; +use syn::parse::Parser; + +// syn::AttributeArgs does not implement syn::Parse +type AttributeArgs = syn::punctuated::Punctuated<syn::NestedMeta, syn::Token![,]>; #[derive(Clone, Copy, PartialEq)] enum RuntimeFlavor { @@ -28,12 +31,20 @@ struct FinalConfig { start_paused: Option<bool>, } +/// Config used in case of the attribute not being able to build a valid config +const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { + flavor: RuntimeFlavor::CurrentThread, + worker_threads: None, + start_paused: None, +}; + struct Configuration { rt_multi_thread_available: bool, default_flavor: RuntimeFlavor, flavor: Option<RuntimeFlavor>, worker_threads: Option<(usize, Span)>, start_paused: Option<(bool, Span)>, + is_test: bool, } impl Configuration { @@ -47,6 +58,7 @@ impl Configuration { flavor: None, worker_threads: None, start_paused: None, + is_test, } } @@ -92,16 +104,25 @@ impl Configuration { Ok(()) } + fn macro_name(&self) -> &'static str { + if self.is_test { + "tokio::test" + } else { + "tokio::main" + } + } + fn build(&self) -> Result<FinalConfig, syn::Error> { let flavor = self.flavor.unwrap_or(self.default_flavor); use RuntimeFlavor::*; let worker_threads = match (flavor, self.worker_threads) { (CurrentThread, Some((_, worker_threads_span))) => { - return Err(syn::Error::new( - worker_threads_span, - "The `worker_threads` option requires the `multi_thread` runtime flavor.", - )) + let msg = format!( + "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`", + self.macro_name(), + ); + return Err(syn::Error::new(worker_threads_span, msg)); } (CurrentThread, None) => None, (Threaded, worker_threads) if self.rt_multi_thread_available => { @@ -119,10 +140,11 @@ impl Configuration { let start_paused = match (flavor, self.start_paused) { (Threaded, Some((_, start_paused_span))) => { - return Err(syn::Error::new( - start_paused_span, - "The `start_paused` option requires the `current_thread` runtime flavor.", - )); + let msg = format!( + "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`", + self.macro_name(), + ); + return Err(syn::Error::new(start_paused_span, msg)); } (CurrentThread, Some((start_paused, _))) => Some(start_paused), (_, None) => None, @@ -142,12 +164,12 @@ fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error Ok(value) => Ok(value), Err(e) => Err(syn::Error::new( span, - format!("Failed to parse {} as integer: {}", field, e), + format!("Failed to parse value of `{}` as integer: {}", field, e), )), }, _ => Err(syn::Error::new( span, - format!("Failed to parse {} as integer.", field), + format!("Failed to parse value of `{}` as integer.", field), )), } } @@ -158,7 +180,7 @@ fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::E syn::Lit::Verbatim(s) => Ok(s.to_string()), _ => Err(syn::Error::new( span, - format!("Failed to parse {} as string.", field), + format!("Failed to parse value of `{}` as string.", field), )), } } @@ -168,71 +190,74 @@ fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Erro syn::Lit::Bool(b) => Ok(b.value), _ => Err(syn::Error::new( span, - format!("Failed to parse {} as bool.", field), + format!("Failed to parse value of `{}` as bool.", field), )), } } -fn parse_knobs( - mut input: syn::ItemFn, - args: syn::AttributeArgs, +fn build_config( + input: syn::ItemFn, + args: AttributeArgs, is_test: bool, rt_multi_thread: bool, -) -> Result<TokenStream, syn::Error> { - let sig = &mut input.sig; - let body = &input.block; - let attrs = &input.attrs; - let vis = input.vis; - - if sig.asyncness.is_none() { - let msg = "the async keyword is missing from the function declaration"; - return Err(syn::Error::new_spanned(sig.fn_token, msg)); +) -> Result<FinalConfig, syn::Error> { + if input.sig.asyncness.is_none() { + let msg = "the `async` keyword is missing from the function declaration"; + return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); } - sig.asyncness = None; - - let macro_name = if is_test { - "tokio::test" - } else { - "tokio::main" - }; let mut config = Configuration::new(is_test, rt_multi_thread); + let macro_name = config.macro_name(); for arg in args { match arg { syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => { - let ident = namevalue.path.get_ident(); - if ident.is_none() { - let msg = "Must have specified ident"; - return Err(syn::Error::new_spanned(namevalue, msg)); - } - match ident.unwrap().to_string().to_lowercase().as_str() { + let ident = namevalue + .path + .get_ident() + .ok_or_else(|| { + syn::Error::new_spanned(&namevalue, "Must have specified ident") + })? + .to_string() + .to_lowercase(); + match ident.as_str() { "worker_threads" => { - config.set_worker_threads(namevalue.lit.clone(), namevalue.span())?; + config.set_worker_threads( + namevalue.lit.clone(), + syn::spanned::Spanned::span(&namevalue.lit), + )?; } "flavor" => { - config.set_flavor(namevalue.lit.clone(), namevalue.span())?; + config.set_flavor( + namevalue.lit.clone(), + syn::spanned::Spanned::span(&namevalue.lit), + )?; } "start_paused" => { - config.set_start_paused(namevalue.lit.clone(), namevalue.span())?; + config.set_start_paused( + namevalue.lit.clone(), + syn::spanned::Spanned::span(&namevalue.lit), + )?; } "core_threads" => { let msg = "Attribute `core_threads` is renamed to `worker_threads`"; return Err(syn::Error::new_spanned(namevalue, msg)); } name => { - let msg = format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`", name); + let msg = format!( + "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", + name, + ); return Err(syn::Error::new_spanned(namevalue, msg)); } } } syn::NestedMeta::Meta(syn::Meta::Path(path)) => { - let ident = path.get_ident(); - if ident.is_none() { - let msg = "Must have specified ident"; - return Err(syn::Error::new_spanned(path, msg)); - } - let name = ident.unwrap().to_string().to_lowercase(); + let name = path + .get_ident() + .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))? + .to_string() + .to_lowercase(); let msg = match name.as_str() { "threaded_scheduler" | "multi_thread" => { format!( @@ -264,13 +289,35 @@ fn parse_knobs( } } - let config = config.build()?; + config.build() +} + +fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> TokenStream { + input.sig.asyncness = None; + + // If type mismatch occurs, the current rustc points to the last statement. + let (last_stmt_start_span, last_stmt_end_span) = { + let mut last_stmt = input + .block + .stmts + .last() + .map(ToTokens::into_token_stream) + .unwrap_or_default() + .into_iter(); + // `Span` on stable Rust has a limitation that only points to the first + // token, not the whole tokens. We can work around this limitation by + // using the first/last span of the tokens like + // `syn::Error::new_spanned` does. + let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); + let end = last_stmt.last().map_or(start, |t| t.span()); + (start, end) + }; let mut rt = match config.flavor { - RuntimeFlavor::CurrentThread => quote! { + RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> tokio::runtime::Builder::new_current_thread() }, - RuntimeFlavor::Threaded => quote! { + RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> tokio::runtime::Builder::new_multi_thread() }, }; @@ -281,65 +328,101 @@ fn parse_knobs( rt = quote! { #rt.start_paused(#v) }; } - let header = { - if is_test { - quote! { - #[::core::prelude::v1::test] - } - } else { - quote! {} + let header = if is_test { + quote! { + #[::core::prelude::v1::test] } + } else { + quote! {} }; - let result = quote! { - #header - #(#attrs)* - #vis #sig { - #rt + let body = &input.block; + let brace_token = input.block.brace_token; + let (tail_return, tail_semicolon) = match body.stmts.last() { + Some(syn::Stmt::Semi(expr, _)) => match expr { + syn::Expr::Return(_) => (quote! { return }, quote! { ; }), + _ => match &input.sig.output { + syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) => + { + (quote! {}, quote! { ; }) // unit + } + syn::ReturnType::Default => (quote! {}, quote! { ; }), // unit + syn::ReturnType::Type(..) => (quote! {}, quote! {}), // ! or another + }, + }, + _ => (quote! {}, quote! {}), + }; + input.block = syn::parse2(quote_spanned! {last_stmt_end_span=> + { + let body = async #body; + #[allow(clippy::expect_used)] + #tail_return #rt .enable_all() .build() - .unwrap() - .block_on(async #body) + .expect("Failed building the Runtime") + .block_on(body)#tail_semicolon } + }) + .expect("Parsing failure"); + input.block.brace_token = brace_token; + + let result = quote! { + #header + #input }; - Ok(result.into()) + result.into() +} + +fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { + tokens.extend(TokenStream::from(error.into_compile_error())); + tokens } #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { - let input = syn::parse_macro_input!(item as syn::ItemFn); - let args = syn::parse_macro_input!(args as syn::AttributeArgs); + // If any of the steps for this macro fail, we still want to expand to an item that is as close + // to the expected output as possible. This helps out IDEs such that completions and other + // related features keep working. + let input: syn::ItemFn = match syn::parse(item.clone()) { + Ok(it) => it, + Err(e) => return token_stream_with_error(item, e), + }; - if input.sig.ident == "main" && !input.sig.inputs.is_empty() { + let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() { let msg = "the main function cannot accept arguments"; - return syn::Error::new_spanned(&input.sig.ident, msg) - .to_compile_error() - .into(); - } + Err(syn::Error::new_spanned(&input.sig.ident, msg)) + } else { + AttributeArgs::parse_terminated + .parse(args) + .and_then(|args| build_config(input.clone(), args, false, rt_multi_thread)) + }; - parse_knobs(input, args, false, rt_multi_thread).unwrap_or_else(|e| e.to_compile_error().into()) + match config { + Ok(config) => parse_knobs(input, false, config), + Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e), + } } pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { - let input = syn::parse_macro_input!(item as syn::ItemFn); - let args = syn::parse_macro_input!(args as syn::AttributeArgs); - - for attr in &input.attrs { - if attr.path.is_ident("test") { - let msg = "second test attribute is supplied"; - return syn::Error::new_spanned(&attr, msg) - .to_compile_error() - .into(); - } - } + // If any of the steps for this macro fail, we still want to expand to an item that is as close + // to the expected output as possible. This helps out IDEs such that completions and other + // related features keep working. + let input: syn::ItemFn = match syn::parse(item.clone()) { + Ok(it) => it, + Err(e) => return token_stream_with_error(item, e), + }; + let config = if let Some(attr) = input.attrs.iter().find(|attr| attr.path.is_ident("test")) { + let msg = "second test attribute is supplied"; + Err(syn::Error::new_spanned(&attr, msg)) + } else { + AttributeArgs::parse_terminated + .parse(args) + .and_then(|args| build_config(input.clone(), args, true, rt_multi_thread)) + }; - if !input.sig.inputs.is_empty() { - let msg = "the test function cannot accept arguments"; - return syn::Error::new_spanned(&input.sig.inputs, msg) - .to_compile_error() - .into(); + match config { + Ok(config) => parse_knobs(input, true, config), + Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e), } - - parse_knobs(input, args, true, rt_multi_thread).unwrap_or_else(|e| e.to_compile_error().into()) } |