From c0e8984ccc60b6294da67db448e2d994a140095d Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Sat, 2 Nov 2024 22:33:55 -0400 Subject: [PATCH] Use 'perfect derives' See https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/ and the test case for more details. --- asn1_derive/src/lib.rs | 83 +++++++++++++++++++++++++++++++++++++++--- src/types.rs | 41 +++++++++++++++++++-- tests/derive_test.rs | 20 ++++++++++ 3 files changed, 135 insertions(+), 9 deletions(-) diff --git a/asn1_derive/src/lib.rs b/asn1_derive/src/lib.rs index ad7cba3..be093a6 100644 --- a/asn1_derive/src/lib.rs +++ b/asn1_derive/src/lib.rs @@ -14,7 +14,9 @@ pub fn derive_asn1_read(input: proc_macro::TokenStream) -> proc_macro::TokenStre let lifetime_name = add_lifetime_if_none(&mut generics); add_bounds( &mut generics, + all_field_types(&input.data), syn::parse_quote!(asn1::Asn1Readable<#lifetime_name>), + syn::parse_quote!(asn1::Asn1DefinedByReadable<#lifetime_name, asn1::ObjectIdentifier>), ); let (impl_generics, _, where_clause) = generics.split_for_impl(); @@ -58,7 +60,12 @@ pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStr let mut input = syn::parse_macro_input!(input as syn::DeriveInput); let name = input.ident; - add_bounds(&mut input.generics, syn::parse_quote!(asn1::Asn1Writable)); + add_bounds( + &mut input.generics, + all_field_types(&input.data), + syn::parse_quote!(asn1::Asn1Writable), + syn::parse_quote!(asn1::Asn1DefinedByWritable), + ); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let expanded = match input.data { @@ -260,11 +267,77 @@ fn add_lifetime_if_none(generics: &mut syn::Generics) -> syn::Lifetime { generics.lifetimes().next().unwrap().lifetime.clone() } -fn add_bounds(generics: &mut syn::Generics, bound: syn::TypeParamBound) { - for param in &mut generics.params { - if let syn::GenericParam::Type(ref mut type_param) = param { - type_param.bounds.push(bound.clone()); +fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, bool)> { + let mut field_types = vec![]; + match data { + syn::Data::Struct(v) => { + add_field_types(&mut field_types, &v.fields); + } + syn::Data::Enum(v) => { + for variant in &v.variants { + add_field_types(&mut field_types, &variant.fields); + } + } + syn::Data::Union(_) => panic!("Unions not supported"), + } + field_types +} + +fn add_field_types(field_types: &mut Vec<(syn::Type, bool)>, fields: &syn::Fields) { + match fields { + syn::Fields::Named(v) => { + for f in &v.named { + add_field_type(field_types, f); + } } + syn::Fields::Unnamed(v) => { + for f in &v.unnamed { + add_field_type(field_types, f); + } + } + syn::Fields::Unit => {} + } +} + +fn add_field_type(field_types: &mut Vec<(syn::Type, bool)>, f: &syn::Field) { + let (op_type, _) = extract_field_properties(&f.attrs); + field_types.push((f.ty.clone(), matches!(op_type, OpType::DefinedBy(_)))); +} + +fn add_bounds( + generics: &mut syn::Generics, + field_types: Vec<(syn::Type, bool)>, + bound: syn::TypeParamBound, + defined_by_bound: syn::TypeParamBound, +) { + let where_clause = if field_types.is_empty() { + return; + } else { + generics + .where_clause + .get_or_insert_with(|| syn::WhereClause { + where_token: Default::default(), + predicates: syn::punctuated::Punctuated::new(), + }) + }; + + for (f, is_defined_by) in field_types { + where_clause + .predicates + .push(syn::WherePredicate::Type(syn::PredicateType { + lifetimes: None, + bounded_ty: f, + colon_token: Default::default(), + bounds: { + let mut p = syn::punctuated::Punctuated::new(); + if is_defined_by { + p.push(defined_by_bound.clone()); + } else { + p.push(bound.clone()); + } + p + }, + })) } } diff --git a/src/types.rs b/src/types.rs index 76f9102..550df8f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1842,13 +1842,28 @@ impl DefinedByMarker { } } +impl<'a, T: Asn1Readable<'a>> Asn1Readable<'a> for DefinedByMarker { + fn parse(_: &mut Parser<'a>) -> ParseResult { + panic!("parse() should never be called on a DefinedByMarker") + } + fn can_parse(_: Tag) -> bool { + panic!("can_parse() should never be called on a DefinedByMarker") + } +} + +impl Asn1Writable for DefinedByMarker { + fn write(&self, _: &mut Writer<'_>) -> WriteResult { + panic!("write() should never be called on a DefinedByMarker") + } +} + #[cfg(test)] mod tests { use crate::{ - parse_single, BigInt, BigUint, DateTime, DefinedByMarker, Enumerated, GeneralizedTime, - IA5String, ObjectIdentifier, OctetStringEncoded, OwnedBigInt, OwnedBigUint, ParseError, - ParseErrorKind, PrintableString, SequenceOf, SequenceOfWriter, SetOf, SetOfWriter, Tag, - Tlv, UtcTime, Utf8String, VisibleString, X509GeneralizedTime, + parse_single, Asn1Readable, Asn1Writable, BigInt, BigUint, DateTime, DefinedByMarker, + Enumerated, GeneralizedTime, IA5String, ObjectIdentifier, OctetStringEncoded, OwnedBigInt, + OwnedBigUint, ParseError, ParseErrorKind, PrintableString, SequenceOf, SequenceOfWriter, + SetOf, SetOfWriter, Tag, Tlv, UtcTime, Utf8String, VisibleString, X509GeneralizedTime, }; use crate::{Explicit, Implicit}; #[cfg(not(feature = "std"))] @@ -2193,4 +2208,22 @@ mod tests { fn test_const() { const _: DefinedByMarker = DefinedByMarker::marker(); } + + #[test] + #[should_panic] + fn test_defined_by_marker_parse() { + crate::parse(b"", DefinedByMarker::::parse).unwrap(); + } + + #[test] + #[should_panic] + fn test_defined_by_marker_can_parse() { + DefinedByMarker::::can_parse(Tag::primitive(2)); + } + + #[test] + #[should_panic] + fn test_defined_by_marker_write() { + crate::write(|w| DefinedByMarker::::marker().write(w)).unwrap(); + } } diff --git a/tests/derive_test.rs b/tests/derive_test.rs index 0db53bd..07538fd 100644 --- a/tests/derive_test.rs +++ b/tests/derive_test.rs @@ -720,3 +720,23 @@ fn test_generics() { b"\x30\x08\x30\x06\x01\x01\xff\x01\x01\xff" ) } + +#[test] +fn test_perfect_derive() { + trait X { + type Type: PartialEq + std::fmt::Debug; + } + + #[derive(PartialEq, Debug)] + struct Op; + impl X for Op { + type Type = u64; + } + + #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + struct S { + value: T::Type, + } + + assert_roundtrips::>(&[(Ok(S { value: 12 }), b"\x30\x03\x02\x01\x0c")]); +}