aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-07-07 07:12:44 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-07-07 07:12:44 +0000
commit46a9a414b352ff56b0d9490c2f6820c6d5baa89b (patch)
tree71072e4f3b2bd823a6e356c31736bc837a2e6622
parent105719052399a131bf3fe79410c39fb068dfae94 (diff)
parent385bce15785129cd36bb1c346d0380f19d0c3ab1 (diff)
downloadderive_arbitrary-aml_odp_341610000.tar.gz
Snap for 10453938 from 385bce15785129cd36bb1c346d0380f19d0c3ab1 to mainline-odp-releaseaml_odp_341610000
Change-Id: I79a24d2b79c5be74c89e8bbe638284b3052c655d
-rw-r--r--.cargo_vcs_info.json2
-rw-r--r--Android.bp6
-rw-r--r--Cargo.toml13
-rw-r--r--Cargo.toml.orig7
-rw-r--r--METADATA14
-rw-r--r--src/container_attributes.rs72
-rw-r--r--src/field_attributes.rs115
-rw-r--r--src/lib.rs383
8 files changed, 509 insertions, 103 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
index 9712a81..5db9271 100644
--- a/.cargo_vcs_info.json
+++ b/.cargo_vcs_info.json
@@ -1,6 +1,6 @@
{
"git": {
- "sha1": "d0d238d880276fd617c38f7e4712bf40db58aad6"
+ "sha1": "c20a95029145c0eab249416f3d301c4bd21f33f6"
},
"path_in_vcs": "derive"
} \ No newline at end of file
diff --git a/Android.bp b/Android.bp
index b3caae1..20506a1 100644
--- a/Android.bp
+++ b/Android.bp
@@ -43,12 +43,14 @@ rust_proc_macro {
name: "libderive_arbitrary",
crate_name: "derive_arbitrary",
cargo_env_compat: true,
- cargo_pkg_version: "1.1.0",
+ cargo_pkg_version: "1.3.0",
srcs: ["src/lib.rs"],
- edition: "2018",
+ edition: "2021",
rustlibs: [
"libproc_macro2",
"libquote",
"libsyn",
],
+ product_available: true,
+ vendor_available: true,
}
diff --git a/Cargo.toml b/Cargo.toml
index d749034..fc076e8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -10,9 +10,10 @@
# See Cargo.toml.orig for the original contents.
[package]
-edition = "2018"
+edition = "2021"
+rust-version = "1.63.0"
name = "derive_arbitrary"
-version = "1.1.0"
+version = "1.3.0"
authors = [
"The Rust-Fuzz Project Developers",
"Nick Fitzgerald <fitzgen@gmail.com>",
@@ -32,6 +33,7 @@ keywords = [
categories = ["development-tools::testing"]
license = "MIT/Apache-2.0"
repository = "https://github.com/rust-fuzz/arbitrary"
+resolver = "1"
[lib]
proc_macro = true
@@ -43,5 +45,8 @@ version = "1.0"
version = "1.0"
[dependencies.syn]
-version = "1.0"
-features = ["derive"]
+version = "1.0.56"
+features = [
+ "derive",
+ "parsing",
+]
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
index acdd379..570519f 100644
--- a/Cargo.toml.orig
+++ b/Cargo.toml.orig
@@ -1,6 +1,6 @@
[package]
name = "derive_arbitrary"
-version = "1.1.0" # Make sure it matches the version of the arbitrary crate itself.
+version = "1.3.0" # Make sure it matches the version of the arbitrary crate itself (not including the patch version)
authors = [
"The Rust-Fuzz Project Developers",
"Nick Fitzgerald <fitzgen@gmail.com>",
@@ -9,18 +9,19 @@ authors = [
"Corey Farwell <coreyf@rwell.org>",
]
categories = ["development-tools::testing"]
-edition = "2018"
+edition = "2021"
keywords = ["arbitrary", "testing", "derive", "macro"]
readme = "README.md"
description = "Derives arbitrary traits"
license = "MIT/Apache-2.0"
repository = "https://github.com/rust-fuzz/arbitrary"
documentation = "https://docs.rs/arbitrary/"
+rust-version = "1.63.0"
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
-syn = { version = "1.0", features = ['derive'] }
+syn = { version = "1.0.56", features = ['derive', 'parsing'] }
[lib]
proc_macro = true
diff --git a/METADATA b/METADATA
index f875682..c7c7071 100644
--- a/METADATA
+++ b/METADATA
@@ -1,3 +1,7 @@
+# This project was upgraded with external_updater.
+# Usage: tools/external_updater/updater.sh update rust/crates/derive_arbitrary
+# For more info, check https://cs.android.com/android/platform/superproject/+/master:tools/external_updater/README.md
+
name: "derive_arbitrary"
description: "Derives arbitrary traits"
third_party {
@@ -7,13 +11,13 @@ third_party {
}
url {
type: ARCHIVE
- value: "https://static.crates.io/crates/derive_arbitrary/derive_arbitrary-1.1.0.crate"
+ value: "https://static.crates.io/crates/derive_arbitrary/derive_arbitrary-1.3.0.crate"
}
- version: "1.1.0"
+ version: "1.3.0"
license_type: NOTICE
last_upgrade_date {
- year: 2022
- month: 3
- day: 1
+ year: 2023
+ month: 4
+ day: 3
}
}
diff --git a/src/container_attributes.rs b/src/container_attributes.rs
new file mode 100644
index 0000000..9a91ac8
--- /dev/null
+++ b/src/container_attributes.rs
@@ -0,0 +1,72 @@
+use crate::ARBITRARY_ATTRIBUTE_NAME;
+use syn::{
+ parse::Error, punctuated::Punctuated, DeriveInput, Lit, Meta, MetaNameValue, NestedMeta, Token,
+ TypeParam,
+};
+
+pub struct ContainerAttributes {
+ /// Specify type bounds to be applied to the derived `Arbitrary` implementation instead of the
+ /// default inferred bounds.
+ ///
+ /// ```ignore
+ /// #[arbitrary(bound = "T: Default, U: Debug")]
+ /// ```
+ ///
+ /// Multiple attributes will be combined as long as they don't conflict, e.g.
+ ///
+ /// ```ignore
+ /// #[arbitrary(bound = "T: Default")]
+ /// #[arbitrary(bound = "U: Default")]
+ /// ```
+ pub bounds: Option<Vec<Punctuated<TypeParam, Token![,]>>>,
+}
+
+impl ContainerAttributes {
+ pub fn from_derive_input(derive_input: &DeriveInput) -> Result<Self, Error> {
+ let mut bounds = None;
+
+ for attr in &derive_input.attrs {
+ if !attr.path.is_ident(ARBITRARY_ATTRIBUTE_NAME) {
+ continue;
+ }
+
+ let meta_list = match attr.parse_meta()? {
+ Meta::List(l) => l,
+ _ => {
+ return Err(Error::new_spanned(
+ attr,
+ format!(
+ "invalid `{}` attribute. expected list",
+ ARBITRARY_ATTRIBUTE_NAME
+ ),
+ ))
+ }
+ };
+
+ for nested_meta in meta_list.nested.iter() {
+ match nested_meta {
+ NestedMeta::Meta(Meta::NameValue(MetaNameValue {
+ path,
+ lit: Lit::Str(bound_str_lit),
+ ..
+ })) if path.is_ident("bound") => {
+ bounds
+ .get_or_insert_with(Vec::new)
+ .push(bound_str_lit.parse_with(Punctuated::parse_terminated)?);
+ }
+ _ => {
+ return Err(Error::new_spanned(
+ attr,
+ format!(
+ "invalid `{}` attribute. expected `bound = \"..\"`",
+ ARBITRARY_ATTRIBUTE_NAME,
+ ),
+ ))
+ }
+ }
+ }
+ }
+
+ Ok(Self { bounds })
+ }
+}
diff --git a/src/field_attributes.rs b/src/field_attributes.rs
new file mode 100644
index 0000000..2ca0f1c
--- /dev/null
+++ b/src/field_attributes.rs
@@ -0,0 +1,115 @@
+use crate::ARBITRARY_ATTRIBUTE_NAME;
+use proc_macro2::{Group, Span, TokenStream, TokenTree};
+use quote::quote;
+use syn::{spanned::Spanned, *};
+
+/// Determines how a value for a field should be constructed.
+#[cfg_attr(test, derive(Debug))]
+pub enum FieldConstructor {
+ /// Assume that Arbitrary is defined for the type of this field and use it (default)
+ Arbitrary,
+
+ /// Places `Default::default()` as a field value.
+ Default,
+
+ /// Use custom function or closure to generate a value for a field.
+ With(TokenStream),
+
+ /// Set a field always to the given value.
+ Value(TokenStream),
+}
+
+pub fn determine_field_constructor(field: &Field) -> Result<FieldConstructor> {
+ let opt_attr = fetch_attr_from_field(field)?;
+ let ctor = match opt_attr {
+ Some(attr) => parse_attribute(attr)?,
+ None => FieldConstructor::Arbitrary,
+ };
+ Ok(ctor)
+}
+
+fn fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>> {
+ let found_attributes: Vec<_> = field
+ .attrs
+ .iter()
+ .filter(|a| {
+ let path = &a.path;
+ let name = quote!(#path).to_string();
+ name == ARBITRARY_ATTRIBUTE_NAME
+ })
+ .collect();
+ if found_attributes.len() > 1 {
+ let name = field.ident.as_ref().unwrap();
+ let msg = format!(
+ "Multiple conflicting #[{ARBITRARY_ATTRIBUTE_NAME}] attributes found on field `{name}`"
+ );
+ return Err(syn::Error::new(field.span(), msg));
+ }
+ Ok(found_attributes.into_iter().next())
+}
+
+fn parse_attribute(attr: &Attribute) -> Result<FieldConstructor> {
+ let group = {
+ let mut tokens_iter = attr.clone().tokens.into_iter();
+ let token = tokens_iter.next().ok_or_else(|| {
+ let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty.");
+ syn::Error::new(attr.span(), msg)
+ })?;
+ match token {
+ TokenTree::Group(g) => g,
+ t => {
+ let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] must contain a group, got: {t})");
+ return Err(syn::Error::new(attr.span(), msg));
+ }
+ }
+ };
+ parse_attribute_internals(group)
+}
+
+fn parse_attribute_internals(group: Group) -> Result<FieldConstructor> {
+ let stream = group.stream();
+ let mut tokens_iter = stream.into_iter();
+ let token = tokens_iter.next().ok_or_else(|| {
+ let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty.");
+ syn::Error::new(group.span(), msg)
+ })?;
+ match token.to_string().as_ref() {
+ "default" => Ok(FieldConstructor::Default),
+ "with" => {
+ let func_path = parse_assigned_value("with", tokens_iter, group.span())?;
+ Ok(FieldConstructor::With(func_path))
+ }
+ "value" => {
+ let value = parse_assigned_value("value", tokens_iter, group.span())?;
+ Ok(FieldConstructor::Value(value))
+ }
+ _ => {
+ let msg = format!("Unknown option for #[{ARBITRARY_ATTRIBUTE_NAME}]: `{token}`");
+ Err(syn::Error::new(token.span(), msg))
+ }
+ }
+}
+
+// Input:
+// = 2 + 2
+// Output:
+// 2 + 2
+fn parse_assigned_value(
+ opt_name: &str,
+ mut tokens_iter: impl Iterator<Item = TokenTree>,
+ default_span: Span,
+) -> Result<TokenStream> {
+ let eq_sign = tokens_iter.next().ok_or_else(|| {
+ let msg = format!(
+ "Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], `{opt_name}` is missing assignment."
+ );
+ syn::Error::new(default_span, msg)
+ })?;
+
+ if eq_sign.to_string() == "=" {
+ Ok(tokens_iter.collect())
+ } else {
+ let msg = format!("Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], expected `=` after `{opt_name}`, got: `{eq_sign}`");
+ Err(syn::Error::new(eq_sign.span(), msg))
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 983bd68..5e05522 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -4,19 +4,44 @@ use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::*;
+mod container_attributes;
+mod field_attributes;
+use container_attributes::ContainerAttributes;
+use field_attributes::{determine_field_constructor, FieldConstructor};
+
+static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
-#[proc_macro_derive(Arbitrary)]
+#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
+ expand_derive_arbitrary(input)
+ .unwrap_or_else(syn::Error::into_compile_error)
+ .into()
+}
+
+fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
+ let container_attrs = ContainerAttributes::from_derive_input(&input)?;
+
let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());
- let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone());
- let size_hint_method = gen_size_hint_method(&input);
+ let recursive_count = syn::Ident::new(
+ &format!("RECURSIVE_COUNT_{}", input.ident),
+ Span::call_site(),
+ );
+
+ let arbitrary_method =
+ gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
+ let size_hint_method = gen_size_hint_method(&input)?;
let name = input.ident;
- // Add a bound `T: Arbitrary` to every type parameter T.
- let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone());
+
+ // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
+ let generics = apply_trait_bounds(
+ input.generics,
+ lifetime_without_bounds.clone(),
+ &container_attrs,
+ )?;
// Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
let mut generics_with_lifetime = generics.clone();
@@ -28,13 +53,20 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
// Build TypeGenerics and WhereClause without a lifetime
let (_, ty_generics, where_clause) = generics.split_for_impl();
- (quote! {
- impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
- #arbitrary_method
- #size_hint_method
- }
+ Ok(quote! {
+ const _: () = {
+ std::thread_local! {
+ #[allow(non_upper_case_globals)]
+ static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0);
+ }
+
+ #[automatically_derived]
+ impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
+ #arbitrary_method
+ #size_hint_method
+ }
+ };
})
- .into()
}
// Returns: (lifetime without bounds, lifetime with bounds)
@@ -55,6 +87,51 @@ fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) {
(lifetime_without_bounds, lifetime_with_bounds)
}
+fn apply_trait_bounds(
+ mut generics: Generics,
+ lifetime: LifetimeDef,
+ container_attrs: &ContainerAttributes,
+) -> Result<Generics> {
+ // If user-supplied bounds exist, apply them to their matching type parameters.
+ if let Some(config_bounds) = &container_attrs.bounds {
+ let mut config_bounds_applied = 0;
+ for param in generics.params.iter_mut() {
+ if let GenericParam::Type(type_param) = param {
+ if let Some(replacement) = config_bounds
+ .iter()
+ .flatten()
+ .find(|p| p.ident == type_param.ident)
+ {
+ *type_param = replacement.clone();
+ config_bounds_applied += 1;
+ } else {
+ // If no user-supplied bounds exist for this type, delete the original bounds.
+ // This mimics serde.
+ type_param.bounds = Default::default();
+ type_param.default = None;
+ }
+ }
+ }
+ let config_bounds_supplied = config_bounds
+ .iter()
+ .map(|bounds| bounds.len())
+ .sum::<usize>();
+ if config_bounds_applied != config_bounds_supplied {
+ return Err(Error::new(
+ Span::call_site(),
+ format!(
+ "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
+ ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
+ ),
+ ));
+ }
+ Ok(generics)
+ } else {
+ // Otherwise, inject a `T: Arbitrary` bound for every parameter.
+ Ok(add_trait_bounds(generics, lifetime))
+ }
+}
+
// Add a bound `T: Arbitrary` to every type parameter T.
fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
for param in generics.params.iter_mut() {
@@ -67,42 +144,102 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
generics
}
-fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream {
- let ident = &input.ident;
- let arbitrary_structlike = |fields| {
- let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?));
- let arbitrary_take_rest = construct_take_rest(fields);
- quote! {
+fn with_recursive_count_guard(
+ recursive_count: &syn::Ident,
+ expr: impl quote::ToTokens,
+) -> impl quote::ToTokens {
+ quote! {
+ let guard_against_recursion = u.is_empty();
+ if guard_against_recursion {
+ #recursive_count.with(|count| {
+ if count.get() > 0 {
+ return Err(arbitrary::Error::NotEnoughData);
+ }
+ count.set(count.get() + 1);
+ Ok(())
+ })?;
+ }
+
+ let result = (|| { #expr })();
+
+ if guard_against_recursion {
+ #recursive_count.with(|count| {
+ count.set(count.get() - 1);
+ });
+ }
+
+ result
+ }
+}
+
+fn gen_arbitrary_method(
+ input: &DeriveInput,
+ lifetime: LifetimeDef,
+ recursive_count: &syn::Ident,
+) -> Result<TokenStream> {
+ fn arbitrary_structlike(
+ fields: &Fields,
+ ident: &syn::Ident,
+ lifetime: LifetimeDef,
+ recursive_count: &syn::Ident,
+ ) -> Result<TokenStream> {
+ let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
+ let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
+
+ let arbitrary_take_rest = construct_take_rest(fields)?;
+ let take_rest_body =
+ with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
+
+ Ok(quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
- Ok(#ident #arbitrary)
+ #body
}
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
- Ok(#ident #arbitrary_take_rest)
+ #take_rest_body
}
- }
- };
- match &input.data {
- Data::Struct(data) => arbitrary_structlike(&data.fields),
- Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())),
+ })
+ }
+
+ let ident = &input.ident;
+ let output = match &input.data {
+ Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
+ Data::Union(data) => arbitrary_structlike(
+ &Fields::Named(data.fields.clone()),
+ ident,
+ lifetime,
+ recursive_count,
+ )?,
Data::Enum(data) => {
- let variants = data.variants.iter().enumerate().map(|(i, variant)| {
- let idx = i as u64;
- let ctor = construct(&variant.fields, |_, _| {
- quote!(arbitrary::Arbitrary::arbitrary(u)?)
- });
- let variant_name = &variant.ident;
- quote! { #idx => #ident::#variant_name #ctor }
- });
- let variants_take_rest = data.variants.iter().enumerate().map(|(i, variant)| {
- let idx = i as u64;
- let ctor = construct_take_rest(&variant.fields);
- let variant_name = &variant.ident;
- quote! { #idx => #ident::#variant_name #ctor }
- });
+ let variants: Vec<TokenStream> = data
+ .variants
+ .iter()
+ .enumerate()
+ .map(|(i, variant)| {
+ let idx = i as u64;
+ let variant_name = &variant.ident;
+ construct(&variant.fields, |_, field| gen_constructor_for_field(field))
+ .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
+ })
+ .collect::<Result<_>>()?;
+
+ let variants_take_rest: Vec<TokenStream> = data
+ .variants
+ .iter()
+ .enumerate()
+ .map(|(i, variant)| {
+ let idx = i as u64;
+ let variant_name = &variant.ident;
+ construct_take_rest(&variant.fields)
+ .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
+ })
+ .collect::<Result<_>>()?;
+
let count = data.variants.len() as u64;
- quote! {
- fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
+
+ let arbitrary = with_recursive_count_guard(
+ recursive_count,
+ quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
@@ -110,9 +247,12 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre
#(#variants,)*
_ => unreachable!()
})
- }
+ },
+ );
- fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
+ let arbitrary_take_rest = with_recursive_count_guard(
+ recursive_count,
+ quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
@@ -120,77 +260,144 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre
#(#variants_take_rest,)*
_ => unreachable!()
})
+ },
+ );
+
+ quote! {
+ fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
+ #arbitrary
+ }
+
+ fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
+ #arbitrary_take_rest
}
}
}
- }
+ };
+ Ok(output)
}
-fn construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> TokenStream {
- match fields {
+fn construct(
+ fields: &Fields,
+ ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
+) -> Result<TokenStream> {
+ let output = match fields {
Fields::Named(names) => {
- let names = names.named.iter().enumerate().map(|(i, f)| {
- let name = f.ident.as_ref().unwrap();
- let ctor = ctor(i, f);
- quote! { #name: #ctor }
- });
+ let names: Vec<TokenStream> = names
+ .named
+ .iter()
+ .enumerate()
+ .map(|(i, f)| {
+ let name = f.ident.as_ref().unwrap();
+ ctor(i, f).map(|ctor| quote! { #name: #ctor })
+ })
+ .collect::<Result<_>>()?;
quote! { { #(#names,)* } }
}
Fields::Unnamed(names) => {
- let names = names.unnamed.iter().enumerate().map(|(i, f)| {
- let ctor = ctor(i, f);
- quote! { #ctor }
- });
+ let names: Vec<TokenStream> = names
+ .unnamed
+ .iter()
+ .enumerate()
+ .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
+ .collect::<Result<_>>()?;
quote! { ( #(#names),* ) }
}
Fields::Unit => quote!(),
- }
+ };
+ Ok(output)
}
-fn construct_take_rest(fields: &Fields) -> TokenStream {
- construct(fields, |idx, _| {
- if idx + 1 == fields.len() {
- quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
- } else {
- quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
- }
+fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
+ construct(fields, |idx, field| {
+ determine_field_constructor(field).map(|field_constructor| match field_constructor {
+ FieldConstructor::Default => quote!(Default::default()),
+ FieldConstructor::Arbitrary => {
+ if idx + 1 == fields.len() {
+ quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
+ } else {
+ quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
+ }
+ }
+ FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
+ FieldConstructor::Value(value) => quote!(#value),
+ })
})
}
-fn gen_size_hint_method(input: &DeriveInput) -> TokenStream {
+fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
let size_hint_fields = |fields: &Fields| {
- let tys = fields.iter().map(|f| &f.ty);
- quote! {
- arbitrary::size_hint::and_all(&[
- #( <#tys as arbitrary::Arbitrary>::size_hint(depth) ),*
- ])
- }
+ fields
+ .iter()
+ .map(|f| {
+ let ty = &f.ty;
+ determine_field_constructor(f).map(|field_constructor| {
+ match field_constructor {
+ FieldConstructor::Default | FieldConstructor::Value(_) => {
+ quote!((0, Some(0)))
+ }
+ FieldConstructor::Arbitrary => {
+ quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
+ }
+
+ // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
+ // just an educated guess, although it's gonna be inaccurate for dynamically
+ // allocated types (Vec, HashMap, etc.).
+ FieldConstructor::With(_) => {
+ quote! { (::core::mem::size_of::<#ty>(), None) }
+ }
+ }
+ })
+ })
+ .collect::<Result<Vec<TokenStream>>>()
+ .map(|hints| {
+ quote! {
+ arbitrary::size_hint::and_all(&[
+ #( #hints ),*
+ ])
+ }
+ })
};
let size_hint_structlike = |fields: &Fields| {
- let hint = size_hint_fields(fields);
- quote! {
- #[inline]
- fn size_hint(depth: usize) -> (usize, Option<usize>) {
- arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
+ size_hint_fields(fields).map(|hint| {
+ quote! {
+ #[inline]
+ fn size_hint(depth: usize) -> (usize, Option<usize>) {
+ arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
+ }
}
- }
+ })
};
match &input.data {
Data::Struct(data) => size_hint_structlike(&data.fields),
Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
- Data::Enum(data) => {
- let variants = data.variants.iter().map(|v| size_hint_fields(&v.fields));
- quote! {
- #[inline]
- fn size_hint(depth: usize) -> (usize, Option<usize>) {
- arbitrary::size_hint::and(
- <u32 as arbitrary::Arbitrary>::size_hint(depth),
- arbitrary::size_hint::recursion_guard(depth, |depth| {
- arbitrary::size_hint::or_all(&[ #( #variants ),* ])
- }),
- )
+ Data::Enum(data) => data
+ .variants
+ .iter()
+ .map(|v| size_hint_fields(&v.fields))
+ .collect::<Result<Vec<TokenStream>>>()
+ .map(|variants| {
+ quote! {
+ #[inline]
+ fn size_hint(depth: usize) -> (usize, Option<usize>) {
+ arbitrary::size_hint::and(
+ <u32 as arbitrary::Arbitrary>::size_hint(depth),
+ arbitrary::size_hint::recursion_guard(depth, |depth| {
+ arbitrary::size_hint::or_all(&[ #( #variants ),* ])
+ }),
+ )
+ }
}
- }
- }
+ }),
}
}
+
+fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
+ let ctor = match determine_field_constructor(field)? {
+ FieldConstructor::Default => quote!(Default::default()),
+ FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
+ FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
+ FieldConstructor::Value(value) => quote!(#value),
+ };
+ Ok(ctor)
+}