Skip to content

der: use Reader<'a> as input for Decode::decode #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions der/derive/src/asn1_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ impl Asn1Type {
/// Get a `der::Decoder` object for a particular ASN.1 type
pub fn decoder(self) -> TokenStream {
match self {
Asn1Type::BitString => quote!(decoder.bit_string()?),
Asn1Type::Ia5String => quote!(decoder.ia5_string()?),
Asn1Type::GeneralizedTime => quote!(decoder.generalized_time()?),
Asn1Type::OctetString => quote!(decoder.octet_string()?),
Asn1Type::PrintableString => quote!(decoder.printable_string()?),
Asn1Type::UtcTime => quote!(decoder.utc_time()?),
Asn1Type::Utf8String => quote!(decoder.utf8_string()?),
Asn1Type::BitString => quote!(::der::asn1::BitString::decode(reader)?),
Asn1Type::Ia5String => quote!(::der::asn1::Ia5String::decode(reader)?),
Asn1Type::GeneralizedTime => quote!(::der::asn1::GeneralizedTime::decode(reader)?),
Asn1Type::OctetString => quote!(::der::asn1::OctetString::decode(reader)?),
Asn1Type::PrintableString => quote!(::der::asn1::PrintableString::decode(reader)?),
Asn1Type::UtcTime => quote!(::der::asn1::UtcTime::decode(reader)?),
Asn1Type::Utf8String => quote!(::der::asn1::Utf8String::decode(reader)?),
}
}

Expand Down
17 changes: 9 additions & 8 deletions der/derive/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ impl FieldAttrs {
pub fn parse(attrs: &[Attribute], type_attrs: &TypeAttrs) -> Self {
let mut asn1_type = None;
let mut context_specific = None;

let mut default = None;
let mut extensible = None;
let mut optional = None;
Expand Down Expand Up @@ -203,13 +202,13 @@ impl FieldAttrs {
if self.extensible || self.is_optional() {
quote! {
::der::asn1::ContextSpecific::<#type_params>::decode_explicit(
decoder,
reader,
#tag_number
)?
}
} else {
quote! {
match ::der::asn1::ContextSpecific::<#type_params>::decode(decoder)? {
match ::der::asn1::ContextSpecific::<#type_params>::decode(reader)? {
field if field.tag_number == #tag_number => Some(field),
_ => None
}
Expand All @@ -219,7 +218,7 @@ impl FieldAttrs {
TagMode::Implicit => {
quote! {
::der::asn1::ContextSpecific::<#type_params>::decode_implicit(
decoder,
reader,
#tag_number
)?
}
Expand All @@ -246,13 +245,15 @@ impl FieldAttrs {
}
} else if let Some(default) = &self.default {
let type_params = self.asn1_type.map(|ty| ty.type_path()).unwrap_or_default();
self.asn1_type.map(|ty| ty.decoder()).unwrap_or_else(
|| quote!(decoder.decode::<Option<#type_params>>()?.unwrap_or_else(#default)),
)
self.asn1_type.map(|ty| ty.decoder()).unwrap_or_else(|| {
quote! {
Option::<#type_params>::decode(reader)?.unwrap_or_else(#default),
}
})
} else {
self.asn1_type
.map(|ty| ty.decoder())
.unwrap_or_else(|| quote!(decoder.decode()?))
.unwrap_or_else(|| quote!(reader.decode()?))
}
}

Expand Down
13 changes: 6 additions & 7 deletions der/derive/src/choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
mod variant;

use self::variant::ChoiceVariant;
use crate::TypeAttrs;
use crate::{default_lifetime, TypeAttrs};
use proc_macro2::TokenStream;
use proc_macro_error::abort;
use quote::quote;
Expand Down Expand Up @@ -59,10 +59,9 @@ impl DeriveChoice {
pub fn to_tokens(&self) -> TokenStream {
let ident = &self.ident;

// Explicit lifetime or `'_`
let lifetime = match self.lifetime {
Some(ref lifetime) => quote!(#lifetime),
None => quote!('_),
None => default_lifetime(),
};

// Lifetime parameters
Expand All @@ -88,16 +87,16 @@ impl DeriveChoice {
}

quote! {
impl<#lt_params> ::der::Choice<#lifetime> for #ident<#lt_params> {
impl<#lifetime> ::der::Choice<#lifetime> for #ident<#lt_params> {
fn can_decode(tag: ::der::Tag) -> bool {
matches!(tag, #(#can_decode_body)|*)
}
}

impl<#lt_params> ::der::Decode<#lifetime> for #ident<#lt_params> {
fn decode(decoder: &mut ::der::Decoder<#lifetime>) -> ::der::Result<Self> {
impl<#lifetime> ::der::Decode<#lifetime> for #ident<#lt_params> {
fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
use der::Reader as _;
match decoder.peek_tag()? {
match reader.peek_tag()? {
#(#decode_body)*
actual => Err(der::ErrorKind::TagUnexpected {
expected: None,
Expand Down
8 changes: 4 additions & 4 deletions der/derive/src/choice/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ mod tests {
variant.to_decode_tokens().to_string(),
quote! {
::der::Tag::Utf8String => Ok(Self::ExampleVariant(
decoder.decode()?
reader.decode()?
)),
}
.to_string()
Expand Down Expand Up @@ -214,7 +214,7 @@ mod tests {
variant.to_decode_tokens().to_string(),
quote! {
::der::Tag::Utf8String => Ok(Self::ExampleVariant(
decoder.utf8_string()?
::der::asn1::Utf8String::decode(reader)?
.try_into()?
)),
}
Expand Down Expand Up @@ -273,7 +273,7 @@ mod tests {
constructed: #constructed,
number: #tag_number,
} => Ok(Self::ExplicitVariant(
match ::der::asn1::ContextSpecific::<>::decode(decoder)? {
match ::der::asn1::ContextSpecific::<>::decode(reader)? {
field if field.tag_number == #tag_number => Some(field),
_ => None
}
Expand Down Expand Up @@ -359,7 +359,7 @@ mod tests {
number: #tag_number,
} => Ok(Self::ImplicitVariant(
::der::asn1::ContextSpecific::<>::decode_implicit(
decoder,
reader,
#tag_number
)?
.ok_or_else(|| {
Expand Down
11 changes: 6 additions & 5 deletions der/derive/src/enumerated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to
//! enum variants.

use crate::ATTR_NAME;
use crate::{default_lifetime, ATTR_NAME};
use proc_macro2::TokenStream;
use proc_macro_error::abort;
use quote::quote;
Expand Down Expand Up @@ -102,6 +102,7 @@ impl DeriveEnumerated {

/// Lower the derived output into a [`TokenStream`].
pub fn to_tokens(&self) -> TokenStream {
let default_lifetime = default_lifetime();
let ident = &self.ident;
let repr = &self.repr;
let tag = match self.integer {
Expand All @@ -115,12 +116,12 @@ impl DeriveEnumerated {
}

quote! {
impl ::der::DecodeValue<'_> for #ident {
fn decode_value(
decoder: &mut ::der::Decoder<'_>,
impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
fn decode_value<R: ::der::Reader<#default_lifetime>>(
reader: &mut R,
header: ::der::Header
) -> ::der::Result<Self> {
<#repr as ::der::DecodeValue>::decode_value(decoder, header)?.try_into()
<#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
}
}

Expand Down
10 changes: 9 additions & 1 deletion der/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,16 @@ use crate::{
value_ord::DeriveValueOrd,
};
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro_error::proc_macro_error;
use syn::{parse_macro_input, DeriveInput};
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Lifetime};

/// Get the default lifetime.
fn default_lifetime() -> proc_macro2::TokenStream {
let lifetime = Lifetime::new("'__der_lifetime", Span::call_site());
quote!(#lifetime)
}

/// Derive the [`Choice`][1] trait on an `enum`.
///
Expand Down
18 changes: 9 additions & 9 deletions der/derive/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

mod field;

use crate::TypeAttrs;
use crate::{default_lifetime, TypeAttrs};
use field::SequenceField;
use proc_macro2::TokenStream;
use proc_macro_error::abort;
Expand Down Expand Up @@ -59,10 +59,9 @@ impl DeriveSequence {
pub fn to_tokens(&self) -> TokenStream {
let ident = &self.ident;

// Explicit lifetime or `'_`
let lifetime = match self.lifetime {
Some(ref lifetime) => quote!(#lifetime),
None => quote!('_),
None => default_lifetime(),
};

// Lifetime parameters
Expand All @@ -84,13 +83,14 @@ impl DeriveSequence {
}

quote! {
impl<#lt_params> ::der::DecodeValue<#lifetime> for #ident<#lt_params> {
fn decode_value(
decoder: &mut ::der::Decoder<#lifetime>,
impl<#lifetime> ::der::DecodeValue<#lifetime> for #ident<#lt_params> {
fn decode_value<R: ::der::Reader<#lifetime>>(
reader: &mut R,
header: ::der::Header,
) -> ::der::Result<Self> {
use ::der::DecodeValue;
::der::asn1::SequenceRef::decode_value(decoder, header)?.decode_body(|decoder| {
use ::der::{Decode as _, DecodeValue as _, Reader as _};

reader.read_nested(header.length, |reader| {
#(#decode_body)*

Ok(Self {
Expand All @@ -100,7 +100,7 @@ impl DeriveSequence {
}
}

impl<#lt_params> ::der::Sequence<#lifetime> for #ident<#lt_params> {
impl<#lifetime> ::der::Sequence<#lifetime> for #ident<#lt_params> {
fn fields<F, T>(&self, f: F) -> ::der::Result<T>
where
F: FnOnce(&[&dyn der::Encode]) -> ::der::Result<T>,
Expand Down
8 changes: 4 additions & 4 deletions der/derive/src/sequence/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ impl LowerFieldDecoder {
/// Handle default value for a type.
fn apply_default(&mut self, default: &Path, field_type: &Type) {
self.decoder = quote! {
decoder.decode::<Option<#field_type>>()?.unwrap_or_else(#default);
}
Option::<#field_type>::decode(reader)?.unwrap_or_else(#default);
};
}
}

Expand Down Expand Up @@ -287,7 +287,7 @@ mod tests {
assert_eq!(
field.to_decode_tokens().to_string(),
quote! {
let example_field = decoder.decode()?;
let example_field = reader.decode()?;
}
.to_string()
);
Expand Down Expand Up @@ -328,7 +328,7 @@ mod tests {
field.to_decode_tokens().to_string(),
quote! {
let implicit_field = ::der::asn1::ContextSpecific::<>::decode_implicit(
decoder,
reader,
::der::TagNumber::N0
)?
.ok_or_else(|| {
Expand Down
19 changes: 3 additions & 16 deletions der/src/arrayvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,9 @@ impl<T, const N: usize> ArrayVec<T, N> {
self.length.checked_sub(1).and_then(|n| self.get(n))
}

/// Try to convert this [`ArrayVec`] into a `[T; N]`.
///
/// Returns `None` if the [`ArrayVec`] does not contain `N` elements.
pub fn try_into_array(self) -> Result<[T; N]> {
if self.length != N {
return Err(ErrorKind::Incomplete {
expected_len: N.try_into()?,
actual_len: self.length.try_into()?,
}
.into());
}

Ok(self.elements.map(|elem| match elem {
Some(e) => e,
None => unreachable!(),
}))
/// Extract the inner array.
pub fn into_array(self) -> [Option<T>; N] {
self.elements
}
}

Expand Down
9 changes: 5 additions & 4 deletions der/src/asn1/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use crate::{
asn1::*, ByteSlice, Choice, Decode, DecodeValue, Decoder, DerOrd, EncodeValue, Error,
ErrorKind, FixedTag, Header, Length, Result, Tag, Tagged, ValueOrd, Writer,
ErrorKind, FixedTag, Header, Length, Reader, Result, Tag, Tagged, ValueOrd, Writer,
};
use core::cmp::Ordering;

Expand Down Expand Up @@ -153,11 +153,12 @@ impl<'a> Choice<'a> for Any<'a> {
}

impl<'a> Decode<'a> for Any<'a> {
fn decode(decoder: &mut Decoder<'a>) -> Result<Any<'a>> {
let header = Header::decode(decoder)?;
fn decode<R: Reader<'a>>(reader: &mut R) -> Result<Any<'a>> {
let header = Header::decode(reader)?;

Ok(Self {
tag: header.tag,
value: ByteSlice::decode_value(decoder, header)?,
value: ByteSlice::decode_value(reader, header)?,
})
}
}
Expand Down
18 changes: 9 additions & 9 deletions der/src/asn1/bit_string.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! ASN.1 `BIT STRING` support.

use crate::{
asn1::Any, ByteSlice, DecodeValue, Decoder, DerOrd, EncodeValue, Error, ErrorKind, FixedTag,
Header, Length, Reader, Result, Tag, ValueOrd, Writer,
asn1::Any, ByteSlice, DecodeValue, DerOrd, EncodeValue, Error, ErrorKind, FixedTag, Header,
Length, Reader, Result, Tag, ValueOrd, Writer,
};
use core::{cmp::Ordering, iter::FusedIterator};

Expand Down Expand Up @@ -116,14 +116,14 @@ impl<'a> BitString<'a> {
}

impl<'a> DecodeValue<'a> for BitString<'a> {
fn decode_value(decoder: &mut Decoder<'a>, header: Header) -> Result<Self> {
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let header = Header {
tag: header.tag,
length: (header.length - Length::ONE)?,
};

let unused_bits = decoder.read_byte()?;
let inner = ByteSlice::decode_value(decoder, header)?;
let unused_bits = reader.read_byte()?;
let inner = ByteSlice::decode_value(reader, header)?;
Self::new(unused_bits, inner.as_slice())
}
}
Expand Down Expand Up @@ -239,12 +239,12 @@ where
T::Type: From<bool>,
T::Type: core::ops::Shl<usize, Output = T::Type>,
{
fn decode_value(decoder: &mut Decoder<'a>, header: Header) -> Result<Self> {
let position = decoder.position();

let bits = BitString::decode_value(decoder, header)?;
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let position = reader.position();
let bits = BitString::decode_value(reader, header)?;

let mut flags = T::none().bits();

if bits.bit_len() > core::mem::size_of_val(&flags) * 8 {
return Err(Error::new(ErrorKind::Overlength, position));
}
Expand Down
Loading