summaryrefslogtreecommitdiff
path: root/src/expand.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/expand.rs')
-rw-r--r--src/expand.rs162
1 files changed, 129 insertions, 33 deletions
diff --git a/src/expand.rs b/src/expand.rs
index 789eee6..435ad48 100644
--- a/src/expand.rs
+++ b/src/expand.rs
@@ -1,8 +1,13 @@
use crate::ast::{Enum, Field, Input, Struct};
+use crate::attr::Trait;
+use crate::generics::InferredBounds;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
+use std::collections::BTreeSet as Set;
use syn::spanned::Spanned;
-use syn::{Data, DeriveInput, Member, PathArguments, Result, Type, Visibility};
+use syn::{
+ Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
+};
pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
let input = Input::from_syn(node)?;
@@ -16,14 +21,23 @@ pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
fn impl_struct(input: Struct) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+ let mut error_inferred_bounds = InferredBounds::new();
let source_body = if input.attrs.transparent.is_some() {
- let only_field = &input.fields[0].member;
+ let only_field = &input.fields[0];
+ if only_field.contains_generic {
+ error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
+ }
+ let member = &only_field.member;
Some(quote! {
- std::error::Error::source(self.#only_field.as_dyn_error())
+ std::error::Error::source(self.#member.as_dyn_error())
})
} else if let Some(source_field) = input.source_field() {
let source = &source_field.member;
+ if source_field.contains_generic {
+ let ty = unoptional_type(source_field.ty);
+ error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
+ }
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
@@ -58,7 +72,9 @@ fn impl_struct(input: Struct) -> TokenStream {
self.#source.as_dyn_error().backtrace()
}
};
- let combinator = if type_is_option(backtrace_field.ty) {
+ let combinator = if source == backtrace {
+ source_backtrace
+ } else if type_is_option(backtrace_field.ty) {
quote! {
#source_backtrace.or(self.#backtrace.as_ref())
}
@@ -87,12 +103,15 @@ fn impl_struct(input: Struct) -> TokenStream {
}
});
+ let mut display_implied_bounds = Set::new();
let display_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
+ display_implied_bounds.insert((0, Trait::Display));
Some(quote! {
std::fmt::Display::fmt(&self.#only_field, __formatter)
})
} else if let Some(display) = &input.attrs.display {
+ display_implied_bounds = display.implied_bounds.clone();
let use_as_display = if display.has_bonus_display {
Some(quote! {
#[allow(unused_imports)]
@@ -112,14 +131,18 @@ fn impl_struct(input: Struct) -> TokenStream {
None
};
let display_impl = display_body.map(|body| {
+ let mut display_inferred_bounds = InferredBounds::new();
+ for (field, bound) in display_implied_bounds {
+ let field = &input.fields[field];
+ if field.contains_generic {
+ display_inferred_bounds.insert(field.ty, bound);
+ }
+ }
+ let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
quote! {
#[allow(unused_qualifications)]
- impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
- #[allow(
- // Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422
- clippy::nonstandard_macro_braces,
- clippy::used_underscore_binding,
- )]
+ impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
+ #[allow(clippy::used_underscore_binding)]
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#body
}
@@ -128,8 +151,8 @@ fn impl_struct(input: Struct) -> TokenStream {
});
let from_impl = input.from_field().map(|from_field| {
- let backtrace_field = input.backtrace_field();
- let from = from_field.ty;
+ let backtrace_field = input.distinct_backtrace_field();
+ let from = unoptional_type(from_field.ty);
let body = from_initializer(from_field, backtrace_field);
quote! {
#[allow(unused_qualifications)]
@@ -143,10 +166,16 @@ fn impl_struct(input: Struct) -> TokenStream {
});
let error_trait = spanned_error_trait(input.original);
+ if input.generics.type_params().next().is_some() {
+ let self_token = <Token![Self]>::default();
+ error_inferred_bounds.insert(self_token, Trait::Debug);
+ error_inferred_bounds.insert(self_token, Trait::Display);
+ }
+ let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
quote! {
#[allow(unused_qualifications)]
- impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
+ impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
#source_method
#backtrace_method
}
@@ -158,18 +187,27 @@ fn impl_struct(input: Struct) -> TokenStream {
fn impl_enum(input: Enum) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+ let mut error_inferred_bounds = InferredBounds::new();
let source_method = if input.has_source() {
let arms = input.variants.iter().map(|variant| {
let ident = &variant.ident;
if variant.attrs.transparent.is_some() {
- let only_field = &variant.fields[0].member;
+ let only_field = &variant.fields[0];
+ if only_field.contains_generic {
+ error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
+ }
+ let member = &only_field.member;
let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
quote! {
- #ty::#ident {#only_field: transparent} => #source,
+ #ty::#ident {#member: transparent} => #source,
}
} else if let Some(source_field) = variant.source_field() {
let source = &source_field.member;
+ if source_field.contains_generic {
+ let ty = unoptional_type(source_field.ty);
+ error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
+ }
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
@@ -238,6 +276,27 @@ fn impl_enum(input: Enum) -> TokenStream {
}
}
}
+ (Some(backtrace_field), Some(source_field))
+ if backtrace_field.member == source_field.member =>
+ {
+ let backtrace = &backtrace_field.member;
+ let varsource = quote!(source);
+ let source_backtrace = if type_is_option(source_field.ty) {
+ quote_spanned! {backtrace.span()=>
+ #varsource.as_ref().and_then(|source| source.as_dyn_error().backtrace())
+ }
+ } else {
+ quote_spanned! {backtrace.span()=>
+ #varsource.as_dyn_error().backtrace()
+ }
+ };
+ quote! {
+ #ty::#ident {#backtrace: #varsource, ..} => {
+ use thiserror::private::AsDynError;
+ #source_backtrace
+ }
+ }
+ }
(Some(backtrace_field), _) => {
let backtrace = &backtrace_field.member;
let body = if type_is_option(backtrace_field.ty) {
@@ -267,6 +326,7 @@ fn impl_enum(input: Enum) -> TokenStream {
};
let display_impl = if input.has_display() {
+ let mut display_inferred_bounds = InferredBounds::new();
let use_as_display = if input.variants.iter().any(|v| {
v.attrs
.display
@@ -286,34 +346,41 @@ fn impl_enum(input: Enum) -> TokenStream {
None
};
let arms = input.variants.iter().map(|variant| {
+ let mut display_implied_bounds = Set::new();
let display = match &variant.attrs.display {
- Some(display) => display.to_token_stream(),
+ Some(display) => {
+ display_implied_bounds = display.implied_bounds.clone();
+ display.to_token_stream()
+ }
None => {
let only_field = match &variant.fields[0].member {
Member::Named(ident) => ident.clone(),
Member::Unnamed(index) => format_ident!("_{}", index),
};
+ display_implied_bounds.insert((0, Trait::Display));
quote!(std::fmt::Display::fmt(#only_field, __formatter))
}
};
+ for (field, bound) in display_implied_bounds {
+ let field = &variant.fields[field];
+ if field.contains_generic {
+ display_inferred_bounds.insert(field.ty, bound);
+ }
+ }
let ident = &variant.ident;
let pat = fields_pat(&variant.fields);
quote! {
#ty::#ident #pat => #display
}
});
+ let arms = arms.collect::<Vec<_>>();
+ let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
Some(quote! {
#[allow(unused_qualifications)]
- impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
+ impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#use_as_display
- #[allow(
- unused_variables,
- deprecated,
- // Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422
- clippy::nonstandard_macro_braces,
- clippy::used_underscore_binding,
- )]
+ #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
match #void_deref self {
#(#arms,)*
}
@@ -326,9 +393,9 @@ fn impl_enum(input: Enum) -> TokenStream {
let from_impls = input.variants.iter().filter_map(|variant| {
let from_field = variant.from_field()?;
- let backtrace_field = variant.backtrace_field();
+ let backtrace_field = variant.distinct_backtrace_field();
let variant = &variant.ident;
- let from = from_field.ty;
+ let from = unoptional_type(from_field.ty);
let body = from_initializer(from_field, backtrace_field);
Some(quote! {
#[allow(unused_qualifications)]
@@ -342,10 +409,16 @@ fn impl_enum(input: Enum) -> TokenStream {
});
let error_trait = spanned_error_trait(input.original);
+ if input.generics.type_params().next().is_some() {
+ let self_token = <Token![Self]>::default();
+ error_inferred_bounds.insert(self_token, Trait::Debug);
+ error_inferred_bounds.insert(self_token, Trait::Display);
+ }
+ let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
quote! {
#[allow(unused_qualifications)]
- impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
+ impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
#source_method
#backtrace_method
}
@@ -371,6 +444,11 @@ fn fields_pat(fields: &[Field]) -> TokenStream {
fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
let from_member = &from_field.member;
+ let some_source = if type_is_option(from_field.ty) {
+ quote!(std::option::Option::Some(source))
+ } else {
+ quote!(source)
+ };
let backtrace = backtrace_field.map(|backtrace_field| {
let backtrace_member = &backtrace_field.member;
if type_is_option(backtrace_field.ty) {
@@ -384,25 +462,43 @@ fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> Toke
}
});
quote!({
- #from_member: source,
+ #from_member: #some_source,
#backtrace
})
}
fn type_is_option(ty: &Type) -> bool {
+ type_parameter_of_option(ty).is_some()
+}
+
+fn unoptional_type(ty: &Type) -> TokenStream {
+ let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
+ quote!(#unoptional)
+}
+
+fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
let path = match ty {
Type::Path(ty) => &ty.path,
- _ => return false,
+ _ => return None,
};
let last = path.segments.last().unwrap();
if last.ident != "Option" {
- return false;
+ return None;
+ }
+
+ let bracketed = match &last.arguments {
+ PathArguments::AngleBracketed(bracketed) => bracketed,
+ _ => return None,
+ };
+
+ if bracketed.args.len() != 1 {
+ return None;
}
- match &last.arguments {
- PathArguments::AngleBracketed(bracketed) => bracketed.args.len() == 1,
- _ => false,
+ match &bracketed.args[0] {
+ GenericArgument::Type(arg) => Some(arg),
+ _ => None,
}
}