Skip to content

Commit

Permalink
Use 'perfect derives' (#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored Nov 17, 2024
1 parent 18ef6ea commit 663d616
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 9 deletions.
83 changes: 78 additions & 5 deletions asn1_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ pub fn derive_asn1_read(input: proc_macro::TokenStream) -> proc_macro::TokenStre
let lifetime_name = add_lifetime_if_none(&mut generics);
add_bounds(
&mut generics,
all_field_types(&input.data),
syn::parse_quote!(asn1::Asn1Readable<#lifetime_name>),
syn::parse_quote!(asn1::Asn1DefinedByReadable<#lifetime_name, asn1::ObjectIdentifier>),
);
let (impl_generics, _, where_clause) = generics.split_for_impl();

Expand Down Expand Up @@ -58,7 +60,12 @@ pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStr
let mut input = syn::parse_macro_input!(input as syn::DeriveInput);

let name = input.ident;
add_bounds(&mut input.generics, syn::parse_quote!(asn1::Asn1Writable));
add_bounds(
&mut input.generics,
all_field_types(&input.data),
syn::parse_quote!(asn1::Asn1Writable),
syn::parse_quote!(asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier>),
);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let expanded = match input.data {
Expand Down Expand Up @@ -260,11 +267,77 @@ fn add_lifetime_if_none(generics: &mut syn::Generics) -> syn::Lifetime {
generics.lifetimes().next().unwrap().lifetime.clone()
}

fn add_bounds(generics: &mut syn::Generics, bound: syn::TypeParamBound) {
for param in &mut generics.params {
if let syn::GenericParam::Type(ref mut type_param) = param {
type_param.bounds.push(bound.clone());
fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, bool)> {
let mut field_types = vec![];
match data {
syn::Data::Struct(v) => {
add_field_types(&mut field_types, &v.fields);
}
syn::Data::Enum(v) => {
for variant in &v.variants {
add_field_types(&mut field_types, &variant.fields);
}
}
syn::Data::Union(_) => panic!("Unions not supported"),
}
field_types
}

fn add_field_types(field_types: &mut Vec<(syn::Type, bool)>, fields: &syn::Fields) {
match fields {
syn::Fields::Named(v) => {
for f in &v.named {
add_field_type(field_types, f);
}
}
syn::Fields::Unnamed(v) => {
for f in &v.unnamed {
add_field_type(field_types, f);
}
}
syn::Fields::Unit => {}
}
}

fn add_field_type(field_types: &mut Vec<(syn::Type, bool)>, f: &syn::Field) {
let (op_type, _) = extract_field_properties(&f.attrs);
field_types.push((f.ty.clone(), matches!(op_type, OpType::DefinedBy(_))));
}

fn add_bounds(
generics: &mut syn::Generics,
field_types: Vec<(syn::Type, bool)>,
bound: syn::TypeParamBound,
defined_by_bound: syn::TypeParamBound,
) {
let where_clause = if field_types.is_empty() {
return;
} else {
generics
.where_clause
.get_or_insert_with(|| syn::WhereClause {
where_token: Default::default(),
predicates: syn::punctuated::Punctuated::new(),
})
};

for (f, is_defined_by) in field_types {
where_clause
.predicates
.push(syn::WherePredicate::Type(syn::PredicateType {
lifetimes: None,
bounded_ty: f,
colon_token: Default::default(),
bounds: {
let mut p = syn::punctuated::Punctuated::new();
if is_defined_by {
p.push(defined_by_bound.clone());
} else {
p.push(bound.clone());
}
p
},
}))
}
}

Expand Down
41 changes: 37 additions & 4 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1842,13 +1842,28 @@ impl<T> DefinedByMarker<T> {
}
}

impl<'a, T: Asn1Readable<'a>> Asn1Readable<'a> for DefinedByMarker<T> {
fn parse(_: &mut Parser<'a>) -> ParseResult<Self> {
panic!("parse() should never be called on a DefinedByMarker")
}
fn can_parse(_: Tag) -> bool {
panic!("can_parse() should never be called on a DefinedByMarker")
}
}

impl<T: Asn1Writable> Asn1Writable for DefinedByMarker<T> {
fn write(&self, _: &mut Writer<'_>) -> WriteResult {
panic!("write() should never be called on a DefinedByMarker")
}
}

#[cfg(test)]
mod tests {
use crate::{
parse_single, BigInt, BigUint, DateTime, DefinedByMarker, Enumerated, GeneralizedTime,
IA5String, ObjectIdentifier, OctetStringEncoded, OwnedBigInt, OwnedBigUint, ParseError,
ParseErrorKind, PrintableString, SequenceOf, SequenceOfWriter, SetOf, SetOfWriter, Tag,
Tlv, UtcTime, Utf8String, VisibleString, X509GeneralizedTime,
parse_single, Asn1Readable, Asn1Writable, BigInt, BigUint, DateTime, DefinedByMarker,
Enumerated, GeneralizedTime, IA5String, ObjectIdentifier, OctetStringEncoded, OwnedBigInt,
OwnedBigUint, ParseError, ParseErrorKind, PrintableString, SequenceOf, SequenceOfWriter,
SetOf, SetOfWriter, Tag, Tlv, UtcTime, Utf8String, VisibleString, X509GeneralizedTime,
};
use crate::{Explicit, Implicit};
#[cfg(not(feature = "std"))]
Expand Down Expand Up @@ -2193,4 +2208,22 @@ mod tests {
fn test_const() {
const _: DefinedByMarker<ObjectIdentifier> = DefinedByMarker::marker();
}

#[test]
#[should_panic]
fn test_defined_by_marker_parse() {
crate::parse(b"", DefinedByMarker::<ObjectIdentifier>::parse).unwrap();
}

#[test]
#[should_panic]
fn test_defined_by_marker_can_parse() {
DefinedByMarker::<ObjectIdentifier>::can_parse(Tag::primitive(2));
}

#[test]
#[should_panic]
fn test_defined_by_marker_write() {
crate::write(|w| DefinedByMarker::<ObjectIdentifier>::marker().write(w)).unwrap();
}
}
20 changes: 20 additions & 0 deletions tests/derive_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,23 @@ fn test_generics() {
b"\x30\x08\x30\x06\x01\x01\xff\x01\x01\xff"
)
}

#[test]
fn test_perfect_derive() {
trait X {
type Type: PartialEq + std::fmt::Debug;
}

#[derive(PartialEq, Debug)]
struct Op;
impl X for Op {
type Type = u64;
}

#[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)]
struct S<T: X> {
value: T::Type,
}

assert_roundtrips::<S<Op>>(&[(Ok(S { value: 12 }), b"\x30\x03\x02\x01\x0c")]);
}

0 comments on commit 663d616

Please sign in to comment.