diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 373 |
1 files changed, 187 insertions, 186 deletions
@@ -2,26 +2,6 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -//! Derive macros for [zerocopy]'s traits. -//! -//! [zerocopy]: https://docs.rs/zerocopy - -// Sometimes we want to use lints which were added after our MSRV. -// `unknown_lints` is `warn` by default and we deny warnings in CI, so without -// this attribute, any unknown lint would cause a CI failure when testing with -// our MSRV. -#![allow(unknown_lints)] -#![deny(renamed_and_removed_lints)] -#![deny(clippy::all, clippy::missing_safety_doc, clippy::undocumented_unsafe_blocks)] -#![deny( - rustdoc::bare_urls, - rustdoc::broken_intra_doc_links, - rustdoc::invalid_codeblock_attributes, - rustdoc::invalid_html_tags, - rustdoc::invalid_rust_codeblocks, - rustdoc::missing_crate_level_docs, - rustdoc::private_intra_doc_links -)] #![recursion_limit = "128"] mod ext; @@ -30,9 +10,10 @@ mod repr; use { proc_macro2::Span, quote::quote, + syn::visit::{self, Visit}, syn::{ - parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit, - GenericParam, Ident, Lit, + parse_quote, punctuated::Punctuated, token::Comma, Data, DataEnum, DataStruct, DataUnion, + DeriveInput, Error, GenericParam, Ident, Lifetime, Type, TypePath, }, }; @@ -49,23 +30,12 @@ use {crate::ext::*, crate::repr::*}; // help: required by the derive of FromBytes // // Instead, we have more verbose error messages like "unsupported representation -// for deriving FromZeroes, FromBytes, AsBytes, or Unaligned on an enum" +// for deriving FromBytes, AsBytes, or Unaligned on an enum" // // This will probably require Span::error // (https://doc.rust-lang.org/nightly/proc_macro/struct.Span.html#method.error), // which is currently unstable. Revisit this once it's stable. -#[proc_macro_derive(FromZeroes)] -pub fn derive_from_zeroes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { - let ast = syn::parse_macro_input!(ts as DeriveInput); - match &ast.data { - Data::Struct(strct) => derive_from_zeroes_struct(&ast, strct), - Data::Enum(enm) => derive_from_zeroes_enum(&ast, enm), - Data::Union(unn) => derive_from_zeroes_union(&ast, unn), - } - .into() -} - #[proc_macro_derive(FromBytes)] pub fn derive_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = syn::parse_macro_input!(ts as DeriveInput); @@ -117,59 +87,11 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ &[StructRepr::C, StructRepr::Packed], ]; -// A struct is `FromZeroes` if: -// - all fields are `FromZeroes` - -fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromZeroes", true, None) -} - -// An enum is `FromZeroes` if: -// - all of its variants are fieldless -// - one of the variants has a discriminant of `0` - -fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { - if !enm.is_c_like() { - return Error::new_spanned(ast, "only C-like enums can implement FromZeroes") - .to_compile_error(); - } - - let has_explicit_zero_discriminant = - enm.variants.iter().filter_map(|v| v.discriminant.as_ref()).any(|(_, e)| { - if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = e { - i.base10_parse::<usize>().ok() == Some(0) - } else { - false - } - }); - // If the first variant of an enum does not specify its discriminant, it is set to zero: - // https://doc.rust-lang.org/reference/items/enumerations.html#custom-discriminant-values-for-fieldless-enumerations - let has_implicit_zero_discriminant = - enm.variants.iter().next().map(|v| v.discriminant.is_none()) == Some(true); - - if !has_explicit_zero_discriminant && !has_implicit_zero_discriminant { - return Error::new_spanned( - ast, - "FromZeroes only supported on enums with a variant that has a discriminant of `0`", - ) - .to_compile_error(); - } - - impl_block(ast, enm, "FromZeroes", true, None) -} - -// Like structs, unions are `FromZeroes` if -// - all fields are `FromZeroes` - -fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromZeroes", true, None) -} - // A struct is `FromBytes` if: // - all fields are `FromBytes` fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromBytes", true, None) + impl_block(ast, strct, "FromBytes", true, PaddingCheck::None) } // An enum is `FromBytes` if: @@ -212,7 +134,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok .to_compile_error(); } - impl_block(ast, enm, "FromBytes", true, None) + impl_block(ast, enm, "FromBytes", true, PaddingCheck::None) } #[rustfmt::skip] @@ -243,7 +165,7 @@ const ENUM_FROM_BYTES_CFG: Config<EnumRepr> = { // - all fields are `FromBytes` fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromBytes", true, None) + impl_block(ast, unn, "FromBytes", true, PaddingCheck::None) } // A struct is `AsBytes` if: @@ -253,30 +175,16 @@ fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::T // - `repr(packed)` fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - let reprs = try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); - let is_transparent = reprs.contains(&StructRepr::Transparent); - let is_packed = reprs.contains(&StructRepr::Packed); - - // TODO(#10): Support type parameters for non-transparent, non-packed - // structs. - if !ast.generics.params.is_empty() && !is_transparent && !is_packed { - return Error::new( - Span::call_site(), - "unsupported on generic structs that are not repr(transparent) or repr(packed)", - ) - .to_compile_error(); + // TODO(#10): Support type parameters. + if !ast.generics.params.is_empty() { + return Error::new(Span::call_site(), "unsupported on types with type parameters") + .to_compile_error(); } - // We don't need a padding check if the struct is repr(transparent) or - // repr(packed). - // - repr(transparent): The layout and ABI of the whole struct is the same - // as its only non-ZST field (meaning there's no padding outside of that - // field) and we require that field to be `AsBytes` (meaning there's no - // padding in that field). - // - repr(packed): Any inter-field padding bytes are removed, meaning that - // any padding bytes would need to come from the fields, all of which - // we require to be `AsBytes` (meaning they don't have any padding). - let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) }; + let reprs = try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); + let padding_check = + if reprs.contains(&StructRepr::Packed) { PaddingCheck::None } else { PaddingCheck::Struct }; + impl_block(ast, strct, "AsBytes", true, padding_check) } @@ -300,7 +208,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token // We don't care what the repr is; we only care that it is one of the // allowed ones. let _: Vec<repr::EnumRepr> = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, enm, "AsBytes", false, None) + impl_block(ast, enm, "AsBytes", false, PaddingCheck::None) } #[rustfmt::skip] @@ -342,7 +250,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union)) + impl_block(ast, unn, "AsBytes", true, PaddingCheck::Union) } // A struct is `Unaligned` if: @@ -355,7 +263,7 @@ fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2 let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, strct, "Unaligned", require_trait_bound, None) + impl_block(ast, strct, "Unaligned", require_trait_bound, PaddingCheck::None) } const STRUCT_UNION_UNALIGNED_CFG: Config<StructRepr> = Config { @@ -386,7 +294,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke // for `require_trait_bounds` doesn't really do anything. But it's // marginally more future-proof in case that restriction is lifted in the // future. - impl_block(ast, enm, "Unaligned", true, None) + impl_block(ast, enm, "Unaligned", true, PaddingCheck::None) } #[rustfmt::skip] @@ -424,37 +332,26 @@ fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::To let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, unn, "Unaligned", require_trait_bound, None) + impl_block(ast, unn, "Unaligned", require_trait_bound, PaddingCheck::None) } // This enum describes what kind of padding check needs to be generated for the // associated impl. enum PaddingCheck { + // No additional padding check is required. + None, // Check that the sum of the fields' sizes exactly equals the struct's size. Struct, // Check that the size of each field exactly equals the union's size. Union, } -impl PaddingCheck { - /// Returns the ident of the macro to call in order to validate that a type - /// passes the padding check encoded by `PaddingCheck`. - fn validator_macro_ident(&self) -> Ident { - let s = match self { - PaddingCheck::Struct => "struct_has_padding", - PaddingCheck::Union => "union_has_padding", - }; - - Ident::new(s, Span::call_site()) - } -} - fn impl_block<D: DataExt>( input: &DeriveInput, data: &D, trait_name: &str, require_trait_bound: bool, - padding_check: Option<PaddingCheck>, + padding_check: PaddingCheck, ) -> proc_macro2::TokenStream { // In this documentation, we will refer to this hypothetical struct: // @@ -470,10 +367,22 @@ fn impl_block<D: DataExt>( // c: I::Item, // } // - // We extract the field types, which in this case are `u8`, `T`, and - // `I::Item`. We re-use the existing parameters and where clauses. If - // `require_trait_bound == true` (as it is for `FromBytes), we add where - // bounds for each field's type: + // First, we extract the field types, which in this case are `u8`, `T`, and + // `I::Item`. We use the names of the type parameters to split the field + // types into two sets - a set of types which are based on the type + // parameters, and a set of types which are not. First, we re-use the + // existing parameters and where clauses, generating an `impl` block like: + // + // impl<T, I: Iterator> FromBytes for Foo<T, I> + // where + // T: Copy, + // I: Clone, + // I::Item: Clone, + // { + // } + // + // Then, we use the list of types which are based on the type parameters to + // generate new entries in the `where` clause: // // impl<T, I: Iterator> FromBytes for Foo<T, I> // where @@ -485,6 +394,18 @@ fn impl_block<D: DataExt>( // { // } // + // Finally, we use a different technique to generate the bounds for the + // types which are not based on type parameters: + // + // + // fn only_derive_is_allowed_to_implement_this_trait() where Self: Sized { + // struct ImplementsFromBytes<F: ?Sized + FromBytes>(PhantomData<F>); + // let _: ImplementsFromBytes<u8>; + // } + // + // It would be easier to put all types in the where clause, but that won't + // work until the trivial_bounds feature is stabilized (#48214). + // // NOTE: It is standard practice to only emit bounds for the type parameters // themselves, not for field types based on those parameters (e.g., `T` vs // `T::Foo`). For a discussion of why this is standard practice, see @@ -495,9 +416,31 @@ fn impl_block<D: DataExt>( // `T::Foo: !FromBytes`. It would not be sound for us to accept a type with // a `T::Foo` field as `FromBytes` simply because `T: FromBytes`. // - // While there's no getting around this requirement for us, it does have the - // pretty serious downside that, when lifetimes are involved, the trait - // solver ties itself in knots: + // While there's no getting around this requirement for us, it does have + // some pretty serious downsides that are worth calling out: + // + // 1. You lose the ability to have fields of generic type with reduced visibility. + // + // #[derive(Unaligned)] + // #[repr(C)] + // pub struct Public<T>(Private<T>); + // + // #[derive(Unaligned)] + // #[repr(C)] + // struct Private<T>(T); + // + // + // warning: private type `Private<T>` in public interface (error E0446) + // --> src/main.rs:6:10 + // | + // 6 | #[derive(Unaligned)] + // | ^^^^^^^^^ + // | + // = note: #[warn(private_in_public)] on by default + // = warning: this was previously accepted by the compiler but is being phased out; it will become a hard error in a future release! + // = note: for more information, see issue #34537 <https://github.com/rust-lang/rust/issues/34537> + // + // 2. When lifetimes are involved, the trait solver ties itself in knots. // // #[derive(Unaligned)] // #[repr(C)] @@ -506,6 +449,7 @@ fn impl_block<D: DataExt>( // b: PhantomData<&'b u8>, // } // + // // error[E0283]: type annotations required: cannot resolve `core::marker::PhantomData<&'a u8>: zerocopy::Unaligned` // --> src/main.rs:6:10 // | @@ -514,37 +458,67 @@ fn impl_block<D: DataExt>( // | // = note: required by `zerocopy::Unaligned` - let type_ident = &input.ident; + // A visitor which is used to walk a field's type and determine whether any + // of its definition is based on the type or lifetime parameters on a type. + struct FromTypeParamVisit<'a, 'b>(&'a Punctuated<GenericParam, Comma>, &'b mut bool); + + impl<'a, 'b> Visit<'a> for FromTypeParamVisit<'a, 'b> { + fn visit_lifetime(&mut self, i: &'a Lifetime) { + visit::visit_lifetime(self, i); + if self.0.iter().any(|param| { + if let GenericParam::Lifetime(param) = param { + param.lifetime.ident == i.ident + } else { + false + } + }) { + *self.1 = true; + } + } + + fn visit_type_path(&mut self, i: &'a TypePath) { + visit::visit_type_path(self, i); + if self.0.iter().any(|param| { + if let GenericParam::Type(param) = param { + i.path.segments.first().unwrap().ident == param.ident + } else { + false + } + }) { + *self.1 = true; + } + } + } + + // Whether this type is based on one of the type parameters. E.g., given the + // type parameters `<T>`, `T`, `T::Foo`, and `(T::Foo, String)` are all + // based on the type parameters, while `String` and `(String, Box<()>)` are + // not. + let is_from_type_param = |ty: &Type| { + let mut ret = false; + FromTypeParamVisit(&input.generics.params, &mut ret).visit_type(ty); + ret + }; + let trait_ident = Ident::new(trait_name, Span::call_site()); - let field_types = data.field_types(); - - let field_type_bounds = require_trait_bound - .then(|| field_types.iter().map(|ty| parse_quote!(#ty: zerocopy::#trait_ident))) - .into_iter() - .flatten() - .collect::<Vec<_>>(); - - // Don't bother emitting a padding check if there are no fields. - #[allow(unstable_name_collisions)] // See `BoolExt` below - let padding_check_bound = padding_check.and_then(|check| (!field_types.is_empty()).then_some(check)).map(|check| { - let fields = field_types.iter(); - let validator_macro = check.validator_macro_ident(); - parse_quote!( - zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::#validator_macro!(#type_ident, #(#fields),*)}>: - zerocopy::derive_util::ShouldBe<false> - ) - }); - let bounds = input - .generics - .where_clause - .as_ref() - .map(|where_clause| where_clause.predicates.iter()) - .into_iter() - .flatten() - .chain(field_type_bounds.iter()) - .chain(padding_check_bound.iter()); + let field_types = data.nested_types(); + let type_param_field_types = field_types.iter().filter(|ty| is_from_type_param(ty)); + let non_type_param_field_types = field_types.iter().filter(|ty| !is_from_type_param(ty)); + + // Add a new set of where clause predicates of the form `T: Trait` for each + // of the types of the struct's fields (but only the ones whose types are + // based on one of the type parameters). + let mut generics = input.generics.clone(); + let where_clause = generics.make_where_clause(); + if require_trait_bound { + for ty in type_param_field_types { + let bound = parse_quote!(#ty: zerocopy::#trait_ident); + where_clause.predicates.push(bound); + } + } + let type_ident = &input.ident; // The parameters with trait bounds, but without type defaults. let params = input.generics.params.clone().into_iter().map(|mut param| { match &mut param { @@ -564,12 +538,56 @@ fn impl_block<D: DataExt>( GenericParam::Const(cnst) => quote!(#cnst), }); + let trait_bound_body = if require_trait_bound { + let implements_type_ident = + Ident::new(format!("Implements{}", trait_ident).as_str(), Span::call_site()); + let implements_type_tokens = quote!(#implements_type_ident); + let types = non_type_param_field_types.map(|ty| quote!(#implements_type_tokens<#ty>)); + quote!( + // A type with a type parameter that must implement `#trait_ident`. + struct #implements_type_ident<F: ?Sized + zerocopy::#trait_ident>(::core::marker::PhantomData<F>); + // For each field type, an instantiation that won't type check if + // that type doesn't implement `#trait_ident`. + #(let _: #types;)* + ) + } else { + quote!() + }; + + let size_check_body = match (field_types.is_empty(), padding_check) { + (true, _) | (false, PaddingCheck::None) => quote!(), + (false, PaddingCheck::Struct) => quote!( + const _: () = { + trait HasPadding<const HAS_PADDING: bool> {} + fn assert_no_padding<T: HasPadding<false>>() {} + + const COMPOSITE_TYPE_SIZE: usize = ::core::mem::size_of::<#type_ident>(); + const SUM_FIELD_SIZES: usize = 0 #(+ ::core::mem::size_of::<#field_types>())*; + const HAS_PADDING: bool = COMPOSITE_TYPE_SIZE > SUM_FIELD_SIZES; + impl HasPadding<HAS_PADDING> for #type_ident {} + let _ = assert_no_padding::<#type_ident>; + }; + ), + (false, PaddingCheck::Union) => quote!( + const _: () = { + trait FieldsAreSameSize<const FIELDS_ARE_SAME_SIZE: bool> {} + fn assert_fields_are_same_size<T: FieldsAreSameSize<true>>() {} + + const COMPOSITE_TYPE_SIZE: usize = ::core::mem::size_of::<#type_ident>(); + const FIELDS_ARE_SAME_SIZE: bool = true + #(&& (::core::mem::size_of::<#field_types>() == COMPOSITE_TYPE_SIZE))*; + impl FieldsAreSameSize<FIELDS_ARE_SAME_SIZE> for #type_ident {} + let _ = assert_fields_are_same_size::<#type_ident>; + }; + ), + }; + quote! { - unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* > - where - #(#bounds,)* - { - fn only_derive_is_allowed_to_implement_this_trait() {} + unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* > #where_clause { + fn only_derive_is_allowed_to_implement_this_trait() where Self: Sized { + #trait_bound_body + #size_check_body + } } } } @@ -578,23 +596,6 @@ fn print_all_errors(errors: Vec<Error>) -> proc_macro2::TokenStream { errors.iter().map(Error::to_compile_error).collect() } -// A polyfill for `Option::then_some`, which was added after our MSRV. -// -// TODO(#67): Remove this once our MSRV is >= 1.62. -trait BoolExt { - fn then_some<T>(self, t: T) -> Option<T>; -} - -impl BoolExt for bool { - fn then_some<T>(self, t: T) -> Option<T> { - if self { - Some(t) - } else { - None - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -615,12 +616,12 @@ mod tests { } fn elements_are_sorted_and_deduped<T: Clone + Ord>(lists: &[&[T]]) -> bool { - lists.iter().all(|list| is_sorted_and_deduped(list)) + lists.iter().all(|list| is_sorted_and_deduped(*list)) } fn config_is_sorted<T: KindRepr + Clone>(config: &Config<T>) -> bool { - elements_are_sorted_and_deduped(config.allowed_combinations) - && elements_are_sorted_and_deduped(config.disallowed_but_legal_combinations) + elements_are_sorted_and_deduped(&config.allowed_combinations) + && elements_are_sorted_and_deduped(&config.disallowed_but_legal_combinations) } assert!(config_is_sorted(&STRUCT_UNION_UNALIGNED_CFG)); |