diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-07-07 07:12:44 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-07-07 07:12:44 +0000 |
commit | 46a9a414b352ff56b0d9490c2f6820c6d5baa89b (patch) | |
tree | 71072e4f3b2bd823a6e356c31736bc837a2e6622 | |
parent | 105719052399a131bf3fe79410c39fb068dfae94 (diff) | |
parent | 385bce15785129cd36bb1c346d0380f19d0c3ab1 (diff) | |
download | derive_arbitrary-aml_odp_341610000.tar.gz |
Snap for 10453938 from 385bce15785129cd36bb1c346d0380f19d0c3ab1 to mainline-odp-releaseaml_odp_341610000
Change-Id: I79a24d2b79c5be74c89e8bbe638284b3052c655d
-rw-r--r-- | .cargo_vcs_info.json | 2 | ||||
-rw-r--r-- | Android.bp | 6 | ||||
-rw-r--r-- | Cargo.toml | 13 | ||||
-rw-r--r-- | Cargo.toml.orig | 7 | ||||
-rw-r--r-- | METADATA | 14 | ||||
-rw-r--r-- | src/container_attributes.rs | 72 | ||||
-rw-r--r-- | src/field_attributes.rs | 115 | ||||
-rw-r--r-- | src/lib.rs | 383 |
8 files changed, 509 insertions, 103 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json index 9712a81..5db9271 100644 --- a/.cargo_vcs_info.json +++ b/.cargo_vcs_info.json @@ -1,6 +1,6 @@ { "git": { - "sha1": "d0d238d880276fd617c38f7e4712bf40db58aad6" + "sha1": "c20a95029145c0eab249416f3d301c4bd21f33f6" }, "path_in_vcs": "derive" }
\ No newline at end of file @@ -43,12 +43,14 @@ rust_proc_macro { name: "libderive_arbitrary", crate_name: "derive_arbitrary", cargo_env_compat: true, - cargo_pkg_version: "1.1.0", + cargo_pkg_version: "1.3.0", srcs: ["src/lib.rs"], - edition: "2018", + edition: "2021", rustlibs: [ "libproc_macro2", "libquote", "libsyn", ], + product_available: true, + vendor_available: true, } @@ -10,9 +10,10 @@ # See Cargo.toml.orig for the original contents. [package] -edition = "2018" +edition = "2021" +rust-version = "1.63.0" name = "derive_arbitrary" -version = "1.1.0" +version = "1.3.0" authors = [ "The Rust-Fuzz Project Developers", "Nick Fitzgerald <fitzgen@gmail.com>", @@ -32,6 +33,7 @@ keywords = [ categories = ["development-tools::testing"] license = "MIT/Apache-2.0" repository = "https://github.com/rust-fuzz/arbitrary" +resolver = "1" [lib] proc_macro = true @@ -43,5 +45,8 @@ version = "1.0" version = "1.0" [dependencies.syn] -version = "1.0" -features = ["derive"] +version = "1.0.56" +features = [ + "derive", + "parsing", +] diff --git a/Cargo.toml.orig b/Cargo.toml.orig index acdd379..570519f 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -1,6 +1,6 @@ [package] name = "derive_arbitrary" -version = "1.1.0" # Make sure it matches the version of the arbitrary crate itself. +version = "1.3.0" # Make sure it matches the version of the arbitrary crate itself (not including the patch version) authors = [ "The Rust-Fuzz Project Developers", "Nick Fitzgerald <fitzgen@gmail.com>", @@ -9,18 +9,19 @@ authors = [ "Corey Farwell <coreyf@rwell.org>", ] categories = ["development-tools::testing"] -edition = "2018" +edition = "2021" keywords = ["arbitrary", "testing", "derive", "macro"] readme = "README.md" description = "Derives arbitrary traits" license = "MIT/Apache-2.0" repository = "https://github.com/rust-fuzz/arbitrary" documentation = "https://docs.rs/arbitrary/" +rust-version = "1.63.0" [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "1.0", features = ['derive'] } +syn = { version = "1.0.56", features = ['derive', 'parsing'] } [lib] proc_macro = true @@ -1,3 +1,7 @@ +# This project was upgraded with external_updater. +# Usage: tools/external_updater/updater.sh update rust/crates/derive_arbitrary +# For more info, check https://cs.android.com/android/platform/superproject/+/master:tools/external_updater/README.md + name: "derive_arbitrary" description: "Derives arbitrary traits" third_party { @@ -7,13 +11,13 @@ third_party { } url { type: ARCHIVE - value: "https://static.crates.io/crates/derive_arbitrary/derive_arbitrary-1.1.0.crate" + value: "https://static.crates.io/crates/derive_arbitrary/derive_arbitrary-1.3.0.crate" } - version: "1.1.0" + version: "1.3.0" license_type: NOTICE last_upgrade_date { - year: 2022 - month: 3 - day: 1 + year: 2023 + month: 4 + day: 3 } } diff --git a/src/container_attributes.rs b/src/container_attributes.rs new file mode 100644 index 0000000..9a91ac8 --- /dev/null +++ b/src/container_attributes.rs @@ -0,0 +1,72 @@ +use crate::ARBITRARY_ATTRIBUTE_NAME; +use syn::{ + parse::Error, punctuated::Punctuated, DeriveInput, Lit, Meta, MetaNameValue, NestedMeta, Token, + TypeParam, +}; + +pub struct ContainerAttributes { + /// Specify type bounds to be applied to the derived `Arbitrary` implementation instead of the + /// default inferred bounds. + /// + /// ```ignore + /// #[arbitrary(bound = "T: Default, U: Debug")] + /// ``` + /// + /// Multiple attributes will be combined as long as they don't conflict, e.g. + /// + /// ```ignore + /// #[arbitrary(bound = "T: Default")] + /// #[arbitrary(bound = "U: Default")] + /// ``` + pub bounds: Option<Vec<Punctuated<TypeParam, Token![,]>>>, +} + +impl ContainerAttributes { + pub fn from_derive_input(derive_input: &DeriveInput) -> Result<Self, Error> { + let mut bounds = None; + + for attr in &derive_input.attrs { + if !attr.path.is_ident(ARBITRARY_ATTRIBUTE_NAME) { + continue; + } + + let meta_list = match attr.parse_meta()? { + Meta::List(l) => l, + _ => { + return Err(Error::new_spanned( + attr, + format!( + "invalid `{}` attribute. expected list", + ARBITRARY_ATTRIBUTE_NAME + ), + )) + } + }; + + for nested_meta in meta_list.nested.iter() { + match nested_meta { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(bound_str_lit), + .. + })) if path.is_ident("bound") => { + bounds + .get_or_insert_with(Vec::new) + .push(bound_str_lit.parse_with(Punctuated::parse_terminated)?); + } + _ => { + return Err(Error::new_spanned( + attr, + format!( + "invalid `{}` attribute. expected `bound = \"..\"`", + ARBITRARY_ATTRIBUTE_NAME, + ), + )) + } + } + } + } + + Ok(Self { bounds }) + } +} diff --git a/src/field_attributes.rs b/src/field_attributes.rs new file mode 100644 index 0000000..2ca0f1c --- /dev/null +++ b/src/field_attributes.rs @@ -0,0 +1,115 @@ +use crate::ARBITRARY_ATTRIBUTE_NAME; +use proc_macro2::{Group, Span, TokenStream, TokenTree}; +use quote::quote; +use syn::{spanned::Spanned, *}; + +/// Determines how a value for a field should be constructed. +#[cfg_attr(test, derive(Debug))] +pub enum FieldConstructor { + /// Assume that Arbitrary is defined for the type of this field and use it (default) + Arbitrary, + + /// Places `Default::default()` as a field value. + Default, + + /// Use custom function or closure to generate a value for a field. + With(TokenStream), + + /// Set a field always to the given value. + Value(TokenStream), +} + +pub fn determine_field_constructor(field: &Field) -> Result<FieldConstructor> { + let opt_attr = fetch_attr_from_field(field)?; + let ctor = match opt_attr { + Some(attr) => parse_attribute(attr)?, + None => FieldConstructor::Arbitrary, + }; + Ok(ctor) +} + +fn fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>> { + let found_attributes: Vec<_> = field + .attrs + .iter() + .filter(|a| { + let path = &a.path; + let name = quote!(#path).to_string(); + name == ARBITRARY_ATTRIBUTE_NAME + }) + .collect(); + if found_attributes.len() > 1 { + let name = field.ident.as_ref().unwrap(); + let msg = format!( + "Multiple conflicting #[{ARBITRARY_ATTRIBUTE_NAME}] attributes found on field `{name}`" + ); + return Err(syn::Error::new(field.span(), msg)); + } + Ok(found_attributes.into_iter().next()) +} + +fn parse_attribute(attr: &Attribute) -> Result<FieldConstructor> { + let group = { + let mut tokens_iter = attr.clone().tokens.into_iter(); + let token = tokens_iter.next().ok_or_else(|| { + let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty."); + syn::Error::new(attr.span(), msg) + })?; + match token { + TokenTree::Group(g) => g, + t => { + let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] must contain a group, got: {t})"); + return Err(syn::Error::new(attr.span(), msg)); + } + } + }; + parse_attribute_internals(group) +} + +fn parse_attribute_internals(group: Group) -> Result<FieldConstructor> { + let stream = group.stream(); + let mut tokens_iter = stream.into_iter(); + let token = tokens_iter.next().ok_or_else(|| { + let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty."); + syn::Error::new(group.span(), msg) + })?; + match token.to_string().as_ref() { + "default" => Ok(FieldConstructor::Default), + "with" => { + let func_path = parse_assigned_value("with", tokens_iter, group.span())?; + Ok(FieldConstructor::With(func_path)) + } + "value" => { + let value = parse_assigned_value("value", tokens_iter, group.span())?; + Ok(FieldConstructor::Value(value)) + } + _ => { + let msg = format!("Unknown option for #[{ARBITRARY_ATTRIBUTE_NAME}]: `{token}`"); + Err(syn::Error::new(token.span(), msg)) + } + } +} + +// Input: +// = 2 + 2 +// Output: +// 2 + 2 +fn parse_assigned_value( + opt_name: &str, + mut tokens_iter: impl Iterator<Item = TokenTree>, + default_span: Span, +) -> Result<TokenStream> { + let eq_sign = tokens_iter.next().ok_or_else(|| { + let msg = format!( + "Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], `{opt_name}` is missing assignment." + ); + syn::Error::new(default_span, msg) + })?; + + if eq_sign.to_string() == "=" { + Ok(tokens_iter.collect()) + } else { + let msg = format!("Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], expected `=` after `{opt_name}`, got: `{eq_sign}`"); + Err(syn::Error::new(eq_sign.span(), msg)) + } +} @@ -4,19 +4,44 @@ use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::*; +mod container_attributes; +mod field_attributes; +use container_attributes::ContainerAttributes; +use field_attributes::{determine_field_constructor, FieldConstructor}; + +static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary"; static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary"; -#[proc_macro_derive(Arbitrary)] +#[proc_macro_derive(Arbitrary, attributes(arbitrary))] pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = syn::parse_macro_input!(tokens as syn::DeriveInput); + expand_derive_arbitrary(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> { + let container_attrs = ContainerAttributes::from_derive_input(&input)?; + let (lifetime_without_bounds, lifetime_with_bounds) = build_arbitrary_lifetime(input.generics.clone()); - let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone()); - let size_hint_method = gen_size_hint_method(&input); + let recursive_count = syn::Ident::new( + &format!("RECURSIVE_COUNT_{}", input.ident), + Span::call_site(), + ); + + let arbitrary_method = + gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?; + let size_hint_method = gen_size_hint_method(&input)?; let name = input.ident; - // Add a bound `T: Arbitrary` to every type parameter T. - let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone()); + + // Apply user-supplied bounds or automatic `T: ArbitraryBounds`. + let generics = apply_trait_bounds( + input.generics, + lifetime_without_bounds.clone(), + &container_attrs, + )?; // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90) let mut generics_with_lifetime = generics.clone(); @@ -28,13 +53,20 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr // Build TypeGenerics and WhereClause without a lifetime let (_, ty_generics, where_clause) = generics.split_for_impl(); - (quote! { - impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause { - #arbitrary_method - #size_hint_method - } + Ok(quote! { + const _: () = { + std::thread_local! { + #[allow(non_upper_case_globals)] + static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0); + } + + #[automatically_derived] + impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause { + #arbitrary_method + #size_hint_method + } + }; }) - .into() } // Returns: (lifetime without bounds, lifetime with bounds) @@ -55,6 +87,51 @@ fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) { (lifetime_without_bounds, lifetime_with_bounds) } +fn apply_trait_bounds( + mut generics: Generics, + lifetime: LifetimeDef, + container_attrs: &ContainerAttributes, +) -> Result<Generics> { + // If user-supplied bounds exist, apply them to their matching type parameters. + if let Some(config_bounds) = &container_attrs.bounds { + let mut config_bounds_applied = 0; + for param in generics.params.iter_mut() { + if let GenericParam::Type(type_param) = param { + if let Some(replacement) = config_bounds + .iter() + .flatten() + .find(|p| p.ident == type_param.ident) + { + *type_param = replacement.clone(); + config_bounds_applied += 1; + } else { + // If no user-supplied bounds exist for this type, delete the original bounds. + // This mimics serde. + type_param.bounds = Default::default(); + type_param.default = None; + } + } + } + let config_bounds_supplied = config_bounds + .iter() + .map(|bounds| bounds.len()) + .sum::<usize>(); + if config_bounds_applied != config_bounds_supplied { + return Err(Error::new( + Span::call_site(), + format!( + "invalid `{}` attribute. too many bounds, only {} out of {} are applicable", + ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied, + ), + )); + } + Ok(generics) + } else { + // Otherwise, inject a `T: Arbitrary` bound for every parameter. + Ok(add_trait_bounds(generics, lifetime)) + } +} + // Add a bound `T: Arbitrary` to every type parameter T. fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics { for param in generics.params.iter_mut() { @@ -67,42 +144,102 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics { generics } -fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream { - let ident = &input.ident; - let arbitrary_structlike = |fields| { - let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?)); - let arbitrary_take_rest = construct_take_rest(fields); - quote! { +fn with_recursive_count_guard( + recursive_count: &syn::Ident, + expr: impl quote::ToTokens, +) -> impl quote::ToTokens { + quote! { + let guard_against_recursion = u.is_empty(); + if guard_against_recursion { + #recursive_count.with(|count| { + if count.get() > 0 { + return Err(arbitrary::Error::NotEnoughData); + } + count.set(count.get() + 1); + Ok(()) + })?; + } + + let result = (|| { #expr })(); + + if guard_against_recursion { + #recursive_count.with(|count| { + count.set(count.get() - 1); + }); + } + + result + } +} + +fn gen_arbitrary_method( + input: &DeriveInput, + lifetime: LifetimeDef, + recursive_count: &syn::Ident, +) -> Result<TokenStream> { + fn arbitrary_structlike( + fields: &Fields, + ident: &syn::Ident, + lifetime: LifetimeDef, + recursive_count: &syn::Ident, + ) -> Result<TokenStream> { + let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?; + let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) }); + + let arbitrary_take_rest = construct_take_rest(fields)?; + let take_rest_body = + with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) }); + + Ok(quote! { fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { - Ok(#ident #arbitrary) + #body } fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { - Ok(#ident #arbitrary_take_rest) + #take_rest_body } - } - }; - match &input.data { - Data::Struct(data) => arbitrary_structlike(&data.fields), - Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())), + }) + } + + let ident = &input.ident; + let output = match &input.data { + Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?, + Data::Union(data) => arbitrary_structlike( + &Fields::Named(data.fields.clone()), + ident, + lifetime, + recursive_count, + )?, Data::Enum(data) => { - let variants = data.variants.iter().enumerate().map(|(i, variant)| { - let idx = i as u64; - let ctor = construct(&variant.fields, |_, _| { - quote!(arbitrary::Arbitrary::arbitrary(u)?) - }); - let variant_name = &variant.ident; - quote! { #idx => #ident::#variant_name #ctor } - }); - let variants_take_rest = data.variants.iter().enumerate().map(|(i, variant)| { - let idx = i as u64; - let ctor = construct_take_rest(&variant.fields); - let variant_name = &variant.ident; - quote! { #idx => #ident::#variant_name #ctor } - }); + let variants: Vec<TokenStream> = data + .variants + .iter() + .enumerate() + .map(|(i, variant)| { + let idx = i as u64; + let variant_name = &variant.ident; + construct(&variant.fields, |_, field| gen_constructor_for_field(field)) + .map(|ctor| quote! { #idx => #ident::#variant_name #ctor }) + }) + .collect::<Result<_>>()?; + + let variants_take_rest: Vec<TokenStream> = data + .variants + .iter() + .enumerate() + .map(|(i, variant)| { + let idx = i as u64; + let variant_name = &variant.ident; + construct_take_rest(&variant.fields) + .map(|ctor| quote! { #idx => #ident::#variant_name #ctor }) + }) + .collect::<Result<_>>()?; + let count = data.variants.len() as u64; - quote! { - fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { + + let arbitrary = with_recursive_count_guard( + recursive_count, + quote! { // Use a multiply + shift to generate a ranged random number // with slight bias. For details, see: // https://lemire.me/blog/2016/06/30/fast-random-shuffling @@ -110,9 +247,12 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre #(#variants,)* _ => unreachable!() }) - } + }, + ); - fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { + let arbitrary_take_rest = with_recursive_count_guard( + recursive_count, + quote! { // Use a multiply + shift to generate a ranged random number // with slight bias. For details, see: // https://lemire.me/blog/2016/06/30/fast-random-shuffling @@ -120,77 +260,144 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre #(#variants_take_rest,)* _ => unreachable!() }) + }, + ); + + quote! { + fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { + #arbitrary + } + + fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { + #arbitrary_take_rest } } } - } + }; + Ok(output) } -fn construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> TokenStream { - match fields { +fn construct( + fields: &Fields, + ctor: impl Fn(usize, &Field) -> Result<TokenStream>, +) -> Result<TokenStream> { + let output = match fields { Fields::Named(names) => { - let names = names.named.iter().enumerate().map(|(i, f)| { - let name = f.ident.as_ref().unwrap(); - let ctor = ctor(i, f); - quote! { #name: #ctor } - }); + let names: Vec<TokenStream> = names + .named + .iter() + .enumerate() + .map(|(i, f)| { + let name = f.ident.as_ref().unwrap(); + ctor(i, f).map(|ctor| quote! { #name: #ctor }) + }) + .collect::<Result<_>>()?; quote! { { #(#names,)* } } } Fields::Unnamed(names) => { - let names = names.unnamed.iter().enumerate().map(|(i, f)| { - let ctor = ctor(i, f); - quote! { #ctor } - }); + let names: Vec<TokenStream> = names + .unnamed + .iter() + .enumerate() + .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor })) + .collect::<Result<_>>()?; quote! { ( #(#names),* ) } } Fields::Unit => quote!(), - } + }; + Ok(output) } -fn construct_take_rest(fields: &Fields) -> TokenStream { - construct(fields, |idx, _| { - if idx + 1 == fields.len() { - quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? } - } else { - quote! { arbitrary::Arbitrary::arbitrary(&mut u)? } - } +fn construct_take_rest(fields: &Fields) -> Result<TokenStream> { + construct(fields, |idx, field| { + determine_field_constructor(field).map(|field_constructor| match field_constructor { + FieldConstructor::Default => quote!(Default::default()), + FieldConstructor::Arbitrary => { + if idx + 1 == fields.len() { + quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? } + } else { + quote! { arbitrary::Arbitrary::arbitrary(&mut u)? } + } + } + FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?), + FieldConstructor::Value(value) => quote!(#value), + }) }) } -fn gen_size_hint_method(input: &DeriveInput) -> TokenStream { +fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> { let size_hint_fields = |fields: &Fields| { - let tys = fields.iter().map(|f| &f.ty); - quote! { - arbitrary::size_hint::and_all(&[ - #( <#tys as arbitrary::Arbitrary>::size_hint(depth) ),* - ]) - } + fields + .iter() + .map(|f| { + let ty = &f.ty; + determine_field_constructor(f).map(|field_constructor| { + match field_constructor { + FieldConstructor::Default | FieldConstructor::Value(_) => { + quote!((0, Some(0))) + } + FieldConstructor::Arbitrary => { + quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) } + } + + // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is + // just an educated guess, although it's gonna be inaccurate for dynamically + // allocated types (Vec, HashMap, etc.). + FieldConstructor::With(_) => { + quote! { (::core::mem::size_of::<#ty>(), None) } + } + } + }) + }) + .collect::<Result<Vec<TokenStream>>>() + .map(|hints| { + quote! { + arbitrary::size_hint::and_all(&[ + #( #hints ),* + ]) + } + }) }; let size_hint_structlike = |fields: &Fields| { - let hint = size_hint_fields(fields); - quote! { - #[inline] - fn size_hint(depth: usize) -> (usize, Option<usize>) { - arbitrary::size_hint::recursion_guard(depth, |depth| #hint) + size_hint_fields(fields).map(|hint| { + quote! { + #[inline] + fn size_hint(depth: usize) -> (usize, Option<usize>) { + arbitrary::size_hint::recursion_guard(depth, |depth| #hint) + } } - } + }) }; match &input.data { Data::Struct(data) => size_hint_structlike(&data.fields), Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())), - Data::Enum(data) => { - let variants = data.variants.iter().map(|v| size_hint_fields(&v.fields)); - quote! { - #[inline] - fn size_hint(depth: usize) -> (usize, Option<usize>) { - arbitrary::size_hint::and( - <u32 as arbitrary::Arbitrary>::size_hint(depth), - arbitrary::size_hint::recursion_guard(depth, |depth| { - arbitrary::size_hint::or_all(&[ #( #variants ),* ]) - }), - ) + Data::Enum(data) => data + .variants + .iter() + .map(|v| size_hint_fields(&v.fields)) + .collect::<Result<Vec<TokenStream>>>() + .map(|variants| { + quote! { + #[inline] + fn size_hint(depth: usize) -> (usize, Option<usize>) { + arbitrary::size_hint::and( + <u32 as arbitrary::Arbitrary>::size_hint(depth), + arbitrary::size_hint::recursion_guard(depth, |depth| { + arbitrary::size_hint::or_all(&[ #( #variants ),* ]) + }), + ) + } } - } - } + }), } } + +fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> { + let ctor = match determine_field_constructor(field)? { + FieldConstructor::Default => quote!(Default::default()), + FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?), + FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?), + FieldConstructor::Value(value) => quote!(#value), + }; + Ok(ctor) +} |