diff options
author | Hasini Gunasinghe <hasinitg@google.com> | 2022-09-08 07:29:42 +0000 |
---|---|---|
committer | Hasini Gunasinghe <hasinitg@google.com> | 2022-10-03 20:55:06 +0000 |
commit | 53d28e16ca18b4cd5f252a07bcfd61632d82ea2d (patch) | |
tree | db1c061f70cfe6d66c7c003afe14c96e6d1724d2 /src | |
parent | 39d10f0a1cdbccb33fb73fb15d253a3a82a17a2a (diff) | |
download | der_derive-53d28e16ca18b4cd5f252a07bcfd61632d82ea2d.tar.gz |
Import platform/external/rust/crates/der_derive
Bug: 239549209
Test: N/A
Change-Id: I677515ecfd59690611865098ea797ca03af8e523
Diffstat (limited to 'src')
-rw-r--r-- | src/asn1_type.rs | 121 | ||||
-rw-r--r-- | src/attributes.rs | 341 | ||||
-rw-r--r-- | src/choice.rs | 251 | ||||
-rw-r--r-- | src/choice/variant.rs | 417 | ||||
-rw-r--r-- | src/enumerated.rs | 244 | ||||
-rw-r--r-- | src/lib.rs | 284 | ||||
-rw-r--r-- | src/sequence.rs | 331 | ||||
-rw-r--r-- | src/sequence/field.rs | 358 | ||||
-rw-r--r-- | src/tag.rs | 176 | ||||
-rw-r--r-- | src/value_ord.rs | 144 |
10 files changed, 2667 insertions, 0 deletions
diff --git a/src/asn1_type.rs b/src/asn1_type.rs new file mode 100644 index 0000000..787e054 --- /dev/null +++ b/src/asn1_type.rs @@ -0,0 +1,121 @@ +//! ASN.1 types supported by the proc macro + +use proc_macro2::TokenStream; +use quote::quote; +use std::{fmt, str::FromStr}; + +/// ASN.1 built-in types supported by the `#[asn1(type = "...")]` attribute +// TODO(tarcieri): support all ASN.1 types specified in `der::Tag` +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +pub(crate) enum Asn1Type { + /// ASN.1 `BIT STRING`. + BitString, + + /// ASN.1 `IA5String`. + Ia5String, + + /// ASN.1 `GeneralizedTime`. + GeneralizedTime, + + /// ASN.1 `OCTET STRING`. + OctetString, + + /// ASN.1 `PrintableString`. + PrintableString, + + /// ASN.1 `UTCTime`. + UtcTime, + + /// ASN.1 `UTF8String`. + Utf8String, +} + +impl Asn1Type { + /// Get the `::der::Tag` for this ASN.1 type + pub fn tag(self) -> TokenStream { + match self { + Asn1Type::BitString => quote!(::der::Tag::BitString), + Asn1Type::Ia5String => quote!(::der::Tag::Ia5String), + Asn1Type::GeneralizedTime => quote!(::der::Tag::GeneralizedTime), + Asn1Type::OctetString => quote!(::der::Tag::OctetString), + Asn1Type::PrintableString => quote!(::der::Tag::PrintableString), + Asn1Type::UtcTime => quote!(::der::Tag::UtcTime), + Asn1Type::Utf8String => quote!(::der::Tag::Utf8String), + } + } + + /// Get a `der::Decoder` object for a particular ASN.1 type + pub fn decoder(self) -> TokenStream { + match self { + Asn1Type::BitString => quote!(::der::asn1::BitStringRef::decode(reader)?), + Asn1Type::Ia5String => quote!(::der::asn1::Ia5StringRef::decode(reader)?), + Asn1Type::GeneralizedTime => quote!(::der::asn1::GeneralizedTime::decode(reader)?), + Asn1Type::OctetString => quote!(::der::asn1::OctetStringRef::decode(reader)?), + Asn1Type::PrintableString => quote!(::der::asn1::PrintableStringRef::decode(reader)?), + Asn1Type::UtcTime => quote!(::der::asn1::UtcTime::decode(reader)?), + Asn1Type::Utf8String => quote!(::der::asn1::Utf8StringRef::decode(reader)?), + } + } + + /// Get a `der::Encoder` object for a particular ASN.1 type + pub fn encoder(self, binding: &TokenStream) -> TokenStream { + let type_path = self.type_path(); + + match self { + Asn1Type::Ia5String + | Asn1Type::OctetString + | Asn1Type::PrintableString + | Asn1Type::Utf8String => quote!(#type_path::new(#binding)?), + _ => quote!(#type_path::try_from(#binding)?), + } + } + + /// Get the Rust type path for a particular ASN.1 type. + /// Get a `der::Encoder` object for a particular ASN.1 type + pub fn type_path(self) -> TokenStream { + match self { + Asn1Type::BitString => quote!(::der::asn1::BitStringRef), + Asn1Type::Ia5String => quote!(::der::asn1::Ia5StringRef), + Asn1Type::GeneralizedTime => quote!(::der::asn1::GeneralizedTime), + Asn1Type::OctetString => quote!(::der::asn1::OctetStringRef), + Asn1Type::PrintableString => quote!(::der::asn1::PrintableStringRef), + Asn1Type::UtcTime => quote!(::der::asn1::UtcTime), + Asn1Type::Utf8String => quote!(::der::asn1::Utf8StringRef), + } + } +} + +impl FromStr for Asn1Type { + type Err = ParseError; + + fn from_str(s: &str) -> Result<Self, ParseError> { + match s { + "BIT STRING" => Ok(Self::BitString), + "IA5String" => Ok(Self::Ia5String), + "GeneralizedTime" => Ok(Self::GeneralizedTime), + "OCTET STRING" => Ok(Self::OctetString), + "PrintableString" => Ok(Self::PrintableString), + "UTCTime" => Ok(Self::UtcTime), + "UTF8String" => Ok(Self::Utf8String), + _ => Err(ParseError), + } + } +} + +impl fmt::Display for Asn1Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Asn1Type::BitString => "BIT STRING", + Asn1Type::Ia5String => "IA5String", + Asn1Type::GeneralizedTime => "GeneralizedTime", + Asn1Type::OctetString => "OCTET STRING", + Asn1Type::PrintableString => "PrintableString", + Asn1Type::UtcTime => "UTCTime", + Asn1Type::Utf8String => "UTF8String", + }) + } +} + +/// Error type +#[derive(Debug)] +pub(crate) struct ParseError; diff --git a/src/attributes.rs b/src/attributes.rs new file mode 100644 index 0000000..c765e19 --- /dev/null +++ b/src/attributes.rs @@ -0,0 +1,341 @@ +//! Attribute-related types used by the proc macro + +use crate::{Asn1Type, Tag, TagMode, TagNumber}; +use proc_macro2::TokenStream; +use proc_macro_error::{abort, abort_call_site}; +use quote::quote; +use std::{fmt::Debug, str::FromStr}; +use syn::{Attribute, Lit, LitStr, Meta, MetaList, MetaNameValue, NestedMeta, Path}; + +/// Attribute name. +pub(crate) const ATTR_NAME: &str = "asn1"; + +/// Parsing error message. +const PARSE_ERR_MSG: &str = "error parsing `asn1` attribute"; + +/// Attributes on a `struct` or `enum` type. +#[derive(Clone, Debug, Default)] +pub(crate) struct TypeAttrs { + /// Tagging mode for this type: `EXPLICIT` or `IMPLICIT`, supplied as + /// `#[asn1(tag_mode = "...")]`. + /// + /// The default value is `EXPLICIT`. + pub tag_mode: TagMode, +} + +impl TypeAttrs { + /// Parse attributes from a struct field or enum variant. + pub fn parse(attrs: &[Attribute]) -> Self { + let mut tag_mode = None; + + let mut parsed_attrs = Vec::new(); + AttrNameValue::from_attributes(attrs, &mut parsed_attrs); + + for attr in parsed_attrs { + // `tag_mode = "..."` attribute + if let Some(mode) = attr.parse_value("tag_mode") { + if tag_mode.is_some() { + abort!(attr.name, "duplicate ASN.1 `tag_mode` attribute"); + } + + tag_mode = Some(mode); + } else { + abort!( + attr.name, + "invalid `asn1` attribute (valid options are `tag_mode`)", + ); + } + } + + Self { + tag_mode: tag_mode.unwrap_or_default(), + } + } +} + +/// Field-level attributes. +#[derive(Clone, Debug, Default)] +pub(crate) struct FieldAttrs { + /// Value of the `#[asn1(type = "...")]` attribute if provided. + pub asn1_type: Option<Asn1Type>, + + /// Value of the `#[asn1(context_specific = "...")] attribute if provided. + pub context_specific: Option<TagNumber>, + + /// Indicates name of function that supplies the default value, which will be used in cases + /// where encoding is omitted per DER and to omit the encoding per DER + pub default: Option<Path>, + + /// Is this field "extensible", i.e. preceded by the `...` extensibility marker? + pub extensible: bool, + + /// Is this field `OPTIONAL`? + pub optional: bool, + + /// Tagging mode for this type: `EXPLICIT` or `IMPLICIT`, supplied as + /// `#[asn1(tag_mode = "...")]`. + /// + /// Inherits from the type-level tagging mode if specified, or otherwise + /// defaults to `EXPLICIT`. + pub tag_mode: TagMode, + + /// Is the inner type constructed? + pub constructed: bool, +} + +impl FieldAttrs { + /// Return true when either an optional or default ASN.1 attribute is associated + /// with a field. Default signifies optionality due to omission of default values in + /// DER encodings. + fn is_optional(&self) -> bool { + self.optional || self.default.is_some() + } + + /// Parse attributes from a struct field or enum variant. + pub fn parse(attrs: &[Attribute], type_attrs: &TypeAttrs) -> Self { + let mut asn1_type = None; + let mut context_specific = None; + let mut default = None; + let mut extensible = None; + let mut optional = None; + let mut tag_mode = None; + let mut constructed = None; + + let mut parsed_attrs = Vec::new(); + AttrNameValue::from_attributes(attrs, &mut parsed_attrs); + + for attr in parsed_attrs { + // `context_specific = "..."` attribute + if let Some(tag_number) = attr.parse_value("context_specific") { + if context_specific.is_some() { + abort!(attr.name, "duplicate ASN.1 `context_specific` attribute"); + } + + context_specific = Some(tag_number); + // `default` attribute + } else if attr.parse_value::<String>("default").is_some() { + if default.is_some() { + abort!(attr.name, "duplicate ASN.1 `default` attribute"); + } + + default = Some(attr.value.parse().unwrap_or_else(|e| { + abort!(attr.value, "error parsing ASN.1 `default` attribute: {}", e) + })); + // `extensible` attribute + } else if let Some(ext) = attr.parse_value("extensible") { + if extensible.is_some() { + abort!(attr.name, "duplicate ASN.1 `extensible` attribute"); + } + + extensible = Some(ext); + // `optional` attribute + } else if let Some(opt) = attr.parse_value("optional") { + if optional.is_some() { + abort!(attr.name, "duplicate ASN.1 `optional` attribute"); + } + + optional = Some(opt); + // `tag_mode` attribute + } else if let Some(mode) = attr.parse_value("tag_mode") { + if tag_mode.is_some() { + abort!(attr.name, "duplicate ASN.1 `tag_mode` attribute"); + } + + tag_mode = Some(mode); + // `type = "..."` attribute + } else if let Some(ty) = attr.parse_value("type") { + if asn1_type.is_some() { + abort!(attr.name, "duplicate ASN.1 `type` attribute: {}"); + } + + asn1_type = Some(ty); + // `constructed = "..."` attribute + } else if let Some(ty) = attr.parse_value("constructed") { + if constructed.is_some() { + abort!(attr.name, "duplicate ASN.1 `constructed` attribute: {}"); + } + + constructed = Some(ty); + } else { + abort!( + attr.name, + "unknown field-level `asn1` attribute \ + (valid options are `context_specific`, `type`)", + ); + } + } + + Self { + asn1_type, + context_specific, + default, + extensible: extensible.unwrap_or_default(), + optional: optional.unwrap_or_default(), + tag_mode: tag_mode.unwrap_or(type_attrs.tag_mode), + constructed: constructed.unwrap_or_default(), + } + } + + /// Get the expected [`Tag`] for this field. + pub fn tag(&self) -> Option<Tag> { + match self.context_specific { + Some(tag_number) => Some(Tag::ContextSpecific { + constructed: self.constructed, + number: tag_number, + }), + + None => match self.tag_mode { + TagMode::Explicit => self.asn1_type.map(Tag::Universal), + TagMode::Implicit => abort_call_site!("implicit tagging requires a `tag_number`"), + }, + } + } + + /// Get a `der::Decoder` object which respects these field attributes. + pub fn decoder(&self) -> TokenStream { + if let Some(tag_number) = self.context_specific { + let type_params = self.asn1_type.map(|ty| ty.type_path()).unwrap_or_default(); + let tag_number = tag_number.to_tokens(); + + let context_specific = match self.tag_mode { + TagMode::Explicit => { + if self.extensible || self.is_optional() { + quote! { + ::der::asn1::ContextSpecific::<#type_params>::decode_explicit( + reader, + #tag_number + )? + } + } else { + quote! { + match ::der::asn1::ContextSpecific::<#type_params>::decode(reader)? { + field if field.tag_number == #tag_number => Some(field), + _ => None + } + } + } + } + TagMode::Implicit => { + quote! { + ::der::asn1::ContextSpecific::<#type_params>::decode_implicit( + reader, + #tag_number + )? + } + } + }; + + if self.is_optional() { + if let Some(default) = &self.default { + quote!(#context_specific.map(|cs| cs.value).unwrap_or_else(#default)) + } else { + quote!(#context_specific.map(|cs| cs.value)) + } + } else { + // TODO(tarcieri): better error handling? + let constructed = self.constructed; + quote! { + #context_specific.ok_or_else(|| { + der::Tag::ContextSpecific { + number: #tag_number, + constructed: #constructed + }.value_error() + })?.value + } + } + } else if let Some(default) = &self.default { + let type_params = self.asn1_type.map(|ty| ty.type_path()).unwrap_or_default(); + self.asn1_type.map(|ty| ty.decoder()).unwrap_or_else(|| { + quote! { + Option::<#type_params>::decode(reader)?.unwrap_or_else(#default), + } + }) + } else { + self.asn1_type + .map(|ty| ty.decoder()) + .unwrap_or_else(|| quote!(reader.decode()?)) + } + } + + /// Get tokens to encode the binding using `::der::EncodeValue`. + pub fn value_encode(&self, binding: &TokenStream) -> TokenStream { + match self.context_specific { + Some(tag_number) => { + let tag_number = tag_number.to_tokens(); + let tag_mode = self.tag_mode.to_tokens(); + quote! { + ::der::asn1::ContextSpecificRef { + tag_number: #tag_number, + tag_mode: #tag_mode, + value: #binding, + }.encode_value(encoder) + } + } + + None => self + .asn1_type + .map(|ty| { + let encoder_obj = ty.encoder(binding); + quote!(#encoder_obj.encode_value(encoder)) + }) + .unwrap_or_else(|| quote!(#binding.encode_value(encoder))), + } + } +} + +/// Name/value pair attribute. +struct AttrNameValue { + /// Attribute name. + pub name: Path, + + /// Attribute value. + pub value: LitStr, +} + +impl AttrNameValue { + /// Parse a slice of attributes. + pub fn from_attributes(attrs: &[Attribute], out: &mut Vec<Self>) { + for attr in attrs { + if !attr.path.is_ident(ATTR_NAME) { + continue; + } + + let nested = match attr.parse_meta().expect(PARSE_ERR_MSG) { + Meta::List(MetaList { nested, .. }) => nested, + other => abort!(other, "malformed `asn1` attribute"), + }; + + for meta in &nested { + match meta { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(lit_str), + .. + })) => out.push(Self { + name: path.clone(), + value: lit_str.clone(), + }), + _ => abort!(nested, "malformed `asn1` attribute"), + } + } + } + } + + /// Parse an attribute value if the name matches the specified one. + pub fn parse_value<T>(&self, name: &str) -> Option<T> + where + T: FromStr + Debug, + T::Err: Debug, + { + if self.name.is_ident(name) { + Some( + self.value + .value() + .parse() + .unwrap_or_else(|_| abort!(self.name, "error parsing attribute")), + ) + } else { + None + } + } +} diff --git a/src/choice.rs b/src/choice.rs new file mode 100644 index 0000000..ac0dc37 --- /dev/null +++ b/src/choice.rs @@ -0,0 +1,251 @@ +//! Support for deriving the `Decode` and `Encode` traits on enums for +//! the purposes of decoding/encoding ASN.1 `CHOICE` types as mapped to +//! enum variants. + +mod variant; + +use self::variant::ChoiceVariant; +use crate::{default_lifetime, TypeAttrs}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::quote; +use syn::{DeriveInput, Ident, Lifetime}; + +/// Derive the `Choice` trait for an enum. +pub(crate) struct DeriveChoice { + /// Name of the enum type. + ident: Ident, + + /// Lifetime of the type. + lifetime: Option<Lifetime>, + + /// Variants of this `Choice`. + variants: Vec<ChoiceVariant>, +} + +impl DeriveChoice { + /// Parse [`DeriveInput`]. + pub fn new(input: DeriveInput) -> Self { + let data = match input.data { + syn::Data::Enum(data) => data, + _ => abort!( + input.ident, + "can't derive `Choice` on this type: only `enum` types are allowed", + ), + }; + + // TODO(tarcieri): properly handle multiple lifetimes + let lifetime = input + .generics + .lifetimes() + .next() + .map(|lt| lt.lifetime.clone()); + + let type_attrs = TypeAttrs::parse(&input.attrs); + let variants = data + .variants + .iter() + .map(|variant| ChoiceVariant::new(variant, &type_attrs)) + .collect(); + + Self { + ident: input.ident, + lifetime, + variants, + } + } + + /// Lower the derived output into a [`TokenStream`]. + pub fn to_tokens(&self) -> TokenStream { + let ident = &self.ident; + + let lifetime = match self.lifetime { + Some(ref lifetime) => quote!(#lifetime), + None => default_lifetime(), + }; + + // Lifetime parameters + // TODO(tarcieri): support multiple lifetimes + let lt_params = self + .lifetime + .as_ref() + .map(|_| lifetime.clone()) + .unwrap_or_default(); + + let mut can_decode_body = Vec::new(); + let mut decode_body = Vec::new(); + let mut encode_body = Vec::new(); + let mut value_len_body = Vec::new(); + let mut tagged_body = Vec::new(); + + for variant in &self.variants { + can_decode_body.push(variant.tag.to_tokens()); + decode_body.push(variant.to_decode_tokens()); + encode_body.push(variant.to_encode_value_tokens()); + value_len_body.push(variant.to_value_len_tokens()); + tagged_body.push(variant.to_tagged_tokens()); + } + + quote! { + impl<#lifetime> ::der::Choice<#lifetime> for #ident<#lt_params> { + fn can_decode(tag: ::der::Tag) -> bool { + matches!(tag, #(#can_decode_body)|*) + } + } + + impl<#lifetime> ::der::Decode<#lifetime> for #ident<#lt_params> { + fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> { + use der::Reader as _; + match reader.peek_tag()? { + #(#decode_body)* + actual => Err(der::ErrorKind::TagUnexpected { + expected: None, + actual + } + .into()), + } + } + } + + impl<#lt_params> ::der::EncodeValue for #ident<#lt_params> { + fn encode_value(&self, encoder: &mut dyn ::der::Writer) -> ::der::Result<()> { + match self { + #(#encode_body)* + } + } + + fn value_len(&self) -> ::der::Result<::der::Length> { + match self { + #(#value_len_body)* + } + } + } + + impl<#lt_params> ::der::Tagged for #ident<#lt_params> { + fn tag(&self) -> ::der::Tag { + match self { + #(#tagged_body)* + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::DeriveChoice; + use crate::{Asn1Type, Tag, TagMode}; + use syn::parse_quote; + + /// Based on `Time` as defined in RFC 5280: + /// <https://tools.ietf.org/html/rfc5280#page-117> + /// + /// ```text + /// Time ::= CHOICE { + /// utcTime UTCTime, + /// generalTime GeneralizedTime } + /// ``` + #[test] + fn time_example() { + let input = parse_quote! { + pub enum Time { + #[asn1(type = "UTCTime")] + UtcTime(UtcTime), + + #[asn1(type = "GeneralizedTime")] + GeneralTime(GeneralizedTime), + } + }; + + let ir = DeriveChoice::new(input); + assert_eq!(ir.ident, "Time"); + assert_eq!(ir.lifetime, None); + assert_eq!(ir.variants.len(), 2); + + let utc_time = &ir.variants[0]; + assert_eq!(utc_time.ident, "UtcTime"); + assert_eq!(utc_time.attrs.asn1_type, Some(Asn1Type::UtcTime)); + assert_eq!(utc_time.attrs.context_specific, None); + assert_eq!(utc_time.attrs.tag_mode, TagMode::Explicit); + assert_eq!(utc_time.tag, Tag::Universal(Asn1Type::UtcTime)); + + let general_time = &ir.variants[1]; + assert_eq!(general_time.ident, "GeneralTime"); + assert_eq!( + general_time.attrs.asn1_type, + Some(Asn1Type::GeneralizedTime) + ); + assert_eq!(general_time.attrs.context_specific, None); + assert_eq!(general_time.attrs.tag_mode, TagMode::Explicit); + assert_eq!(general_time.tag, Tag::Universal(Asn1Type::GeneralizedTime)); + } + + /// `IMPLICIT` tagged example + #[test] + fn implicit_example() { + let input = parse_quote! { + #[asn1(tag_mode = "IMPLICIT")] + pub enum ImplicitChoice<'a> { + #[asn1(context_specific = "0", type = "BIT STRING")] + BitString(BitString<'a>), + + #[asn1(context_specific = "1", type = "GeneralizedTime")] + Time(GeneralizedTime), + + #[asn1(context_specific = "2", type = "UTF8String")] + Utf8String(String), + } + }; + + let ir = DeriveChoice::new(input); + assert_eq!(ir.ident, "ImplicitChoice"); + assert_eq!(ir.lifetime.unwrap().to_string(), "'a"); + assert_eq!(ir.variants.len(), 3); + + let bit_string = &ir.variants[0]; + assert_eq!(bit_string.ident, "BitString"); + assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString)); + assert_eq!( + bit_string.attrs.context_specific, + Some("0".parse().unwrap()) + ); + assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit); + assert_eq!( + bit_string.tag, + Tag::ContextSpecific { + constructed: false, + number: "0".parse().unwrap() + } + ); + + let time = &ir.variants[1]; + assert_eq!(time.ident, "Time"); + assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime)); + assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap())); + assert_eq!(time.attrs.tag_mode, TagMode::Implicit); + assert_eq!( + time.tag, + Tag::ContextSpecific { + constructed: false, + number: "1".parse().unwrap() + } + ); + + let utf8_string = &ir.variants[2]; + assert_eq!(utf8_string.ident, "Utf8String"); + assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String)); + assert_eq!( + utf8_string.attrs.context_specific, + Some("2".parse().unwrap()) + ); + assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit); + assert_eq!( + utf8_string.tag, + Tag::ContextSpecific { + constructed: false, + number: "2".parse().unwrap() + } + ); + } +} diff --git a/src/choice/variant.rs b/src/choice/variant.rs new file mode 100644 index 0000000..d74be61 --- /dev/null +++ b/src/choice/variant.rs @@ -0,0 +1,417 @@ +//! Choice variant IR and lowerings + +use crate::{FieldAttrs, Tag, TypeAttrs}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::quote; +use syn::{Fields, Ident, Path, Type, Variant}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(super) enum TagOrPath { + Tag(Tag), + Path(Path), +} + +impl PartialEq<Tag> for TagOrPath { + fn eq(&self, rhs: &Tag) -> bool { + match self { + Self::Tag(lhs) => lhs == rhs, + _ => false, + } + } +} + +impl From<Tag> for TagOrPath { + fn from(tag: Tag) -> Self { + Self::Tag(tag) + } +} + +impl From<Path> for TagOrPath { + fn from(path: Path) -> Self { + Self::Path(path) + } +} + +impl From<&Variant> for TagOrPath { + fn from(input: &Variant) -> Self { + if let Fields::Unnamed(fields) = &input.fields { + if fields.unnamed.len() == 1 { + if let Type::Path(path) = &fields.unnamed[0].ty { + return path.path.clone().into(); + } + } + } + + abort!( + &input.ident, + "no #[asn1(type=...)] specified for enum variant" + ) + } +} + +impl TagOrPath { + pub fn to_tokens(&self) -> TokenStream { + match self { + Self::Tag(tag) => tag.to_tokens(), + Self::Path(path) => quote! { <#path as ::der::FixedTag>::TAG }, + } + } +} + +/// "IR" for a variant of a derived `Choice`. +pub(super) struct ChoiceVariant { + /// Variant name. + pub(super) ident: Ident, + + /// "Field" (in this case variant)-level attributes. + pub(super) attrs: FieldAttrs, + + /// Tag for the ASN.1 type. + pub(super) tag: TagOrPath, +} + +impl ChoiceVariant { + /// Create a new [`ChoiceVariant`] from the input [`Variant`]. + pub(super) fn new(input: &Variant, type_attrs: &TypeAttrs) -> Self { + let ident = input.ident.clone(); + let attrs = FieldAttrs::parse(&input.attrs, type_attrs); + + if attrs.extensible { + abort!(&ident, "`extensible` is not allowed on CHOICE"); + } + + // Validate that variant is a 1-element tuple struct + match &input.fields { + // TODO(tarcieri): handle 0 bindings for ASN.1 NULL + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => (), + _ => abort!(&ident, "enum variant must be a 1-element tuple struct"), + } + + let tag = attrs + .tag() + .map(TagOrPath::from) + .unwrap_or_else(|| TagOrPath::from(input)); + + Self { ident, attrs, tag } + } + + /// Derive a match arm of the impl body for `TryFrom<der::asn1::Any<'_>>`. + pub(super) fn to_decode_tokens(&self) -> TokenStream { + let tag = self.tag.to_tokens(); + let ident = &self.ident; + let decoder = self.attrs.decoder(); + + match self.attrs.asn1_type { + Some(..) => quote! { #tag => Ok(Self::#ident(#decoder.try_into()?)), }, + None => quote! { #tag => Ok(Self::#ident(#decoder)), }, + } + } + + /// Derive a match arm for the impl body for `der::EncodeValue::encode_value`. + pub(super) fn to_encode_value_tokens(&self) -> TokenStream { + let ident = &self.ident; + let binding = quote!(variant); + let encoder = self.attrs.value_encode(&binding); + quote! { + Self::#ident(#binding) => #encoder, + } + } + + /// Derive a match arm for the impl body for `der::EncodeValue::value_len`. + pub(super) fn to_value_len_tokens(&self) -> TokenStream { + let ident = &self.ident; + + match self.attrs.context_specific { + Some(tag_number) => { + let tag_number = tag_number.to_tokens(); + let tag_mode = self.attrs.tag_mode.to_tokens(); + + quote! { + Self::#ident(variant) => ::der::asn1::ContextSpecificRef { + tag_number: #tag_number, + tag_mode: #tag_mode, + value: variant, + }.value_len(), + } + } + + _ => quote! { Self::#ident(variant) => variant.value_len(), }, + } + } + + /// Derive a match arm for the impl body for `der::Tagged::tag`. + pub(super) fn to_tagged_tokens(&self) -> TokenStream { + let ident = &self.ident; + let tag = self.tag.to_tokens(); + quote! { + Self::#ident(_) => #tag, + } + } +} + +#[cfg(test)] +mod tests { + use super::ChoiceVariant; + use crate::{choice::variant::TagOrPath, Asn1Type, FieldAttrs, Tag, TagMode, TagNumber}; + use proc_macro2::Span; + use quote::quote; + use syn::Ident; + + #[test] + fn simple() { + let ident = Ident::new("ExampleVariant", Span::call_site()); + let attrs = FieldAttrs::default(); + let tag = Tag::Universal(Asn1Type::Utf8String).into(); + let variant = ChoiceVariant { ident, attrs, tag }; + + assert_eq!( + variant.to_decode_tokens().to_string(), + quote! { + ::der::Tag::Utf8String => Ok(Self::ExampleVariant( + reader.decode()? + )), + } + .to_string() + ); + + assert_eq!( + variant.to_encode_value_tokens().to_string(), + quote! { + Self::ExampleVariant(variant) => variant.encode_value(encoder), + } + .to_string() + ); + + assert_eq!( + variant.to_value_len_tokens().to_string(), + quote! { + Self::ExampleVariant(variant) => variant.value_len(), + } + .to_string() + ); + + assert_eq!( + variant.to_tagged_tokens().to_string(), + quote! { + Self::ExampleVariant(_) => ::der::Tag::Utf8String, + } + .to_string() + ) + } + + #[test] + fn utf8string() { + let ident = Ident::new("ExampleVariant", Span::call_site()); + let attrs = FieldAttrs { + asn1_type: Some(Asn1Type::Utf8String), + ..Default::default() + }; + let tag = Tag::Universal(Asn1Type::Utf8String).into(); + let variant = ChoiceVariant { ident, attrs, tag }; + + assert_eq!( + variant.to_decode_tokens().to_string(), + quote! { + ::der::Tag::Utf8String => Ok(Self::ExampleVariant( + ::der::asn1::Utf8StringRef::decode(reader)? + .try_into()? + )), + } + .to_string() + ); + + assert_eq!( + variant.to_encode_value_tokens().to_string(), + quote! { + Self::ExampleVariant(variant) => ::der::asn1::Utf8StringRef::new(variant)?.encode_value(encoder), + } + .to_string() + ); + + assert_eq!( + variant.to_value_len_tokens().to_string(), + quote! { + Self::ExampleVariant(variant) => variant.value_len(), + } + .to_string() + ); + + assert_eq!( + variant.to_tagged_tokens().to_string(), + quote! { + Self::ExampleVariant(_) => ::der::Tag::Utf8String, + } + .to_string() + ) + } + + #[test] + fn explicit() { + for tag_number in [0, 1, 2, 3] { + for constructed in [false, true] { + let ident = Ident::new("ExplicitVariant", Span::call_site()); + let attrs = FieldAttrs { + constructed, + context_specific: Some(TagNumber(tag_number)), + ..Default::default() + }; + assert_eq!(attrs.tag_mode, TagMode::Explicit); + + let tag = TagOrPath::Tag(Tag::ContextSpecific { + constructed, + number: TagNumber(tag_number), + }); + + let variant = ChoiceVariant { ident, attrs, tag }; + let tag_number = TagNumber(tag_number).to_tokens(); + + assert_eq!( + variant.to_decode_tokens().to_string(), + quote! { + ::der::Tag::ContextSpecific { + constructed: #constructed, + number: #tag_number, + } => Ok(Self::ExplicitVariant( + match ::der::asn1::ContextSpecific::<>::decode(reader)? { + field if field.tag_number == #tag_number => Some(field), + _ => None + } + .ok_or_else(|| { + der::Tag::ContextSpecific { + number: #tag_number, + constructed: #constructed + } + .value_error() + })? + .value + )), + } + .to_string() + ); + + assert_eq!( + variant.to_encode_value_tokens().to_string(), + quote! { + Self::ExplicitVariant(variant) => ::der::asn1::ContextSpecificRef { + tag_number: #tag_number, + tag_mode: ::der::TagMode::Explicit, + value: variant, + } + .encode_value(encoder), + } + .to_string() + ); + + assert_eq!( + variant.to_value_len_tokens().to_string(), + quote! { + Self::ExplicitVariant(variant) => ::der::asn1::ContextSpecificRef { + tag_number: #tag_number, + tag_mode: ::der::TagMode::Explicit, + value: variant, + } + .value_len(), + } + .to_string() + ); + + assert_eq!( + variant.to_tagged_tokens().to_string(), + quote! { + Self::ExplicitVariant(_) => ::der::Tag::ContextSpecific { + constructed: #constructed, + number: #tag_number, + }, + } + .to_string() + ) + } + } + } + + #[test] + fn implicit() { + for tag_number in [0, 1, 2, 3] { + for constructed in [false, true] { + let ident = Ident::new("ImplicitVariant", Span::call_site()); + + let attrs = FieldAttrs { + constructed, + context_specific: Some(TagNumber(tag_number)), + tag_mode: TagMode::Implicit, + ..Default::default() + }; + + let tag = TagOrPath::Tag(Tag::ContextSpecific { + constructed, + number: TagNumber(tag_number), + }); + + let variant = ChoiceVariant { ident, attrs, tag }; + let tag_number = TagNumber(tag_number).to_tokens(); + + assert_eq!( + variant.to_decode_tokens().to_string(), + quote! { + ::der::Tag::ContextSpecific { + constructed: #constructed, + number: #tag_number, + } => Ok(Self::ImplicitVariant( + ::der::asn1::ContextSpecific::<>::decode_implicit( + reader, + #tag_number + )? + .ok_or_else(|| { + der::Tag::ContextSpecific { + number: #tag_number, + constructed: #constructed + } + .value_error() + })? + .value + )), + } + .to_string() + ); + + assert_eq!( + variant.to_encode_value_tokens().to_string(), + quote! { + Self::ImplicitVariant(variant) => ::der::asn1::ContextSpecificRef { + tag_number: #tag_number, + tag_mode: ::der::TagMode::Implicit, + value: variant, + } + .encode_value(encoder), + } + .to_string() + ); + + assert_eq!( + variant.to_value_len_tokens().to_string(), + quote! { + Self::ImplicitVariant(variant) => ::der::asn1::ContextSpecificRef { + tag_number: #tag_number, + tag_mode: ::der::TagMode::Implicit, + value: variant, + } + .value_len(), + } + .to_string() + ); + + assert_eq!( + variant.to_tagged_tokens().to_string(), + quote! { + Self::ImplicitVariant(_) => ::der::Tag::ContextSpecific { + constructed: #constructed, + number: #tag_number, + }, + } + .to_string() + ) + } + } + } +} diff --git a/src/enumerated.rs b/src/enumerated.rs new file mode 100644 index 0000000..64db64a --- /dev/null +++ b/src/enumerated.rs @@ -0,0 +1,244 @@ +//! Support for deriving the `Decode` and `Encode` traits on enums for +//! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to +//! enum variants. + +use crate::{default_lifetime, ATTR_NAME}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::quote; +use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Meta, MetaList, NestedMeta, Variant}; + +/// Valid options for the `#[repr]` attribute on `Enumerated` types. +const REPR_TYPES: &[&str] = &["u8", "u16", "u32"]; + +/// Derive the `Enumerated` trait for an enum. +pub(crate) struct DeriveEnumerated { + /// Name of the enum type. + ident: Ident, + + /// Value of the `repr` attribute. + repr: Ident, + + /// Whether or not to tag the enum as an integer + integer: bool, + + /// Variants of this enum. + variants: Vec<EnumeratedVariant>, +} + +impl DeriveEnumerated { + /// Parse [`DeriveInput`]. + pub fn new(input: DeriveInput) -> Self { + let data = match input.data { + syn::Data::Enum(data) => data, + _ => abort!( + input.ident, + "can't derive `Enumerated` on this type: only `enum` types are allowed", + ), + }; + + // Reject `asn1` attributes, parse the `repr` attribute + let mut repr: Option<Ident> = None; + let mut integer = false; + + for attr in &input.attrs { + if attr.path.is_ident(ATTR_NAME) { + if let Ok(Meta::List(MetaList { nested, .. })) = attr.parse_meta() { + for meta in nested { + if let NestedMeta::Meta(Meta::NameValue(nv)) = meta { + if nv.path.is_ident("type") { + if let Lit::Str(lit) = nv.lit { + match lit.value().as_str() { + "ENUMERATED" => integer = false, + "INTEGER" => integer = true, + s => abort!(lit, "`type = \"{}\"` is unsupported", s), + } + } + } + } + } + } + } else if attr.path.is_ident("repr") { + if repr.is_some() { + abort!( + attr, + "multiple `#[repr]` attributes encountered on `Enumerated`", + ); + } + + let r = attr + .parse_args::<Ident>() + .unwrap_or_else(|_| abort!(attr, "error parsing `#[repr]` attribute")); + + // Validate + if !REPR_TYPES.contains(&r.to_string().as_str()) { + abort!( + attr, + "invalid `#[repr]` type: allowed types are {:?}", + REPR_TYPES + ); + } + + repr = Some(r); + } + } + + // Parse enum variants + let variants = data.variants.iter().map(EnumeratedVariant::new).collect(); + + Self { + ident: input.ident.clone(), + repr: repr.unwrap_or_else(|| { + abort!( + &input.ident, + "no `#[repr]` attribute on enum: must be one of {:?}", + REPR_TYPES + ) + }), + variants, + integer, + } + } + + /// Lower the derived output into a [`TokenStream`]. + pub fn to_tokens(&self) -> TokenStream { + let default_lifetime = default_lifetime(); + let ident = &self.ident; + let repr = &self.repr; + let tag = match self.integer { + false => quote! { ::der::Tag::Enumerated }, + true => quote! { ::der::Tag::Integer }, + }; + + let mut try_from_body = Vec::new(); + for variant in &self.variants { + try_from_body.push(variant.to_try_from_tokens()); + } + + quote! { + impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident { + fn decode_value<R: ::der::Reader<#default_lifetime>>( + reader: &mut R, + header: ::der::Header + ) -> ::der::Result<Self> { + <#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into() + } + } + + impl ::der::EncodeValue for #ident { + fn value_len(&self) -> ::der::Result<::der::Length> { + ::der::EncodeValue::value_len(&(*self as #repr)) + } + + fn encode_value(&self, encoder: &mut dyn ::der::Writer) -> ::der::Result<()> { + ::der::EncodeValue::encode_value(&(*self as #repr), encoder) + } + } + + impl ::der::FixedTag for #ident { + const TAG: ::der::Tag = #tag; + } + + impl TryFrom<#repr> for #ident { + type Error = ::der::Error; + + fn try_from(n: #repr) -> ::der::Result<Self> { + match n { + #(#try_from_body)* + _ => Err(#tag.value_error()) + } + } + } + } + } +} + +/// "IR" for a variant of a derived `Enumerated`. +pub struct EnumeratedVariant { + /// Variant name. + ident: Ident, + + /// Integer value that this variant corresponds to. + discriminant: LitInt, +} + +impl EnumeratedVariant { + /// Create a new [`ChoiceVariant`] from the input [`Variant`]. + fn new(input: &Variant) -> Self { + for attr in &input.attrs { + if attr.path.is_ident(ATTR_NAME) { + abort!( + attr, + "`asn1` attribute is not allowed on fields of `Enumerated` types" + ); + } + } + + match &input.discriminant { + Some(( + _, + Expr::Lit(ExprLit { + lit: Lit::Int(discriminant), + .. + }), + )) => Self { + ident: input.ident.clone(), + discriminant: discriminant.clone(), + }, + Some((_, other)) => abort!(other, "invalid discriminant for `Enumerated`"), + None => abort!(input, "`Enumerated` variant has no discriminant"), + } + } + + /// Write the body for the derived [`TryFrom`] impl. + pub fn to_try_from_tokens(&self) -> TokenStream { + let ident = &self.ident; + let discriminant = &self.discriminant; + quote! { + #discriminant => Ok(Self::#ident), + } + } +} + +#[cfg(test)] +mod tests { + use super::DeriveEnumerated; + use syn::parse_quote; + + /// X.509 `CRLReason`. + #[test] + fn crlreason_example() { + let input = parse_quote! { + #[repr(u32)] + pub enum CrlReason { + Unspecified = 0, + KeyCompromise = 1, + CaCompromise = 2, + AffiliationChanged = 3, + Superseded = 4, + CessationOfOperation = 5, + CertificateHold = 6, + RemoveFromCrl = 8, + PrivilegeWithdrawn = 9, + AaCompromised = 10, + } + }; + + let ir = DeriveEnumerated::new(input); + assert_eq!(ir.ident, "CrlReason"); + assert_eq!(ir.repr, "u32"); + assert_eq!(ir.variants.len(), 10); + + let unspecified = &ir.variants[0]; + assert_eq!(unspecified.ident, "Unspecified"); + assert_eq!(unspecified.discriminant.to_string(), "0"); + + let key_compromise = &ir.variants[1]; + assert_eq!(key_compromise.ident, "KeyCompromise"); + assert_eq!(key_compromise.discriminant.to_string(), "1"); + + let key_compromise = &ir.variants[2]; + assert_eq!(key_compromise.ident, "CaCompromise"); + assert_eq!(key_compromise.discriminant.to_string(), "2"); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b7aef1d --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,284 @@ +#![doc = include_str!("../README.md")] + +//! ## About +//! Custom derive support for the [`der`] crate. +//! +//! This crate contains custom derive macros intended to be used in the +//! following way: +//! +//! - [`Choice`][`derive@Choice`]: map ASN.1 `CHOICE` to a Rust enum. +//! - [`Enumerated`][`derive@Enumerated`]: map ASN.1 `ENUMERATED` to a C-like Rust enum. +//! - [`Sequence`][`derive@Sequence`]: map ASN.1 `SEQUENCE` to a Rust struct. +//! - [`ValueOrd`][`derive@ValueOrd`]: determine DER ordering for ASN.1 `SET OF`. +//! +//! Note that this crate shouldn't be used directly, but instead accessed +//! by using the `derive` feature of the `der` crate, which re-exports the +//! above macros from the toplevel. +//! +//! ## Why not `serde`? +//! The `der` crate is designed to be easily usable in embedded environments, +//! including ones where code size comes at a premium. +//! +//! This crate (i.e. `der_derive`) is able to generate code which is +//! significantly smaller than `serde_derive`. This is because the `der` +//! crate has been designed with high-level abstractions which reduce +//! code size, including trait object-based encoders which allow encoding +//! logic which is duplicated in `serde` serializers to be implemented in +//! a single place in the `der` crate. +//! +//! This is a deliberate tradeoff in terms of performance, flexibility, and +//! code size. At least for now, the `der` crate is optimizing for leveraging +//! as many abstractions as it can to minimize code size. +//! +//! ## Toplevel attributes +//! +//! The following attributes can be added to an `enum` or `struct` when +//! deriving either [`Choice`] or [`Sequence`] respectively: +//! +//! ### `#[asn1(tag_mode = "...")]` attribute: `EXPLICIT` vs `IMPLICIT` +//! +//! This attribute can be used to declare the tagging mode used by a particular +//! ASN.1 module. +//! +//! It's used when parsing `CONTEXT-SENSITIVE` fields. +//! +//! The default is `EXPLICIT`, so the attribute only needs to be added when +//! a particular module is declared `IMPLICIT`. +//! +//! ## Field-level attributes +//! +//! The following attributes can be added to either the fields of a particular +//! `struct` or the variants of a particular `enum`: +//! +//! ### `#[asn1(context_specific = "...")]` attribute: `CONTEXT-SPECIFIC` support +//! +//! This attribute can be added to associate a particular `CONTEXT-SPECIFIC` +//! tag number with a given enum variant or struct field. +//! +//! The value must be quoted and contain a number, e.g. `#[asn1(context_specific = "42"]`. +//! +//! ### `#[asn1(default = "...")]` attribute: `DEFAULT` support +//! +//! This behaves like `serde_derive`'s `default` attribute, allowing you to +//! specify the path to a function which returns a default value. +//! +//! ### `#[asn1(extensible = "true")]` attribute: support for `...` extensibility operator +//! +//! This attribute can be applied to the fields of `struct` types, and will +//! skip over unrecognized lower-numbered `CONTEXT-SPECIFIC` fields when +//! looking for a particular field of a struct. +//! +//! ### `#[asn1(optional = "true")]` attribute: support for `OPTIONAL` fields +//! +//! This attribute explicitly annotates a field as `OPTIONAL`. +//! +//! ### `#[asn1(type = "...")]` attribute: ASN.1 type declaration +//! +//! This attribute can be used to specify the ASN.1 type for a particular +//! `enum` variant or `struct` field. +//! +//! It's presently mandatory for all `enum` variants, even when using one of +//! the ASN.1 types defined by this crate. +//! +//! For structs, placing this attribute on a field makes it possible to +//! decode/encode types which don't directly implement the `Decode`/`Encode` +//! traits but do impl `From` and `TryInto` and `From` for one of the ASN.1 types +//! listed below (use the ASN.1 type keywords as the `type`): +//! +//! - `BIT STRING`: performs an intermediate conversion to [`der::asn1::BitString`] +//! - `IA5String`: performs an intermediate conversion to [`der::asn1::IA5String`] +//! - `GeneralizedTime`: performs an intermediate conversion to [`der::asn1::GeneralizedTime`] +//! - `OCTET STRING`: performs an intermediate conversion to [`der::asn1::OctetString`] +//! - `PrintableString`: performs an intermediate conversion to [`der::asn1::PrintableString`] +//! - `UTCTime`: performs an intermediate conversion to [`der::asn1::UtcTime`] +//! - `UTF8String`: performs an intermediate conversion to [`der::asn1::Utf8String`] +//! +//! ### `#[asn1(constructed = "...")]` attribute: support for constructed inner types +//! +//! This attribute can be used to specify that an "inner" type is constructed. It is most +//! commonly used when a `CHOICE` has a constructed inner type. +//! +//! Note: please open a GitHub Issue if you would like to request support +//! for additional ASN.1 types. +//! +//! [`der`]: https://docs.rs/der/ +//! [`Choice`]: derive@Choice +//! [`Sequence`]: derive@Sequence +//! [`der::asn1::BitString`]: https://docs.rs/der/latest/der/asn1/struct.BitString.html +//! [`der::asn1::Ia5String`]: https://docs.rs/der/latest/der/asn1/struct.Ia5String.html +//! [`der::asn1::GeneralizedTime`]: https://docs.rs/der/latest/der/asn1/struct.GeneralizedTime.html +//! [`der::asn1::OctetString`]: https://docs.rs/der/latest/der/asn1/struct.OctetString.html +//! [`der::asn1::PrintableString`]: https://docs.rs/der/latest/der/asn1/struct.PrintableString.html +//! [`der::asn1::UtcTime`]: https://docs.rs/der/latest/der/asn1/struct.UtcTime.html +//! [`der::asn1::Utf8String`]: https://docs.rs/der/latest/der/asn1/struct.Utf8String.html + +#![crate_type = "proc-macro"] +#![forbid(unsafe_code)] +#![warn( + clippy::unwrap_used, + rust_2018_idioms, + trivial_casts, + unused_qualifications +)] + +mod asn1_type; +mod attributes; +mod choice; +mod enumerated; +mod sequence; +mod tag; +mod value_ord; + +use crate::{ + asn1_type::Asn1Type, + attributes::{FieldAttrs, TypeAttrs, ATTR_NAME}, + choice::DeriveChoice, + enumerated::DeriveEnumerated, + sequence::DeriveSequence, + tag::{Tag, TagMode, TagNumber}, + value_ord::DeriveValueOrd, +}; +use proc_macro::TokenStream; +use proc_macro2::Span; +use proc_macro_error::proc_macro_error; +use quote::quote; +use syn::{parse_macro_input, DeriveInput, Lifetime}; + +/// Get the default lifetime. +fn default_lifetime() -> proc_macro2::TokenStream { + let lifetime = Lifetime::new("'__der_lifetime", Span::call_site()); + quote!(#lifetime) +} + +/// Derive the [`Choice`][1] trait on an `enum`. +/// +/// This custom derive macro can be used to automatically impl the +/// [`Decode`][2] and [`Encode`][3] traits along with the +/// [`Choice`][1] supertrait for any enum representing an ASN.1 `CHOICE`. +/// +/// The enum must consist entirely of 1-tuple variants wrapping inner +/// types which must also impl the [`Decode`][2] and [`Encode`][3] +/// traits. It will will also generate [`From`] impls for each of the +/// inner types of the variants into the enum that wraps them. +/// +/// # Usage +/// +/// ```ignore +/// // NOTE: requires the `derive` feature of `der` +/// use der::Choice; +/// +/// /// `Time` as defined in RFC 5280 +/// #[derive(Choice)] +/// pub enum Time { +/// #[asn1(type = "UTCTime")] +/// UtcTime(UtcTime), +/// +/// #[asn1(type = "GeneralizedTime")] +/// GeneralTime(GeneralizedTime), +/// } +/// ``` +/// +/// # `#[asn1(type = "...")]` attribute +/// +/// See [toplevel documentation for the `der_derive` crate][4] for more +/// information about the `#[asn1]` attribute. +/// +/// [1]: https://docs.rs/der/latest/der/trait.Choice.html +/// [2]: https://docs.rs/der/latest/der/trait.Decode.html +/// [3]: https://docs.rs/der/latest/der/trait.Encode.html +/// [4]: https://docs.rs/der_derive/ +#[proc_macro_derive(Choice, attributes(asn1))] +#[proc_macro_error] +pub fn derive_choice(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + DeriveChoice::new(input).to_tokens().into() +} + +/// Derive decoders and encoders for ASN.1 [`Enumerated`] types on a +/// C-like `enum` type. +/// +/// # Usage +/// +/// The `Enumerated` proc macro requires a C-like enum which impls `Copy` +/// and has a `#[repr]` of `u8`, `u16`, or `u32`: +/// +/// ```ignore +/// use der::Enumerated; +/// +/// #[derive(Enumerated, Copy, Clone, Debug, Eq, PartialEq)] +/// #[repr(u32)] +/// pub enum CrlReason { +/// Unspecified = 0, +/// KeyCompromise = 1, +/// CaCompromise = 2, +/// AffiliationChanged = 3, +/// Superseded = 4, +/// CessationOfOperation = 5, +/// CertificateHold = 6, +/// RemoveFromCrl = 8, +/// PrivilegeWithdrawn = 9, +/// AaCompromised = 10 +/// } +/// ``` +/// +/// Note that the derive macro will write a `TryFrom<...>` impl for the +/// provided `#[repr]`, which is used by the decoder. +#[proc_macro_derive(Enumerated, attributes(asn1))] +#[proc_macro_error] +pub fn derive_enumerated(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + DeriveEnumerated::new(input).to_tokens().into() +} + +/// Derive the [`Sequence`][1] trait on a `struct`. +/// +/// This custom derive macro can be used to automatically impl the +/// `Sequence` trait for any struct which can be decoded/encoded as an +/// ASN.1 `SEQUENCE`. +/// +/// # Usage +/// +/// ```ignore +/// use der::{ +/// asn1::{Any, ObjectIdentifier}, +/// Sequence +/// }; +/// +/// /// X.509 `AlgorithmIdentifier` +/// #[derive(Sequence)] +/// pub struct AlgorithmIdentifier<'a> { +/// /// This field contains an ASN.1 `OBJECT IDENTIFIER`, a.k.a. OID. +/// pub algorithm: ObjectIdentifier, +/// +/// /// This field is `OPTIONAL` and contains the ASN.1 `ANY` type, which +/// /// in this example allows arbitrary algorithm-defined parameters. +/// pub parameters: Option<Any<'a>> +/// } +/// ``` +/// +/// # `#[asn1(type = "...")]` attribute +/// +/// See [toplevel documentation for the `der_derive` crate][2] for more +/// information about the `#[asn1]` attribute. +/// +/// [1]: https://docs.rs/der/latest/der/trait.Sequence.html +/// [2]: https://docs.rs/der_derive/ +#[proc_macro_derive(Sequence, attributes(asn1))] +#[proc_macro_error] +pub fn derive_sequence(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + DeriveSequence::new(input).to_tokens().into() +} + +/// Derive the [`ValueOrd`][1] trait on a `struct`. +/// +/// This trait is used in conjunction with ASN.1 `SET OF` types to determine +/// the lexicographical order of their DER encodings. +/// +/// [1]: https://docs.rs/der/latest/der/trait.ValueOrd.html +#[proc_macro_derive(ValueOrd, attributes(asn1))] +#[proc_macro_error] +pub fn derive_value_ord(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + DeriveValueOrd::new(input).to_tokens().into() +} diff --git a/src/sequence.rs b/src/sequence.rs new file mode 100644 index 0000000..3e47246 --- /dev/null +++ b/src/sequence.rs @@ -0,0 +1,331 @@ +//! Support for deriving the `Sequence` trait on structs for the purposes of +//! decoding/encoding ASN.1 `SEQUENCE` types as mapped to struct fields. + +mod field; + +use crate::{default_lifetime, TypeAttrs}; +use field::SequenceField; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::quote; +use syn::{DeriveInput, Ident, Lifetime}; + +/// Derive the `Sequence` trait for a struct +pub(crate) struct DeriveSequence { + /// Name of the sequence struct. + ident: Ident, + + /// Lifetime of the struct. + lifetime: Option<Lifetime>, + + /// Fields of the struct. + fields: Vec<SequenceField>, +} + +impl DeriveSequence { + /// Parse [`DeriveInput`]. + pub fn new(input: DeriveInput) -> Self { + let data = match input.data { + syn::Data::Struct(data) => data, + _ => abort!( + input.ident, + "can't derive `Sequence` on this type: only `struct` types are allowed", + ), + }; + + // TODO(tarcieri): properly handle multiple lifetimes + let lifetime = input + .generics + .lifetimes() + .next() + .map(|lt| lt.lifetime.clone()); + + let type_attrs = TypeAttrs::parse(&input.attrs); + + let fields = data + .fields + .iter() + .map(|field| SequenceField::new(field, &type_attrs)) + .collect(); + + Self { + ident: input.ident, + lifetime, + fields, + } + } + + /// Lower the derived output into a [`TokenStream`]. + pub fn to_tokens(&self) -> TokenStream { + let ident = &self.ident; + + let lifetime = match self.lifetime { + Some(ref lifetime) => quote!(#lifetime), + None => default_lifetime(), + }; + + // Lifetime parameters + // TODO(tarcieri): support multiple lifetimes + let lt_params = self + .lifetime + .as_ref() + .map(|_| lifetime.clone()) + .unwrap_or_default(); + + let mut decode_body = Vec::new(); + let mut decode_result = Vec::new(); + let mut encode_body = Vec::new(); + + for field in &self.fields { + decode_body.push(field.to_decode_tokens()); + decode_result.push(&field.ident); + encode_body.push(field.to_encode_tokens()); + } + + quote! { + impl<#lifetime> ::der::DecodeValue<#lifetime> for #ident<#lt_params> { + fn decode_value<R: ::der::Reader<#lifetime>>( + reader: &mut R, + header: ::der::Header, + ) -> ::der::Result<Self> { + use ::der::{Decode as _, DecodeValue as _, Reader as _}; + + reader.read_nested(header.length, |reader| { + #(#decode_body)* + + Ok(Self { + #(#decode_result),* + }) + }) + } + } + + impl<#lifetime> ::der::Sequence<#lifetime> for #ident<#lt_params> { + fn fields<F, T>(&self, f: F) -> ::der::Result<T> + where + F: FnOnce(&[&dyn der::Encode]) -> ::der::Result<T>, + { + f(&[ + #(#encode_body),* + ]) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::DeriveSequence; + use crate::{Asn1Type, TagMode}; + use syn::parse_quote; + + /// X.509 SPKI `AlgorithmIdentifier`. + #[test] + fn algorithm_identifier_example() { + let input = parse_quote! { + #[derive(Sequence)] + pub struct AlgorithmIdentifier<'a> { + pub algorithm: ObjectIdentifier, + pub parameters: Option<Any<'a>>, + } + }; + + let ir = DeriveSequence::new(input); + assert_eq!(ir.ident, "AlgorithmIdentifier"); + assert_eq!(ir.lifetime.unwrap().to_string(), "'a"); + assert_eq!(ir.fields.len(), 2); + + let algorithm_field = &ir.fields[0]; + assert_eq!(algorithm_field.ident, "algorithm"); + assert_eq!(algorithm_field.attrs.asn1_type, None); + assert_eq!(algorithm_field.attrs.context_specific, None); + assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); + + let parameters_field = &ir.fields[1]; + assert_eq!(parameters_field.ident, "parameters"); + assert_eq!(parameters_field.attrs.asn1_type, None); + assert_eq!(parameters_field.attrs.context_specific, None); + assert_eq!(parameters_field.attrs.tag_mode, TagMode::Explicit); + } + + /// X.509 `SubjectPublicKeyInfo`. + #[test] + fn spki_example() { + let input = parse_quote! { + #[derive(Sequence)] + pub struct SubjectPublicKeyInfo<'a> { + pub algorithm: AlgorithmIdentifier<'a>, + + #[asn1(type = "BIT STRING")] + pub subject_public_key: &'a [u8], + } + }; + + let ir = DeriveSequence::new(input); + assert_eq!(ir.ident, "SubjectPublicKeyInfo"); + assert_eq!(ir.lifetime.unwrap().to_string(), "'a"); + assert_eq!(ir.fields.len(), 2); + + let algorithm_field = &ir.fields[0]; + assert_eq!(algorithm_field.ident, "algorithm"); + assert_eq!(algorithm_field.attrs.asn1_type, None); + assert_eq!(algorithm_field.attrs.context_specific, None); + assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); + + let subject_public_key_field = &ir.fields[1]; + assert_eq!(subject_public_key_field.ident, "subject_public_key"); + assert_eq!( + subject_public_key_field.attrs.asn1_type, + Some(Asn1Type::BitString) + ); + assert_eq!(subject_public_key_field.attrs.context_specific, None); + assert_eq!(subject_public_key_field.attrs.tag_mode, TagMode::Explicit); + } + + /// PKCS#8v2 `OneAsymmetricKey`. + /// + /// ```text + /// OneAsymmetricKey ::= SEQUENCE { + /// version Version, + /// privateKeyAlgorithm PrivateKeyAlgorithmIdentifier, + /// privateKey PrivateKey, + /// attributes [0] Attributes OPTIONAL, + /// ..., + /// [[2: publicKey [1] PublicKey OPTIONAL ]], + /// ... + /// } + /// + /// Version ::= INTEGER { v1(0), v2(1) } (v1, ..., v2) + /// + /// PrivateKeyAlgorithmIdentifier ::= AlgorithmIdentifier + /// + /// PrivateKey ::= OCTET STRING + /// + /// Attributes ::= SET OF Attribute + /// + /// PublicKey ::= BIT STRING + /// ``` + #[test] + fn pkcs8_example() { + let input = parse_quote! { + #[derive(Sequence)] + pub struct OneAsymmetricKey<'a> { + pub version: u8, + pub private_key_algorithm: AlgorithmIdentifier<'a>, + #[asn1(type = "OCTET STRING")] + pub private_key: &'a [u8], + #[asn1(context_specific = "0", extensible = "true", optional = "true")] + pub attributes: Option<SetOf<Any<'a>, 1>>, + #[asn1( + context_specific = "1", + extensible = "true", + optional = "true", + type = "BIT STRING" + )] + pub public_key: Option<&'a [u8]>, + } + }; + + let ir = DeriveSequence::new(input); + assert_eq!(ir.ident, "OneAsymmetricKey"); + assert_eq!(ir.lifetime.unwrap().to_string(), "'a"); + assert_eq!(ir.fields.len(), 5); + + let version_field = &ir.fields[0]; + assert_eq!(version_field.ident, "version"); + assert_eq!(version_field.attrs.asn1_type, None); + assert_eq!(version_field.attrs.context_specific, None); + assert_eq!(version_field.attrs.extensible, false); + assert_eq!(version_field.attrs.optional, false); + assert_eq!(version_field.attrs.tag_mode, TagMode::Explicit); + + let algorithm_field = &ir.fields[1]; + assert_eq!(algorithm_field.ident, "private_key_algorithm"); + assert_eq!(algorithm_field.attrs.asn1_type, None); + assert_eq!(algorithm_field.attrs.context_specific, None); + assert_eq!(algorithm_field.attrs.extensible, false); + assert_eq!(algorithm_field.attrs.optional, false); + assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); + + let private_key_field = &ir.fields[2]; + assert_eq!(private_key_field.ident, "private_key"); + assert_eq!( + private_key_field.attrs.asn1_type, + Some(Asn1Type::OctetString) + ); + assert_eq!(private_key_field.attrs.context_specific, None); + assert_eq!(private_key_field.attrs.extensible, false); + assert_eq!(private_key_field.attrs.optional, false); + assert_eq!(private_key_field.attrs.tag_mode, TagMode::Explicit); + + let attributes_field = &ir.fields[3]; + assert_eq!(attributes_field.ident, "attributes"); + assert_eq!(attributes_field.attrs.asn1_type, None); + assert_eq!( + attributes_field.attrs.context_specific, + Some("0".parse().unwrap()) + ); + assert_eq!(attributes_field.attrs.extensible, true); + assert_eq!(attributes_field.attrs.optional, true); + assert_eq!(attributes_field.attrs.tag_mode, TagMode::Explicit); + + let public_key_field = &ir.fields[4]; + assert_eq!(public_key_field.ident, "public_key"); + assert_eq!(public_key_field.attrs.asn1_type, Some(Asn1Type::BitString)); + assert_eq!( + public_key_field.attrs.context_specific, + Some("1".parse().unwrap()) + ); + assert_eq!(public_key_field.attrs.extensible, true); + assert_eq!(public_key_field.attrs.optional, true); + assert_eq!(public_key_field.attrs.tag_mode, TagMode::Explicit); + } + + /// `IMPLICIT` tagged example + #[test] + fn implicit_example() { + let input = parse_quote! { + #[asn1(tag_mode = "IMPLICIT")] + pub struct ImplicitSequence<'a> { + #[asn1(context_specific = "0", type = "BIT STRING")] + bit_string: BitString<'a>, + + #[asn1(context_specific = "1", type = "GeneralizedTime")] + time: GeneralizedTime, + + #[asn1(context_specific = "2", type = "UTF8String")] + utf8_string: String, + } + }; + + let ir = DeriveSequence::new(input); + assert_eq!(ir.ident, "ImplicitSequence"); + assert_eq!(ir.lifetime.unwrap().to_string(), "'a"); + assert_eq!(ir.fields.len(), 3); + + let bit_string = &ir.fields[0]; + assert_eq!(bit_string.ident, "bit_string"); + assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString)); + assert_eq!( + bit_string.attrs.context_specific, + Some("0".parse().unwrap()) + ); + assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit); + + let time = &ir.fields[1]; + assert_eq!(time.ident, "time"); + assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime)); + assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap())); + assert_eq!(time.attrs.tag_mode, TagMode::Implicit); + + let utf8_string = &ir.fields[2]; + assert_eq!(utf8_string.ident, "utf8_string"); + assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String)); + assert_eq!( + utf8_string.attrs.context_specific, + Some("2".parse().unwrap()) + ); + assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit); + } +} diff --git a/src/sequence/field.rs b/src/sequence/field.rs new file mode 100644 index 0000000..4f478f0 --- /dev/null +++ b/src/sequence/field.rs @@ -0,0 +1,358 @@ +//! Sequence field IR and lowerings + +use crate::{Asn1Type, FieldAttrs, TagMode, TagNumber, TypeAttrs}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::quote; +use syn::{Field, Ident, Path, Type}; + +/// "IR" for a field of a derived `Sequence`. +pub(super) struct SequenceField { + /// Variant name. + pub(super) ident: Ident, + + /// Field-level attributes. + pub(super) attrs: FieldAttrs, + + /// Field type + pub(super) field_type: Type, +} + +impl SequenceField { + /// Create a new [`SequenceField`] from the input [`Field`]. + pub(super) fn new(field: &Field, type_attrs: &TypeAttrs) -> Self { + let ident = field.ident.as_ref().cloned().unwrap_or_else(|| { + abort!( + field, + "no name on struct field i.e. tuple structs unsupported" + ) + }); + + let attrs = FieldAttrs::parse(&field.attrs, type_attrs); + + if attrs.asn1_type.is_some() && attrs.default.is_some() { + abort!( + ident, + "ASN.1 `type` and `default` options cannot be combined" + ); + } + + if attrs.default.is_some() && attrs.optional { + abort!( + ident, + "`optional` and `default` field qualifiers are mutually exclusive" + ); + } + + Self { + ident, + attrs, + field_type: field.ty.clone(), + } + } + + /// Derive code for decoding a field of a sequence. + pub(super) fn to_decode_tokens(&self) -> TokenStream { + let mut lowerer = LowerFieldDecoder::new(&self.attrs); + + if self.attrs.asn1_type.is_some() { + lowerer.apply_asn1_type(self.attrs.optional); + } + + if let Some(default) = &self.attrs.default { + // TODO(tarcieri): default in conjunction with ASN.1 types? + debug_assert!( + self.attrs.asn1_type.is_none(), + "`type` and `default` are mutually exclusive" + ); + + // TODO(tarcieri): support for context-specific fields with defaults? + if self.attrs.context_specific.is_none() { + lowerer.apply_default(default, &self.field_type); + } + } + + lowerer.into_tokens(&self.ident) + } + + /// Derive code for encoding a field of a sequence. + pub(super) fn to_encode_tokens(&self) -> TokenStream { + let mut lowerer = LowerFieldEncoder::new(&self.ident); + let attrs = &self.attrs; + + if let Some(ty) = &attrs.asn1_type { + // TODO(tarcieri): default in conjunction with ASN.1 types? + debug_assert!( + attrs.default.is_none(), + "`type` and `default` are mutually exclusive" + ); + lowerer.apply_asn1_type(ty, attrs.optional); + } + + if let Some(tag_number) = &attrs.context_specific { + lowerer.apply_context_specific(tag_number, &attrs.tag_mode, attrs.optional); + } + + if let Some(default) = &attrs.default { + debug_assert!( + !attrs.optional, + "`default`, and `optional` are mutually exclusive" + ); + lowerer.apply_default(&self.ident, default); + } + + lowerer.into_tokens() + } +} + +/// AST lowerer for field decoders. +struct LowerFieldDecoder { + /// Decoder-in-progress. + decoder: TokenStream, +} + +impl LowerFieldDecoder { + /// Create a new field decoder lowerer. + fn new(attrs: &FieldAttrs) -> Self { + Self { + decoder: attrs.decoder(), + } + } + + /// the field decoder to tokens. + fn into_tokens(self, ident: &Ident) -> TokenStream { + let decoder = self.decoder; + + quote! { + let #ident = #decoder; + } + } + + /// Apply the ASN.1 type (if defined). + fn apply_asn1_type(&mut self, optional: bool) { + let decoder = &self.decoder; + + self.decoder = if optional { + quote! { + #decoder.map(TryInto::try_into).transpose()? + } + } else { + quote! { + #decoder.try_into()? + } + } + } + + /// Handle default value for a type. + fn apply_default(&mut self, default: &Path, field_type: &Type) { + self.decoder = quote! { + Option::<#field_type>::decode(reader)?.unwrap_or_else(#default); + }; + } +} + +/// AST lowerer for field encoders. +struct LowerFieldEncoder { + /// Encoder-in-progress. + encoder: TokenStream, +} + +impl LowerFieldEncoder { + /// Create a new field encoder lowerer. + fn new(ident: &Ident) -> Self { + Self { + encoder: quote!(self.#ident), + } + } + + /// the field encoder to tokens. + fn into_tokens(self) -> TokenStream { + let encoder = self.encoder; + quote! { &#encoder } + } + + /// Apply the ASN.1 type (if defined). + fn apply_asn1_type(&mut self, asn1_type: &Asn1Type, optional: bool) { + let binding = &self.encoder; + + self.encoder = if optional { + let map_arg = quote!(field); + let encoder = asn1_type.encoder(&map_arg); + + quote! { + #binding.as_ref().map(|#map_arg| { + der::Result::Ok(#encoder) + }).transpose()? + } + } else { + let encoder = asn1_type.encoder(binding); + quote!(#encoder) + }; + } + + /// Handle default value for a type. + fn apply_default(&mut self, ident: &Ident, default: &Path) { + let encoder = &self.encoder; + + self.encoder = quote! { + if &self.#ident == &#default() { + None + } else { + Some(#encoder) + } + }; + } + + /// Make this field context-specific. + fn apply_context_specific( + &mut self, + tag_number: &TagNumber, + tag_mode: &TagMode, + optional: bool, + ) { + let encoder = &self.encoder; + let number_tokens = tag_number.to_tokens(); + let mode_tokens = tag_mode.to_tokens(); + + if optional { + self.encoder = quote! { + #encoder.as_ref().map(|field| { + ::der::asn1::ContextSpecificRef { + tag_number: #number_tokens, + tag_mode: #mode_tokens, + value: field, + } + }) + }; + } else { + self.encoder = quote! { + ::der::asn1::ContextSpecificRef { + tag_number: #number_tokens, + tag_mode: #mode_tokens, + value: &#encoder, + } + }; + } + } +} + +#[cfg(test)] +mod tests { + use super::SequenceField; + use crate::{FieldAttrs, TagMode, TagNumber}; + use proc_macro2::Span; + use quote::quote; + use syn::{punctuated::Punctuated, Ident, Path, PathSegment, Type, TypePath}; + + /// Create a [`Type::Path`]. + pub fn type_path(ident: Ident) -> Type { + let mut segments = Punctuated::new(); + segments.push_value(PathSegment { + ident, + arguments: Default::default(), + }); + + Type::Path(TypePath { + qself: None, + path: Path { + leading_colon: None, + segments, + }, + }) + } + + #[test] + fn simple() { + let span = Span::call_site(); + let ident = Ident::new("example_field", span); + + let attrs = FieldAttrs { + asn1_type: None, + context_specific: None, + default: None, + extensible: false, + optional: false, + tag_mode: TagMode::Explicit, + constructed: false, + }; + + let field_type = Ident::new("String", span); + + let field = SequenceField { + ident, + attrs, + field_type: type_path(field_type), + }; + + assert_eq!( + field.to_decode_tokens().to_string(), + quote! { + let example_field = reader.decode()?; + } + .to_string() + ); + + assert_eq!( + field.to_encode_tokens().to_string(), + quote! { + &self.example_field + } + .to_string() + ); + } + + #[test] + fn implicit() { + let span = Span::call_site(); + let ident = Ident::new("implicit_field", span); + + let attrs = FieldAttrs { + asn1_type: None, + context_specific: Some(TagNumber(0)), + default: None, + extensible: false, + optional: false, + tag_mode: TagMode::Implicit, + constructed: false, + }; + + let field_type = Ident::new("String", span); + + let field = SequenceField { + ident, + attrs, + field_type: type_path(field_type), + }; + + assert_eq!( + field.to_decode_tokens().to_string(), + quote! { + let implicit_field = ::der::asn1::ContextSpecific::<>::decode_implicit( + reader, + ::der::TagNumber::N0 + )? + .ok_or_else(|| { + der::Tag::ContextSpecific { + number: ::der::TagNumber::N0, + constructed: false + } + .value_error() + })? + .value; + } + .to_string() + ); + + assert_eq!( + field.to_encode_tokens().to_string(), + quote! { + &::der::asn1::ContextSpecificRef { + tag_number: ::der::TagNumber::N0, + tag_mode: ::der::TagMode::Implicit, + value: &self.implicit_field, + } + } + .to_string() + ); + } +} diff --git a/src/tag.rs b/src/tag.rs new file mode 100644 index 0000000..f2e39ec --- /dev/null +++ b/src/tag.rs @@ -0,0 +1,176 @@ +//! Tag-related functionality. + +use crate::Asn1Type; +use proc_macro2::TokenStream; +use quote::quote; +use std::{ + fmt::{self, Display}, + str::FromStr, +}; + +/// Tag "IR" type. +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +pub(crate) enum Tag { + /// Universal tags with an associated [`Asn1Type`]. + Universal(Asn1Type), + + /// Context-specific tags with an associated [`TagNumber`]. + ContextSpecific { + /// Is the inner ASN.1 type constructed? + constructed: bool, + + /// Context-specific tag number + number: TagNumber, + }, +} + +impl Tag { + /// Lower this [`Tag`] to a [`TokenStream`]. + pub fn to_tokens(self) -> TokenStream { + match self { + Tag::Universal(ty) => ty.tag(), + Tag::ContextSpecific { + constructed, + number, + } => { + let constructed = if constructed { + quote!(true) + } else { + quote!(false) + }; + + let number = number.to_tokens(); + + quote! { + ::der::Tag::ContextSpecific { + constructed: #constructed, + number: #number, + } + } + } + } + } +} + +/// Tagging modes: `EXPLICIT` versus `IMPLICIT`. +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +pub(crate) enum TagMode { + /// `EXPLICIT` tagging. + /// + /// Tag is added in addition to the inner tag of the type. + Explicit, + + /// `IMPLICIT` tagging. + /// + /// Tag replaces the existing tag of the inner type. + Implicit, +} + +impl TagMode { + /// Lower this [`TagMode`] to a [`TokenStream`] with the `der` + /// crate's corresponding enum variant for this tag mode. + pub fn to_tokens(self) -> TokenStream { + match self { + TagMode::Explicit => quote!(::der::TagMode::Explicit), + TagMode::Implicit => quote!(::der::TagMode::Implicit), + } + } +} + +impl FromStr for TagMode { + type Err = ParseError; + + fn from_str(s: &str) -> Result<Self, ParseError> { + match s { + "EXPLICIT" | "explicit" => Ok(TagMode::Explicit), + "IMPLICIT" | "implicit" => Ok(TagMode::Implicit), + _ => Err(ParseError), + } + } +} + +impl Default for TagMode { + fn default() -> TagMode { + TagMode::Explicit + } +} + +impl Display for TagMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TagMode::Explicit => f.write_str("EXPLICIT"), + TagMode::Implicit => f.write_str("IMPLICIT"), + } + } +} + +/// ASN.1 tag numbers (i.e. lower 5 bits of a [`Tag`]). +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +pub(crate) struct TagNumber(pub u8); + +impl TagNumber { + /// Maximum tag number supported (inclusive). + pub const MAX: u8 = 30; + + /// Get tokens describing this tag. + pub fn to_tokens(self) -> TokenStream { + match self.0 { + 0 => quote!(::der::TagNumber::N0), + 1 => quote!(::der::TagNumber::N1), + 2 => quote!(::der::TagNumber::N2), + 3 => quote!(::der::TagNumber::N3), + 4 => quote!(::der::TagNumber::N4), + 5 => quote!(::der::TagNumber::N5), + 6 => quote!(::der::TagNumber::N6), + 7 => quote!(::der::TagNumber::N7), + 8 => quote!(::der::TagNumber::N8), + 9 => quote!(::der::TagNumber::N9), + 10 => quote!(::der::TagNumber::N10), + 11 => quote!(::der::TagNumber::N11), + 12 => quote!(::der::TagNumber::N12), + 13 => quote!(::der::TagNumber::N13), + 14 => quote!(::der::TagNumber::N14), + 15 => quote!(::der::TagNumber::N15), + 16 => quote!(::der::TagNumber::N16), + 17 => quote!(::der::TagNumber::N17), + 18 => quote!(::der::TagNumber::N18), + 19 => quote!(::der::TagNumber::N19), + 20 => quote!(::der::TagNumber::N20), + 21 => quote!(::der::TagNumber::N21), + 22 => quote!(::der::TagNumber::N22), + 23 => quote!(::der::TagNumber::N23), + 24 => quote!(::der::TagNumber::N24), + 25 => quote!(::der::TagNumber::N25), + 26 => quote!(::der::TagNumber::N26), + 27 => quote!(::der::TagNumber::N27), + 28 => quote!(::der::TagNumber::N28), + 29 => quote!(::der::TagNumber::N29), + 30 => quote!(::der::TagNumber::N30), + _ => unreachable!("tag number out of range: {}", self), + } + } +} + +impl FromStr for TagNumber { + type Err = ParseError; + + fn from_str(s: &str) -> Result<Self, ParseError> { + let n = s.parse::<u8>().map_err(|_| ParseError)?; + + if n <= Self::MAX { + Ok(Self(n)) + } else { + Err(ParseError) + } + } +} + +impl Display for TagNumber { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Error type +#[derive(Debug)] +pub(crate) struct ParseError; diff --git a/src/value_ord.rs b/src/value_ord.rs new file mode 100644 index 0000000..6b08f7f --- /dev/null +++ b/src/value_ord.rs @@ -0,0 +1,144 @@ +//! Support for deriving the `ValueOrd` trait on enums and structs. +//! +//! This trait is used in conjunction with ASN.1 `SET OF` types to determine +//! the lexicographical order of their DER encodings. + +// TODO(tarcieri): enum support + +use crate::{FieldAttrs, TypeAttrs}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::quote; +use syn::{DeriveInput, Field, Ident, Lifetime, Variant}; + +/// Derive the `Enumerated` trait for an enum. +pub(crate) struct DeriveValueOrd { + /// Name of the enum. + ident: Ident, + + /// Lifetime of the struct. + lifetime: Option<Lifetime>, + + /// Fields of structs or enum variants. + fields: Vec<ValueField>, +} + +impl DeriveValueOrd { + /// Parse [`DeriveInput`]. + pub fn new(input: DeriveInput) -> Self { + let ident = input.ident; + let type_attrs = TypeAttrs::parse(&input.attrs); + + // TODO(tarcieri): properly handle multiple lifetimes + let lifetime = input + .generics + .lifetimes() + .next() + .map(|lt| lt.lifetime.clone()); + + let fields = match input.data { + syn::Data::Enum(data) => data + .variants + .into_iter() + .map(|variant| ValueField::new_enum(variant, &type_attrs)) + .collect(), + syn::Data::Struct(data) => data + .fields + .into_iter() + .map(|field| ValueField::new_struct(field, &type_attrs)) + .collect(), + _ => abort!( + ident, + "can't derive `ValueOrd` on this type: \ + only `enum` and `struct` types are allowed", + ), + }; + + Self { + ident, + lifetime, + fields, + } + } + + /// Lower the derived output into a [`TokenStream`]. + pub fn to_tokens(&self) -> TokenStream { + let ident = &self.ident; + + // Lifetime parameters + // TODO(tarcieri): support multiple lifetimes + let lt_params = self + .lifetime + .as_ref() + .map(|lt| vec![lt.clone()]) + .unwrap_or_default(); + + let mut body = Vec::new(); + + for field in &self.fields { + body.push(field.to_tokens()); + } + + quote! { + impl<#(#lt_params)*> ::der::ValueOrd for #ident<#(#lt_params)*> { + fn value_cmp(&self, other: &Self) -> ::der::Result<::core::cmp::Ordering> { + #[allow(unused_imports)] + use ::der::DerOrd; + + #(#body)* + + Ok(::core::cmp::Ordering::Equal) + } + } + } + } +} + +struct ValueField { + /// Name of the field + ident: Ident, + + /// Field-level attributes. + attrs: FieldAttrs, +} + +impl ValueField { + /// Create from an `enum` variant. + fn new_enum(variant: Variant, _: &TypeAttrs) -> Self { + abort!( + variant, + "deriving `ValueOrd` only presently supported for structs" + ); + } + + /// Create from a `struct` field. + fn new_struct(field: Field, type_attrs: &TypeAttrs) -> Self { + let ident = field + .ident + .as_ref() + .cloned() + .unwrap_or_else(|| abort!(&field, "tuple structs are not supported")); + + let attrs = FieldAttrs::parse(&field.attrs, type_attrs); + Self { ident, attrs } + } + + /// Lower to [`TokenStream`]. + fn to_tokens(&self) -> TokenStream { + let ident = &self.ident; + let mut binding1 = quote!(self.#ident); + let mut binding2 = quote!(other.#ident); + + if let Some(ty) = &self.attrs.asn1_type { + binding1 = ty.encoder(&binding1); + binding2 = ty.encoder(&binding2); + } + + quote! { + match #binding1.der_cmp(&#binding2)? { + ::core::cmp::Ordering::Equal => (), + other => return Ok(other), + } + } + } +} |