diff options
Diffstat (limited to 'src/expand.rs')
-rw-r--r-- | src/expand.rs | 162 |
1 files changed, 129 insertions, 33 deletions
diff --git a/src/expand.rs b/src/expand.rs index 789eee6..435ad48 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -1,8 +1,13 @@ use crate::ast::{Enum, Field, Input, Struct}; +use crate::attr::Trait; +use crate::generics::InferredBounds; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; +use std::collections::BTreeSet as Set; use syn::spanned::Spanned; -use syn::{Data, DeriveInput, Member, PathArguments, Result, Type, Visibility}; +use syn::{ + Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility, +}; pub fn derive(node: &DeriveInput) -> Result<TokenStream> { let input = Input::from_syn(node)?; @@ -16,14 +21,23 @@ pub fn derive(node: &DeriveInput) -> Result<TokenStream> { fn impl_struct(input: Struct) -> TokenStream { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut error_inferred_bounds = InferredBounds::new(); let source_body = if input.attrs.transparent.is_some() { - let only_field = &input.fields[0].member; + let only_field = &input.fields[0]; + if only_field.contains_generic { + error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error)); + } + let member = &only_field.member; Some(quote! { - std::error::Error::source(self.#only_field.as_dyn_error()) + std::error::Error::source(self.#member.as_dyn_error()) }) } else if let Some(source_field) = input.source_field() { let source = &source_field.member; + if source_field.contains_generic { + let ty = unoptional_type(source_field.ty); + error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static)); + } let asref = if type_is_option(source_field.ty) { Some(quote_spanned!(source.span()=> .as_ref()?)) } else { @@ -58,7 +72,9 @@ fn impl_struct(input: Struct) -> TokenStream { self.#source.as_dyn_error().backtrace() } }; - let combinator = if type_is_option(backtrace_field.ty) { + let combinator = if source == backtrace { + source_backtrace + } else if type_is_option(backtrace_field.ty) { quote! { #source_backtrace.or(self.#backtrace.as_ref()) } @@ -87,12 +103,15 @@ fn impl_struct(input: Struct) -> TokenStream { } }); + let mut display_implied_bounds = Set::new(); let display_body = if input.attrs.transparent.is_some() { let only_field = &input.fields[0].member; + display_implied_bounds.insert((0, Trait::Display)); Some(quote! { std::fmt::Display::fmt(&self.#only_field, __formatter) }) } else if let Some(display) = &input.attrs.display { + display_implied_bounds = display.implied_bounds.clone(); let use_as_display = if display.has_bonus_display { Some(quote! { #[allow(unused_imports)] @@ -112,14 +131,18 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let display_impl = display_body.map(|body| { + let mut display_inferred_bounds = InferredBounds::new(); + for (field, bound) in display_implied_bounds { + let field = &input.fields[field]; + if field.contains_generic { + display_inferred_bounds.insert(field.ty, bound); + } + } + let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics); quote! { #[allow(unused_qualifications)] - impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { - #[allow( - // Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422 - clippy::nonstandard_macro_braces, - clippy::used_underscore_binding, - )] + impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause { + #[allow(clippy::used_underscore_binding)] fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { #body } @@ -128,8 +151,8 @@ fn impl_struct(input: Struct) -> TokenStream { }); let from_impl = input.from_field().map(|from_field| { - let backtrace_field = input.backtrace_field(); - let from = from_field.ty; + let backtrace_field = input.distinct_backtrace_field(); + let from = unoptional_type(from_field.ty); let body = from_initializer(from_field, backtrace_field); quote! { #[allow(unused_qualifications)] @@ -143,10 +166,16 @@ fn impl_struct(input: Struct) -> TokenStream { }); let error_trait = spanned_error_trait(input.original); + if input.generics.type_params().next().is_some() { + let self_token = <Token![Self]>::default(); + error_inferred_bounds.insert(self_token, Trait::Debug); + error_inferred_bounds.insert(self_token, Trait::Display); + } + let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics); quote! { #[allow(unused_qualifications)] - impl #impl_generics #error_trait for #ty #ty_generics #where_clause { + impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause { #source_method #backtrace_method } @@ -158,18 +187,27 @@ fn impl_struct(input: Struct) -> TokenStream { fn impl_enum(input: Enum) -> TokenStream { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut error_inferred_bounds = InferredBounds::new(); let source_method = if input.has_source() { let arms = input.variants.iter().map(|variant| { let ident = &variant.ident; if variant.attrs.transparent.is_some() { - let only_field = &variant.fields[0].member; + let only_field = &variant.fields[0]; + if only_field.contains_generic { + error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error)); + } + let member = &only_field.member; let source = quote!(std::error::Error::source(transparent.as_dyn_error())); quote! { - #ty::#ident {#only_field: transparent} => #source, + #ty::#ident {#member: transparent} => #source, } } else if let Some(source_field) = variant.source_field() { let source = &source_field.member; + if source_field.contains_generic { + let ty = unoptional_type(source_field.ty); + error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static)); + } let asref = if type_is_option(source_field.ty) { Some(quote_spanned!(source.span()=> .as_ref()?)) } else { @@ -238,6 +276,27 @@ fn impl_enum(input: Enum) -> TokenStream { } } } + (Some(backtrace_field), Some(source_field)) + if backtrace_field.member == source_field.member => + { + let backtrace = &backtrace_field.member; + let varsource = quote!(source); + let source_backtrace = if type_is_option(source_field.ty) { + quote_spanned! {backtrace.span()=> + #varsource.as_ref().and_then(|source| source.as_dyn_error().backtrace()) + } + } else { + quote_spanned! {backtrace.span()=> + #varsource.as_dyn_error().backtrace() + } + }; + quote! { + #ty::#ident {#backtrace: #varsource, ..} => { + use thiserror::private::AsDynError; + #source_backtrace + } + } + } (Some(backtrace_field), _) => { let backtrace = &backtrace_field.member; let body = if type_is_option(backtrace_field.ty) { @@ -267,6 +326,7 @@ fn impl_enum(input: Enum) -> TokenStream { }; let display_impl = if input.has_display() { + let mut display_inferred_bounds = InferredBounds::new(); let use_as_display = if input.variants.iter().any(|v| { v.attrs .display @@ -286,34 +346,41 @@ fn impl_enum(input: Enum) -> TokenStream { None }; let arms = input.variants.iter().map(|variant| { + let mut display_implied_bounds = Set::new(); let display = match &variant.attrs.display { - Some(display) => display.to_token_stream(), + Some(display) => { + display_implied_bounds = display.implied_bounds.clone(); + display.to_token_stream() + } None => { let only_field = match &variant.fields[0].member { Member::Named(ident) => ident.clone(), Member::Unnamed(index) => format_ident!("_{}", index), }; + display_implied_bounds.insert((0, Trait::Display)); quote!(std::fmt::Display::fmt(#only_field, __formatter)) } }; + for (field, bound) in display_implied_bounds { + let field = &variant.fields[field]; + if field.contains_generic { + display_inferred_bounds.insert(field.ty, bound); + } + } let ident = &variant.ident; let pat = fields_pat(&variant.fields); quote! { #ty::#ident #pat => #display } }); + let arms = arms.collect::<Vec<_>>(); + let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics); Some(quote! { #[allow(unused_qualifications)] - impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { + impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause { fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { #use_as_display - #[allow( - unused_variables, - deprecated, - // Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422 - clippy::nonstandard_macro_braces, - clippy::used_underscore_binding, - )] + #[allow(unused_variables, deprecated, clippy::used_underscore_binding)] match #void_deref self { #(#arms,)* } @@ -326,9 +393,9 @@ fn impl_enum(input: Enum) -> TokenStream { let from_impls = input.variants.iter().filter_map(|variant| { let from_field = variant.from_field()?; - let backtrace_field = variant.backtrace_field(); + let backtrace_field = variant.distinct_backtrace_field(); let variant = &variant.ident; - let from = from_field.ty; + let from = unoptional_type(from_field.ty); let body = from_initializer(from_field, backtrace_field); Some(quote! { #[allow(unused_qualifications)] @@ -342,10 +409,16 @@ fn impl_enum(input: Enum) -> TokenStream { }); let error_trait = spanned_error_trait(input.original); + if input.generics.type_params().next().is_some() { + let self_token = <Token![Self]>::default(); + error_inferred_bounds.insert(self_token, Trait::Debug); + error_inferred_bounds.insert(self_token, Trait::Display); + } + let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics); quote! { #[allow(unused_qualifications)] - impl #impl_generics #error_trait for #ty #ty_generics #where_clause { + impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause { #source_method #backtrace_method } @@ -371,6 +444,11 @@ fn fields_pat(fields: &[Field]) -> TokenStream { fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream { let from_member = &from_field.member; + let some_source = if type_is_option(from_field.ty) { + quote!(std::option::Option::Some(source)) + } else { + quote!(source) + }; let backtrace = backtrace_field.map(|backtrace_field| { let backtrace_member = &backtrace_field.member; if type_is_option(backtrace_field.ty) { @@ -384,25 +462,43 @@ fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> Toke } }); quote!({ - #from_member: source, + #from_member: #some_source, #backtrace }) } fn type_is_option(ty: &Type) -> bool { + type_parameter_of_option(ty).is_some() +} + +fn unoptional_type(ty: &Type) -> TokenStream { + let unoptional = type_parameter_of_option(ty).unwrap_or(ty); + quote!(#unoptional) +} + +fn type_parameter_of_option(ty: &Type) -> Option<&Type> { let path = match ty { Type::Path(ty) => &ty.path, - _ => return false, + _ => return None, }; let last = path.segments.last().unwrap(); if last.ident != "Option" { - return false; + return None; + } + + let bracketed = match &last.arguments { + PathArguments::AngleBracketed(bracketed) => bracketed, + _ => return None, + }; + + if bracketed.args.len() != 1 { + return None; } - match &last.arguments { - PathArguments::AngleBracketed(bracketed) => bracketed.args.len() == 1, - _ => false, + match &bracketed.args[0] { + GenericArgument::Type(arg) => Some(arg), + _ => None, } } |