aboutsummaryrefslogtreecommitdiff
path: root/src/entry.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/entry.rs')
-rw-r--r--src/entry.rs116
1 files changed, 91 insertions, 25 deletions
diff --git a/src/entry.rs b/src/entry.rs
index 5cb4a49..6460e70 100644
--- a/src/entry.rs
+++ b/src/entry.rs
@@ -1,5 +1,5 @@
use proc_macro::TokenStream;
-use proc_macro2::Span;
+use proc_macro2::{Ident, Span};
use quote::{quote, quote_spanned, ToTokens};
use syn::parse::Parser;
@@ -29,6 +29,7 @@ struct FinalConfig {
flavor: RuntimeFlavor,
worker_threads: Option<usize>,
start_paused: Option<bool>,
+ crate_name: Option<String>,
}
/// Config used in case of the attribute not being able to build a valid config
@@ -36,6 +37,7 @@ const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig {
flavor: RuntimeFlavor::CurrentThread,
worker_threads: None,
start_paused: None,
+ crate_name: None,
};
struct Configuration {
@@ -45,6 +47,7 @@ struct Configuration {
worker_threads: Option<(usize, Span)>,
start_paused: Option<(bool, Span)>,
is_test: bool,
+ crate_name: Option<String>,
}
impl Configuration {
@@ -59,6 +62,7 @@ impl Configuration {
worker_threads: None,
start_paused: None,
is_test,
+ crate_name: None,
}
}
@@ -104,6 +108,15 @@ impl Configuration {
Ok(())
}
+ fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> {
+ if self.crate_name.is_some() {
+ return Err(syn::Error::new(span, "`crate` set multiple times."));
+ }
+ let name_ident = parse_ident(name, span, "crate")?;
+ self.crate_name = Some(name_ident.to_string());
+ Ok(())
+ }
+
fn macro_name(&self) -> &'static str {
if self.is_test {
"tokio::test"
@@ -151,6 +164,7 @@ impl Configuration {
};
Ok(FinalConfig {
+ crate_name: self.crate_name.clone(),
flavor,
worker_threads,
start_paused,
@@ -185,6 +199,27 @@ fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::E
}
}
+fn parse_ident(lit: syn::Lit, span: Span, field: &str) -> Result<Ident, syn::Error> {
+ match lit {
+ syn::Lit::Str(s) => {
+ let err = syn::Error::new(
+ span,
+ format!(
+ "Failed to parse value of `{}` as ident: \"{}\"",
+ field,
+ s.value()
+ ),
+ );
+ let path = s.parse::<syn::Path>().map_err(|_| err.clone())?;
+ path.get_ident().cloned().ok_or(err)
+ }
+ _ => Err(syn::Error::new(
+ span,
+ format!("Failed to parse value of `{}` as ident.", field),
+ )),
+ }
+}
+
fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
match bool {
syn::Lit::Bool(b) => Ok(b.value),
@@ -243,9 +278,15 @@ fn build_config(
let msg = "Attribute `core_threads` is renamed to `worker_threads`";
return Err(syn::Error::new_spanned(namevalue, msg));
}
+ "crate" => {
+ config.set_crate_name(
+ namevalue.lit.clone(),
+ syn::spanned::Spanned::span(&namevalue.lit),
+ )?;
+ }
name => {
let msg = format!(
- "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`",
+ "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`",
name,
);
return Err(syn::Error::new_spanned(namevalue, msg));
@@ -275,7 +316,7 @@ fn build_config(
format!("The `{}` attribute requires an argument.", name)
}
name => {
- format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", name)
+ format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name)
}
};
return Err(syn::Error::new_spanned(path, msg));
@@ -313,12 +354,16 @@ fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> To
(start, end)
};
+ let crate_name = config.crate_name.as_deref().unwrap_or("tokio");
+
+ let crate_ident = Ident::new(crate_name, last_stmt_start_span);
+
let mut rt = match config.flavor {
RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=>
- tokio::runtime::Builder::new_current_thread()
+ #crate_ident::runtime::Builder::new_current_thread()
},
RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=>
- tokio::runtime::Builder::new_multi_thread()
+ #crate_ident::runtime::Builder::new_multi_thread()
},
};
if let Some(v) = config.worker_threads {
@@ -338,29 +383,50 @@ fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> To
let body = &input.block;
let brace_token = input.block.brace_token;
- let (tail_return, tail_semicolon) = match body.stmts.last() {
- Some(syn::Stmt::Semi(syn::Expr::Return(_), _)) => (quote! { return }, quote! { ; }),
- Some(syn::Stmt::Semi(..)) | Some(syn::Stmt::Local(..)) | None => {
- match &input.sig.output {
- syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) =>
- {
- (quote! {}, quote! { ; }) // unit
- }
- syn::ReturnType::Default => (quote! {}, quote! { ; }), // unit
- syn::ReturnType::Type(..) => (quote! {}, quote! {}), // ! or another
- }
- }
- _ => (quote! {}, quote! {}),
- };
- input.block = syn::parse2(quote_spanned! {last_stmt_end_span=>
+ let body_ident = quote! { body };
+ let block_expr = quote_spanned! {last_stmt_end_span=>
+ #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
{
- let body = async #body;
- #[allow(clippy::expect_used)]
- #tail_return #rt
+ return #rt
.enable_all()
.build()
.expect("Failed building the Runtime")
- .block_on(body)#tail_semicolon
+ .block_on(#body_ident);
+ }
+ };
+
+ // For test functions pin the body to the stack and use `Pin<&mut dyn
+ // Future>` to reduce the amount of `Runtime::block_on` (and related
+ // functions) copies we generate during compilation due to the generic
+ // parameter `F` (the future to block on). This could have an impact on
+ // performance, but because it's only for testing it's unlikely to be very
+ // large.
+ //
+ // We don't do this for the main function as it should only be used once so
+ // there will be no benefit.
+ let body = if is_test {
+ let output_type = match &input.sig.output {
+ // For functions with no return value syn doesn't print anything,
+ // but that doesn't work as `Output` for our boxed `Future`, so
+ // default to `()` (the same type as the function output).
+ syn::ReturnType::Default => quote! { () },
+ syn::ReturnType::Type(_, ret_type) => quote! { #ret_type },
+ };
+ quote! {
+ let body = async #body;
+ #crate_ident::pin!(body);
+ let body: ::std::pin::Pin<&mut dyn ::std::future::Future<Output = #output_type>> = body;
+ }
+ } else {
+ quote! {
+ let body = async #body;
+ }
+ };
+
+ input.block = syn::parse2(quote! {
+ {
+ #body
+ #block_expr
}
})
.expect("Parsing failure");
@@ -414,7 +480,7 @@ pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool)
};
let config = if let Some(attr) = input.attrs.iter().find(|attr| attr.path.is_ident("test")) {
let msg = "second test attribute is supplied";
- Err(syn::Error::new_spanned(&attr, msg))
+ Err(syn::Error::new_spanned(attr, msg))
} else {
AttributeArgs::parse_terminated
.parse(args)