From 30c38095bf82431acbef8736c384b0d0ea6283f9 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Mon, 18 Jul 2022 11:02:50 +0200 Subject: [PATCH] Redesign error handling in the decode module (#65) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit changes the entire way that errors are handled while decoding. Sources now provide errors related to failing to get more data when more is requested via the associated Source::Error type. This has been renamed from Source::Err for consistency with other such associated types. A new type ContentError is introduced for errors where incorrectly encoded data is encountered, i.e., actual encoding errors or data that isn't following the ASN.1 definition or informal profiles. This type wraps an error message that currently can be either a static str or a boxed Display trait object but can be extended if necessary later. Since both source and content errors can happen during decoding, the compound type DecodeError wraps both these types. For content errors it also stores where in the source the error happened to facilitate debugging. Currently, this type only allows displaying the error. It can be extended if additional access to the internally stored error is necessary. The various content decoding methods now return this DecodeError. In order to implement all these changes, the Source trait had to be adjusted. First, it needs to be able to provide the current position. This meant that it couldn’t be implemented on Bytes and &[u8] directly anymore. Therefore, the new trait IntoSource allows to convert a type into its Source implementation. Finally, the trait's methods got cleaned up a bit. Specifically, Source::advance now only allows advancing as far as the length most recently returned by Source::request which also means it cannot fail anymore but needs to panic if the length is too large. This consistent to how Source::bytes already behaves. --- Cargo.toml | 8 +- src/captured.rs | 52 +++- src/debug.rs | 50 ---- src/decode/content.rs | 598 +++++++++++++++++++++++++-------------- src/decode/error.rs | 159 ++++++++++- src/decode/mod.rs | 35 ++- src/decode/source.rs | 546 ++++++++++++++++++++++++----------- src/encode/values.rs | 2 +- src/guide/decode.rs | 13 +- src/int.rs | 256 ++++++++++------- src/length.rs | 18 +- src/lib.rs | 2 - src/mode.rs | 11 +- src/oid.rs | 41 ++- src/string/bit.rs | 34 ++- src/string/octet.rs | 143 ++++++---- src/string/restricted.rs | 10 +- src/tag.rs | 113 +++++--- 18 files changed, 1369 insertions(+), 722 deletions(-) delete mode 100644 src/debug.rs diff --git a/Cargo.toml b/Cargo.toml index ee226ed..b8d8373 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bcder" -version = "0.6.2-dev" +version = "0.7.0-dev" edition = "2018" authors = ["The NLnet Labs RPKI Team "] description = "Handling of data encoded in BER, CER, and DER." @@ -11,12 +11,6 @@ categories = ["encoding", "network-programming", "parsing"] license = "BSD-3-Clause" [dependencies] -backtrace = { version = "^0.3.15", optional = true } bytes = "^1.0" smallvec = "^1.1" -[features] -# Print a backtrace when a parsing error occurs. This feature is intended for -# development use exclusively and MUST NOT be enabled in release builds. -extra-debug = [ "backtrace" ] - diff --git a/src/captured.rs b/src/captured.rs index 579c252..80360f0 100644 --- a/src/captured.rs +++ b/src/captured.rs @@ -2,9 +2,10 @@ //! //! This is a private module. Its public items are re-exported by the parent. -use std::{fmt, io, ops}; +use std::{fmt, io, mem, ops}; use bytes::{Bytes, BytesMut}; use crate::{decode, encode}; +use crate::decode::{BytesSource, DecodeError, IntoSource, Pos}; use crate::mode::Mode; @@ -39,8 +40,14 @@ use crate::mode::Mode; /// [`Mode`]: ../enum.Mode.html #[derive(Clone)] pub struct Captured { + /// The captured data. bytes: Bytes, + + /// The encoding mode of the captured data. mode: Mode, + + /// The start position of the data in the original source. + start: Pos, } impl Captured { @@ -49,8 +56,8 @@ impl Captured { /// Because we can’t guarantee that the bytes are properly encoded, we /// keep this function crate public. The type, however, doesn’t rely on /// content being properly encoded so this method isn’t unsafe. - pub(crate) fn new(bytes: Bytes, mode: Mode) -> Self { - Captured { bytes, mode } + pub(crate) fn new(bytes: Bytes, mode: Mode, start: Pos) -> Self { + Captured { bytes, mode, start } } /// Creates a captured value by encoding data. @@ -68,7 +75,8 @@ impl Captured { pub fn empty(mode: Mode) -> Self { Captured { bytes: Bytes::new(), - mode + mode, + start: Pos::default(), } } @@ -90,11 +98,13 @@ impl Captured { /// The method consumes the value. If you want to keep it around, simply /// clone it first. Since bytes values are cheap to clone, this is /// relatively cheap. - pub fn decode(self, op: F) -> Result + pub fn decode( + self, op: F + ) -> Result::Error>> where F: FnOnce( - &mut decode::Constructed - ) -> Result + &mut decode::Constructed + ) -> Result::Error>> { self.mode.decode(self.bytes, op) } @@ -104,13 +114,20 @@ impl Captured { /// The method calls `op` to parse a number of values from the beginning /// of the value and then advances the content of the captured value until /// after the end of these decoded values. - pub fn decode_partial(&mut self, op: F) -> Result + pub fn decode_partial( + &mut self, op: F + ) -> Result::Error>> where F: FnOnce( - &mut decode::Constructed<&mut Bytes> - ) -> Result + &mut decode::Constructed<&mut BytesSource> + ) -> Result::Error>> { - self.mode.decode(&mut self.bytes, op) + let mut source = mem::replace( + &mut self.bytes, Bytes::new() + ).into_source(); + let res = self.mode.decode(&mut source, op); + self.bytes = source.into_bytes(); + res } /// Trades the value for a bytes value with the raw data. @@ -148,6 +165,17 @@ impl AsRef<[u8]> for Captured { } +//--- IntoSource + +impl IntoSource for Captured { + type Source = BytesSource; + + fn into_source(self) -> Self::Source { + BytesSource::with_offset(self.bytes, self.start) + } +} + + //--- encode::Values impl encode::Values for Captured { @@ -223,7 +251,7 @@ impl CapturedBuilder { } pub fn freeze(self) -> Captured { - Captured::new(self.bytes.freeze(), self.mode) + Captured::new(self.bytes.freeze(), self.mode, Pos::default()) } } diff --git a/src/debug.rs b/src/debug.rs deleted file mode 100644 index 5131ed8..0000000 --- a/src/debug.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! Macros for last-resort debugging. -//! -//! _Note:_ This facility is going to be replaced by an error type that -//! includes a backtrace of the `extra-debug` feature is set. -//! -//! Since error reporting of the BER parser is limited on purpose, debugging -//! code using it may be difficult. To remedy this somewhat, this module -//! contains a macro `xerr!()` that will print out a backtrace if the -//! `extra-debug` feature is enable during build before resolving into -//! whatever the expression it encloses resolves to otherwise. Use it -//! whenever you initially produce an error, i.e.: -//! -//! ```rust,ignore -//! if foo { -//! xerr!(Err(Error::Malformed)) -//! } -//! ``` -//! -//! or, with an early return: -//! -//! ```rust,ignore -//! if foo { -//! xerr!(return Err(Error::Malformed)); -//! } -//! ``` - -#[cfg(feature = "extra-debug")] -extern crate backtrace; - -#[cfg(feature="extra-debug")] -pub use self::backtrace::Backtrace; - -#[cfg(feature = "extra-debug")] -#[macro_export] -macro_rules! xerr { - ($test:expr) => {{ - eprintln!( - "--- EXTRA DEBUG ---\n{:?}\n--- EXTRA DEBUG ---", - $crate::debug::Backtrace::new() - ); - $test - }} -} - -#[cfg(not(feature = "extra-debug"))] -#[macro_export] -macro_rules! xerr { - ($test:expr) => { $test }; -} - diff --git a/src/decode/content.rs b/src/decode/content.rs index bdec10d..860af73 100644 --- a/src/decode/content.rs +++ b/src/decode/content.rs @@ -3,6 +3,11 @@ //! This is an internal module. Its public types are re-exported by the //! parent. +#![allow(unused_imports)] +#![allow(dead_code)] + +use std::fmt; +use std::convert::Infallible; use bytes::Bytes; use smallvec::SmallVec; use crate::captured::Captured; @@ -10,8 +15,10 @@ use crate::int::{Integer, Unsigned}; use crate::length::Length; use crate::mode::Mode; use crate::tag::Tag; -use super::error::Error; -use super::source::{CaptureSource, LimitedSource, Source}; +use super::error::{ContentError, DecodeError}; +use super::source::{ + CaptureSource, IntoSource, LimitedSource, Pos, SliceSource, Source, +}; //------------ Content ------------------------------------------------------- @@ -34,10 +41,10 @@ pub enum Content<'a, S: 'a> { } impl<'a, S: Source + 'a> Content<'a, S> { - /// Checkes that the content has been parsed completely. + /// Checks that the content has been parsed completely. /// /// Returns a malformed error if not. - fn exhausted(self) -> Result<(), S::Err> { + fn exhausted(self) -> Result<(), DecodeError> { match self { Content::Primitive(inner) => inner.exhausted(), Content::Constructed(mut inner) => inner.exhausted() @@ -69,11 +76,13 @@ impl<'a, S: Source + 'a> Content<'a, S> { } /// Converts a reference into into one to a primitive value or errors out. - pub fn as_primitive(&mut self) -> Result<&mut Primitive<'a, S>, S::Err> { + pub fn as_primitive( + &mut self + ) -> Result<&mut Primitive<'a, S>, DecodeError> { match *self { Content::Primitive(ref mut inner) => Ok(inner), - Content::Constructed(_) => { - xerr!(Err(Error::Malformed.into())) + Content::Constructed(ref inner) => { + Err(inner.content_err("expected primitive value")) } } } @@ -81,14 +90,24 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// Converts a reference into on to a constructed value or errors out. pub fn as_constructed( &mut self - ) -> Result<&mut Constructed<'a, S>, S::Err> { + ) -> Result<&mut Constructed<'a, S>, DecodeError> { match *self { - Content::Primitive(_) => { - xerr!(Err(Error::Malformed.into())) + Content::Primitive(ref inner) => { + Err(inner.content_err("expected constructed value")) } Content::Constructed(ref mut inner) => Ok(inner), } } + + /// Produces a content error at the current source position. + pub fn content_err( + &self, err: impl Into, + ) -> DecodeError { + match *self { + Content::Primitive(ref inner) => inner.content_err(err), + Content::Constructed(ref inner) => inner.content_err(err), + } + } } #[allow(clippy::wrong_self_convention)] @@ -97,12 +116,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content is not primitive or does not contain a single BER /// encoded INTEGER value between 0 and 256, returns a malformed error. - pub fn to_u8(&mut self) -> Result { + pub fn to_u8(&mut self) -> Result> { if let Content::Primitive(ref mut prim) = *self { prim.to_u8() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected integer (0..255)")) } } @@ -111,13 +130,15 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// The content needs to be primitive and contain a validly encoded /// integer of value `expected` or else a malformed error will be /// returned. - pub fn skip_u8_if(&mut self, expected: u8) -> Result<(), S::Err> { + pub fn skip_u8_if( + &mut self, expected: u8, + ) -> Result<(), DecodeError> { let res = self.to_u8()?; if res == expected { Ok(()) } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err(ExpectedIntValue(expected))) } } @@ -125,12 +146,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content is not primitive or does not contain a single BER /// encoded INTEGER value between 0 and 2^16-1, returns a malformed error. - pub fn to_u16(&mut self) -> Result { + pub fn to_u16(&mut self) -> Result> { if let Content::Primitive(ref mut prim) = *self { prim.to_u16() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected integer (0..65535)")) } } @@ -138,12 +159,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content is not primitive or does not contain a single BER /// encoded INTEGER value between 0 and 2^32-1, returns a malformed error. - pub fn to_u32(&mut self) -> Result { + pub fn to_u32(&mut self) -> Result> { if let Content::Primitive(ref mut prim) = *self { prim.to_u32() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected integer (0..4294967295)")) } } @@ -151,12 +172,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content is not primitive or does not contain a single BER /// encoded INTEGER value between 0 and 2^64-1, returns a malformed error. - pub fn to_u64(&mut self) -> Result { + pub fn to_u64(&mut self) -> Result> { if let Content::Primitive(ref mut prim) = *self { prim.to_u64() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected integer (0..2**64-1)")) } } @@ -164,12 +185,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content isn’t primitive and contains a single BER encoded /// NULL value (i.e., nothing), returns a malformed error. - pub fn to_null(&mut self) -> Result<(), S::Err> { + pub fn to_null(&mut self) -> Result<(), DecodeError> { if let Content::Primitive(ref mut prim) = *self { prim.to_null() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected NULL")) } } } @@ -207,14 +228,18 @@ pub struct Primitive<'a, S: 'a> { /// The decoding mode to operate in. mode: Mode, + + /// The start position of the value in the source. + start: Pos, } /// # Value Management /// impl<'a, S: 'a> Primitive<'a, S> { /// Creates a new primitive from the given source and mode. - fn new(source: &'a mut LimitedSource, mode: Mode) -> Self { - Primitive { source, mode } + fn new(source: &'a mut LimitedSource, mode: Mode) -> Self + where S: Source { + Primitive { start: source.pos(), source, mode } } /// Returns the current decoding mode. @@ -231,20 +256,29 @@ impl<'a, S: 'a> Primitive<'a, S> { } } +impl<'a, S: Source + 'a> Primitive<'a, S> { + /// Produces a content error at the current source position. + pub fn content_err( + &self, err: impl Into, + ) -> DecodeError { + DecodeError::content(err, self.start) + } +} + /// # High-level Decoding /// #[allow(clippy::wrong_self_convention)] impl<'a, S: Source + 'a> Primitive<'a, S> { /// Parses the primitive value as a BOOLEAN value. - pub fn to_bool(&mut self) -> Result { + pub fn to_bool(&mut self) -> Result> { let res = self.take_u8()?; if self.mode != Mode::Ber { match res { 0 => Ok(false), 0xFF => Ok(true), _ => { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("invalid boolean")) } } } @@ -254,61 +288,65 @@ impl<'a, S: Source + 'a> Primitive<'a, S> { } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i8(&mut self) -> Result { + pub fn to_i8(&mut self) -> Result> { Integer::i8_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i16(&mut self) -> Result { + pub fn to_i16(&mut self) -> Result> { Integer::i16_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i32(&mut self) -> Result { + pub fn to_i32(&mut self) -> Result> { Integer::i32_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i64(&mut self) -> Result { + pub fn to_i64(&mut self) -> Result> { Integer::i64_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i128(&mut self) -> Result { + pub fn to_i128(&mut self) -> Result> { Integer::i128_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `u8`. - pub fn to_u8(&mut self) -> Result { + pub fn to_u8(&mut self) -> Result> { Unsigned::u8_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `u16`. - pub fn to_u16(&mut self) -> Result { + pub fn to_u16(&mut self) -> Result> { Unsigned::u16_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `u32`. - pub fn to_u32(&mut self) -> Result { + pub fn to_u32(&mut self) -> Result> { Unsigned::u32_from_primitive(self) } /// Parses the primitive value as a INTEGER value limited to a `u64`. - pub fn to_u64(&mut self) -> Result { + pub fn to_u64(&mut self) -> Result> { Unsigned::u64_from_primitive(self) } /// Parses the primitive value as a INTEGER value limited to a `u128`. - pub fn to_u128(&mut self) -> Result { + pub fn to_u128(&mut self) -> Result> { Unsigned::u64_from_primitive(self) } /// Converts the content octets to a NULL value. /// /// Since such a value is empty, this doesn’t really do anything. - pub fn to_null(&mut self) -> Result<(), S::Err> { - // The rest is taken care of by the exhausted check later ... - Ok(()) + pub fn to_null(&mut self) -> Result<(), DecodeError> { + if self.remaining() > 0 { + Err(self.content_err("invalid NULL value")) + } + else { + Ok(()) + } } } @@ -322,31 +360,38 @@ impl<'a, S: Source + 'a> Primitive<'a, S> { impl<'a, S: Source + 'a> Primitive<'a, S> { /// Returns the number of remaining octets. /// - /// The returned value reflects what is left of the content and therefore - /// decreases when the primitive is advanced. + /// The returned value reflects what is left of the expected length of + /// content and therefore decreases when the primitive is advanced. pub fn remaining(&self) -> usize { self.source.limit().unwrap() } /// Skips the rest of the content. - pub fn skip_all(&mut self) -> Result<(), S::Err> { + /// + /// Returns a malformed error if the source ends before the expected + /// length of content. + pub fn skip_all(&mut self) -> Result<(), DecodeError> { self.source.skip_all() } /// Returns the remainder of the content as a `Bytes` value. - pub fn take_all(&mut self) -> Result { + pub fn take_all(&mut self) -> Result> { self.source.take_all() } /// Returns a bytes slice of the remainder of the content. - pub fn slice_all(&mut self) -> Result<&[u8], S::Err> { + pub fn slice_all(&mut self) -> Result<&[u8], DecodeError> { let remaining = self.remaining(); - self.source.request(remaining)?; - Ok(&self.source.slice()[..remaining]) + if self.source.request(remaining)? < remaining { + Err(self.source.content_err("unexpected end of data")) + } + else { + Ok(&self.source.slice()[..remaining]) + } } /// Checkes whether all content has been advanced over. - fn exhausted(self) -> Result<(), S::Err> { + fn exhausted(self) -> Result<(), DecodeError> { self.source.exhausted() } } @@ -354,7 +399,7 @@ impl<'a, S: Source + 'a> Primitive<'a, S> { /// # Support for Testing /// -impl<'a> Primitive<'a, &'a [u8]> { +impl Primitive<'static, ()> { /// Decode a bytes slice via a closure. /// /// This method can be used in testing code for decoding primitive @@ -374,13 +419,17 @@ impl<'a> Primitive<'a, &'a [u8]> { /// ) /// ``` pub fn decode_slice( - source: &'a [u8], + data: &[u8], mode: Mode, op: F - ) -> Result - where F: FnOnce(&mut Primitive<&[u8]>) -> Result { - let mut lim = LimitedSource::new(source); - lim.set_limit(Some(source.len())); + ) -> Result> + where + F: FnOnce( + &mut Primitive + ) -> Result> + { + let mut lim = LimitedSource::new(data.into_source()); + lim.set_limit(Some(data.len())); let mut prim = Primitive::new(&mut lim, mode); let res = op(&mut prim)?; prim.exhausted()?; @@ -392,14 +441,14 @@ impl<'a> Primitive<'a, &'a [u8]> { //--- Source impl<'a, S: Source + 'a> Source for Primitive<'a, S> { - type Err = S::Err; + type Error = S::Error; - fn request(&mut self, len: usize) -> Result { - self.source.request(len) + fn pos(&self) -> Pos { + self.source.pos() } - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - self.source.advance(len) + fn request(&mut self, len: usize) -> Result { + self.source.request(len) } fn slice(&self) -> &[u8] { @@ -409,6 +458,10 @@ impl<'a, S: Source + 'a> Source for Primitive<'a, S> { fn bytes(&self, start: usize, end: usize) -> Bytes { self.source.bytes(start, end) } + + fn advance(&mut self, len: usize) { + self.source.advance(len) + } } @@ -443,6 +496,9 @@ pub struct Constructed<'a, S: 'a> { /// The encoding mode to use. mode: Mode, + + /// The start position of the value in the source. + start: Pos, } /// # General Management @@ -454,10 +510,10 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { state: State, mode: Mode ) -> Self { - Constructed { source, state, mode } + Constructed { start: source.pos(), source, state, mode } } - /// Decode a source as a constructed content. + /// Decode a source as constructed content. /// /// The function will start decoding of `source` in the given mode. It /// will pass a constructed content value to the closure `op` which @@ -466,9 +522,14 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// This function is identical to calling [`Mode::decode`]. /// /// [`Mode::decode`]: ../enum.Mode.html#method.decode - pub fn decode(source: S, mode: Mode, op: F) -> Result - where F: FnOnce(&mut Constructed) -> Result { - let mut source = LimitedSource::new(source); + pub fn decode( + source: I, mode: Mode, op: F, + ) -> Result> + where + I: IntoSource, + F: FnOnce(&mut Constructed) -> Result> + { + let mut source = LimitedSource::new(source.into_source()); let mut cons = Constructed::new(&mut source, State::Unbounded, mode); let res = op(&mut cons)?; cons.exhausted()?; @@ -486,6 +547,15 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { } } +impl<'a, S: Source + 'a> Constructed<'a, S> { + /// Produces a content error at start of the value. + pub fn content_err( + &self, err: impl Into, + ) -> DecodeError { + DecodeError::content(err, self.start) + } +} + /// # Fundamental Reading /// impl<'a, S: Source + 'a> Constructed<'a, S> { @@ -494,7 +564,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// For a value of definite length, this is the case when the limit of the /// source has been reached. For indefinite values, we need to have either /// already read or can now read the end-of-value marker. - fn exhausted(&mut self) -> Result<(), S::Err> { + fn exhausted(&mut self) -> Result<(), DecodeError> { match self.state { State::Done => Ok(()), State::Definite => { @@ -505,7 +575,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { if tag != Tag::END_OF_VALUE || constructed || !Length::take_from(self.source, self.mode)?.is_zero() { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("unexpected trailing values")) } else { Ok(()) @@ -546,8 +616,10 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { &mut self, expected: Option, op: F - ) -> Result, S::Err> - where F: FnOnce(Tag, &mut Content) -> Result { + ) -> Result, DecodeError> + where + F: FnOnce(Tag, &mut Content) -> Result> + { if self.is_exhausted() { return Ok(None) } @@ -568,16 +640,22 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { if tag == Tag::END_OF_VALUE { if let State::Indefinite = self.state { if constructed { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "constructed end of value" + )) } if !length.is_zero() { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "non-empty end of value" + )) } self.state = State::Done; return Ok(None) } else { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "unexpected end of value" + )) } } @@ -589,7 +667,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { // Definite length constructed values are not allowed // in CER. if self.mode == Mode::Cer { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "definite length constructed in CER mode" + )) } Content::Constructed( Constructed::new( @@ -611,10 +691,14 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { } Length::Indefinite => { if !constructed || self.mode == Mode::Der { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "indefinite length constructed in DER mode" + )) } let mut content = Content::Constructed( - Constructed::new(self.source, State::Indefinite, self.mode) + Constructed::new( + self.source, State::Indefinite, self.mode + ) ); let res = op(tag, &mut content)?; content.exhausted()?; @@ -622,6 +706,21 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { } } } + + /// Makes sure the next value is present. + fn mandatory( + &mut self, op: F, + ) -> Result> + where + F: FnOnce( + &mut Constructed + ) -> Result, DecodeError>, + { + match op(self)? { + Some(res) => Ok(res), + None => Err(self.source.content_err("missing futher values")), + } + } } /// # Processing Contained Values @@ -638,13 +737,15 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// method returns a malformed error if there isn’t at least one more /// value available. It also returns an error if the closure returns one /// or if reading from the source fails. - pub fn take_value(&mut self, op: F) -> Result - where F: FnOnce(Tag, &mut Content) -> Result { + pub fn take_value( + &mut self, op: F, + ) -> Result> + where + F: FnOnce(Tag, &mut Content) -> Result>, + { match self.process_next_value(None, op)? { Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } + None => Err(self.content_err("missing futher values")), } } @@ -658,8 +759,12 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// If there are no more values available, the method returns `Ok(None)`. /// It returns an error if the closure returns one or if reading from /// the source fails. - pub fn take_opt_value(&mut self, op: F) -> Result, S::Err> - where F: FnOnce(Tag, &mut Content) -> Result { + pub fn take_opt_value( + &mut self, op: F, + ) -> Result, DecodeError> + where + F: FnOnce(Tag, &mut Content) -> Result>, + { self.process_next_value(None, op) } @@ -677,16 +782,14 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { &mut self, expected: Tag, op: F - ) -> Result - where F: FnOnce(&mut Content) -> Result { + ) -> Result> + where F: FnOnce(&mut Content) -> Result> { let res = self.process_next_value(Some(expected), |_, content| { op(content) })?; match res { Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } + None => Err(self.content_err(ExpectedTag(expected))), } } @@ -704,12 +807,12 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { &mut self, expected: Tag, op: F - ) -> Result, S::Err> - where F: FnOnce(&mut Content) -> Result { + ) -> Result, DecodeError> + where F: FnOnce(&mut Content) -> Result> { self.process_next_value(Some(expected), |_, content| op(content)) } - /// Process a constructed value. + /// Processes a constructed value. /// /// If the next value is a constructed value, its tag and content are /// being given to the closure `op` which has to process it completely. @@ -719,14 +822,15 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// value or if the closure doesn’t process the next value completely, /// a malformed error is returned. An error is also returned if the /// closure returns one or if accessing the underlying source fails. - pub fn take_constructed(&mut self, op: F) -> Result - where F: FnOnce(Tag, &mut Constructed) -> Result { - match self.take_opt_constructed(op)? { - Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } - } + pub fn take_constructed( + &mut self, op: F + ) -> Result> + where + F: FnOnce( + Tag, &mut Constructed + ) -> Result>, + { + self.mandatory(|cons| cons.take_opt_constructed(op)) } /// Processes an optional constructed value. @@ -745,8 +849,12 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { pub fn take_opt_constructed( &mut self, op: F - ) -> Result, S::Err> - where F: FnOnce(Tag, &mut Constructed) -> Result { + ) -> Result, DecodeError> + where + F: FnOnce( + Tag, &mut Constructed, + ) -> Result> + { self.process_next_value(None, |tag, content| { op(tag, content.as_constructed()?) }) @@ -767,15 +875,12 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { pub fn take_constructed_if( &mut self, expected: Tag, - op: F - ) -> Result - where F: FnOnce(&mut Constructed) -> Result { - match self.take_opt_constructed_if(expected, op)? { - Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } - } + op: F, + ) -> Result> + where + F: FnOnce(&mut Constructed) -> Result>, + { + self.mandatory(|cons| cons.take_opt_constructed_if(expected, op)) } /// Processes an optional constructed value if it has a given tag. @@ -795,9 +900,11 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { pub fn take_opt_constructed_if( &mut self, expected: Tag, - op: F - ) -> Result, S::Err> - where F: FnOnce(&mut Constructed) -> Result { + op: F, + ) -> Result, DecodeError> + where + F: FnOnce(&mut Constructed) -> Result>, + { self.process_next_value(Some(expected), |_, content| { op(content.as_constructed()?) }) @@ -813,14 +920,13 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// reached, or if the closure fails to process the next value’s content /// fully, a malformed error is returned. An error is also returned if /// the closure returns one or if accessing the underlying source fails. - pub fn take_primitive(&mut self, op: F) -> Result - where F: FnOnce(Tag, &mut Primitive) -> Result { - match self.take_opt_primitive(op)? { - Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } - } + pub fn take_primitive( + &mut self, op: F, + ) -> Result> + where + F: FnOnce(Tag, &mut Primitive) -> Result>, + { + self.mandatory(|cons| cons.take_opt_primitive(op)) } /// Processes an optional primitive value. @@ -835,10 +941,11 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// malformed error is returned. An error is also returned if /// the closure returns one or if accessing the underlying source fails. pub fn take_opt_primitive( - &mut self, - op: F - ) -> Result, S::Err> - where F: FnOnce(Tag, &mut Primitive) -> Result { + &mut self, op: F, + ) -> Result, DecodeError> + where + F: FnOnce(Tag, &mut Primitive) -> Result>, + { self.process_next_value(None, |tag, content| { op(tag, content.as_primitive()?) }) @@ -855,17 +962,10 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// the closure doesn’t advance over the complete content. If access to /// the underlying source fails, an error is returned, too. pub fn take_primitive_if( - &mut self, - expected: Tag, - op: F - ) -> Result - where F: FnOnce(&mut Primitive) -> Result { - match self.take_opt_primitive_if(expected, op)? { - Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } - } + &mut self, expected: Tag, op: F, + ) -> Result> + where F: FnOnce(&mut Primitive) -> Result> { + self.mandatory(|cons| cons.take_opt_primitive_if(expected, op)) } /// Processes an optional primitive value of a given tag. @@ -880,11 +980,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// fully the method returns a malformed error. If access to the /// underlying source fails, it returns an appropriate error. pub fn take_opt_primitive_if( - &mut self, - expected: Tag, - op: F - ) -> Result, S::Err> - where F: FnOnce(&mut Primitive) -> Result { + &mut self, expected: Tag, op: F, + ) -> Result, DecodeError> + where F: FnOnce(&mut Primitive) -> Result> { self.process_next_value(Some(expected), |_, content| { op(content.as_primitive()?) }) @@ -902,13 +1000,16 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// If the closure returns an error, this error is returned. /// /// [`Captured`]: ../captures/struct.Captured.html - pub fn capture(&mut self, op: F) -> Result + pub fn capture( + &mut self, op: F, + ) -> Result> where F: FnOnce( &mut Constructed>> - ) -> Result<(), S::Err> + ) -> Result<(), DecodeError> { let limit = self.source.limit(); + let start = self.source.pos(); let mut source = LimitedSource::new(CaptureSource::new(self.source)); source.set_limit(limit); { @@ -918,7 +1019,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { op(&mut constructed)?; self.state = constructed.state; } - Ok(Captured::new(source.unwrap().into_bytes(), self.mode)) + Ok(Captured::new( + source.unwrap().into_bytes(), self.mode, start, + )) } /// Captures one value for later processing @@ -930,28 +1033,25 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// to the underlying source fails, an appropriate error is returned. /// /// [`Captured`]: ../captures/struct.Captured.html - pub fn capture_one(&mut self) -> Result { - self.capture(|cons| { - match cons.skip_one()? { - Some(()) => Ok(()), - None => { - xerr!(Err(Error::Malformed.into())) - } - } - }) + pub fn capture_one(&mut self) -> Result> { + self.capture(|cons| cons.mandatory(|cons| cons.skip_one())) } /// Captures all remaining content for later processing. /// /// The method takes all remaining values from this value’s content and /// returns their encoded form in a `Bytes` value. - pub fn capture_all(&mut self) -> Result { + pub fn capture_all( + &mut self + ) -> Result> { self.capture(|cons| cons.skip_all()) } /// Skips over content. - pub fn skip_opt(&mut self, mut op: F) -> Result, S::Err> - where F: FnMut(Tag, bool, usize) -> Result<(), S::Err> { + pub fn skip_opt( + &mut self, mut op: F, + ) -> Result, DecodeError> + where F: FnMut(Tag, bool, usize) -> Result<(), ContentError> { // If we already know we are at the end of the value, we can return. if self.is_exhausted() { return Ok(None) @@ -970,7 +1070,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { if !constructed { if tag == Tag::END_OF_VALUE { if length != Length::Definite(0) { - xerr!(return Err(Error::Malformed.into())) + return Err(self.content_err("non-empty end of value")) } // End-of-value: The top of the stack needs to be an @@ -989,36 +1089,51 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { return Ok(None) } else { - xerr!(return Err(Error::Malformed.into())) + return Err(self.content_err( + "invalid nested values" + )) } } - _ => xerr!(return Err(Error::Malformed.into())) + _ => { + return Err(self.content_err( + "invalid nested values" + )) + } } } else { // Primitive value. Check for the length to be definite, - // check that the caller likes it, then try to read over it. + // check that the caller likes it, then try to read over + // it. if let Length::Definite(len) = length { - op(tag, constructed, stack.len())?; - self.source.advance(len)?; + if let Err(err) = op(tag, constructed, stack.len()) { + return Err(self.content_err(err)); + } + self.source.advance(len); } else { - xerr!(return Err(Error::Malformed.into())); + return Err(self.content_err( + "primitive value with indefinite length" + )) } } } else if let Length::Definite(len) = length { - // Definite constructed value. First check if the caller likes - // it. Check that there is enough limit left for the value. If - // so, push the limit at the end of the value to the stack, - // update the limit to our length, and continue. - op(tag, constructed, stack.len())?; + // Definite constructed value. First check if the caller + // likes it. Check that there is enough limit left for the + // value. If so, push the limit at the end of the value to + // the stack, update the limit to our length, and continue. + if let Err(err) = op(tag, constructed, stack.len()) { + return Err(self.content_err(err)); + } stack.push(Some(match self.source.limit() { Some(limit) => { match limit.checked_sub(len) { Some(len) => Some(len), None => { - xerr!(return Err(Error::Malformed.into())); + return Err(self.content_err( + "invalid nested values" + )); } } } @@ -1029,7 +1144,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { else { // Indefinite constructed value. Simply push a `None` to the // stack, if the caller likes it. - op(tag, constructed, stack.len())?; + if let Err(err) = op(tag, constructed, stack.len()) { + return Err(self.content_err(err)); + } stack.push(None); continue; } @@ -1051,7 +1168,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { Some(None) => { // We need a End-of-value, so running out of // data is an error. - xerr!(return Err(Error::Malformed.into())); + return Err(self.content_err(" + missing futher values" + )) } None => unreachable!(), } @@ -1064,18 +1183,13 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { } } - pub fn skip(&mut self, op: F) -> Result<(), S::Err> - where F: FnMut(Tag, bool, usize) -> Result<(), S::Err> { - if self.skip_opt(op)? == None { - xerr!(Err(Error::Malformed.into())) - } - else { - Ok(()) - } + pub fn skip(&mut self, op: F) -> Result<(), DecodeError> + where F: FnMut(Tag, bool, usize) -> Result<(), ContentError> { + self.mandatory(|cons| cons.skip_opt(op)) } /// Skips over all remaining content. - pub fn skip_all(&mut self) -> Result<(), S::Err> { + pub fn skip_all(&mut self) -> Result<(), DecodeError> { while let Some(()) = self.skip_one()? { } Ok(()) } @@ -1084,7 +1198,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If there is a next value, returns `Ok(Some(()))`, if the end of value /// has already been reached, returns `Ok(None)`. - pub fn skip_one(&mut self) -> Result, S::Err> { + pub fn skip_one(&mut self) -> Result, DecodeError> { if self.is_exhausted() { Ok(None) } @@ -1103,22 +1217,24 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// encoding. impl<'a, S: Source + 'a> Constructed<'a, S> { /// Processes and returns a mandatory boolean value. - pub fn take_bool(&mut self) -> Result { + pub fn take_bool(&mut self) -> Result> { self.take_primitive_if(Tag::BOOLEAN, |prim| prim.to_bool()) } /// Processes and returns an optional boolean value. - pub fn take_opt_bool(&mut self) -> Result, S::Err> { + pub fn take_opt_bool( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::BOOLEAN, |prim| prim.to_bool()) } /// Processes a mandatory NULL value. - pub fn take_null(&mut self) -> Result<(), S::Err> { + pub fn take_null(&mut self) -> Result<(), DecodeError> { self.take_primitive_if(Tag::NULL, |_| Ok(())).map(|_| ()) } /// Processes an optional NULL value. - pub fn take_opt_null(&mut self) -> Result<(), S::Err> { + pub fn take_opt_null(&mut self) -> Result<(), DecodeError> { self.take_opt_primitive_if(Tag::NULL, |_| Ok(())).map(|_| ()) } @@ -1126,7 +1242,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 255, a malformed /// error is returned. - pub fn take_u8(&mut self) -> Result { + pub fn take_u8(&mut self) -> Result> { self.take_primitive_if(Tag::INTEGER, |prim| prim.to_u8()) } @@ -1134,7 +1250,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 255, a malformed /// error is returned. - pub fn take_opt_u8(&mut self) -> Result, S::Err> { + pub fn take_opt_u8( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| prim.to_u8()) } @@ -1142,11 +1260,13 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the next value is an integer but of a different value, returns /// a malformed error. - pub fn skip_u8_if(&mut self, expected: u8) -> Result<(), S::Err> { + pub fn skip_u8_if( + &mut self, expected: u8, + ) -> Result<(), DecodeError> { self.take_primitive_if(Tag::INTEGER, |prim| { let got = prim.take_u8()?; if got != expected { - xerr!(Err(Error::Malformed.into())) + Err(prim.content_err(ExpectedIntValue(expected))) } else { Ok(()) @@ -1158,11 +1278,13 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the next value is an integer but of a different value, returns /// a malformed error. - pub fn skip_opt_u8_if(&mut self, expected: u8) -> Result<(), S::Err> { + pub fn skip_opt_u8_if( + &mut self, expected: u8, + ) -> Result<(), DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| { let got = prim.take_u8()?; if got != expected { - xerr!(Err(Error::Malformed.into())) + Err(prim.content_err(ExpectedIntValue(expected))) } else { Ok(()) @@ -1172,17 +1294,19 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// Processes a mandatory INTEGER value of the `u16` range. /// - /// If the integer value is less than 0 or greater than 65535, a malformed - /// error is returned. - pub fn take_u16(&mut self) -> Result { + /// If the integer value is less than 0 or greater than 65535, a + /// malformed error is returned. + pub fn take_u16(&mut self) -> Result> { self.take_primitive_if(Tag::INTEGER, |prim| prim.to_u16()) } /// Processes an optional INTEGER value of the `u16` range. /// - /// If the integer value is less than 0 or greater than 65535, a malformed - /// error is returned. - pub fn take_opt_u16(&mut self) -> Result, S::Err> { + /// If the integer value is less than 0 or greater than 65535, a + /// malformed error is returned. + pub fn take_opt_u16( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| prim.to_u16()) } @@ -1190,7 +1314,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 2^32-1, a /// malformed error is returned. - pub fn take_u32(&mut self) -> Result { + pub fn take_u32(&mut self) -> Result> { self.take_primitive_if(Tag::INTEGER, |prim| prim.to_u32()) } @@ -1198,7 +1322,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 2^32-1, a /// malformed error is returned. - pub fn take_opt_u32(&mut self) -> Result, S::Err> { + pub fn take_opt_u32( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| prim.to_u32()) } @@ -1206,7 +1332,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 2^64-1, a /// malformed error is returned. - pub fn take_u64(&mut self) -> Result { + pub fn take_u64(&mut self) -> Result> { self.take_primitive_if(Tag::INTEGER, |prim| prim.to_u64()) } @@ -1214,42 +1340,50 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 2^64-1, a /// malformed error is returned. - pub fn take_opt_u64(&mut self) -> Result, S::Err> { + pub fn take_opt_u64( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| prim.to_u64()) } /// Processes a mandatory SEQUENCE value. /// /// This is a shortcut for `self.take_constructed(Tag::SEQUENCE, op)`. - pub fn take_sequence(&mut self, op: F) -> Result - where F: FnOnce(&mut Constructed) -> Result { + pub fn take_sequence( + &mut self, op: F, + ) -> Result> + where F: FnOnce(&mut Constructed) -> Result> { self.take_constructed_if(Tag::SEQUENCE, op) } /// Processes an optional SEQUENCE value. /// - /// This is a shortcut for `self.take_opt_constructed(Tag::SEQUENCE, op)`. + /// This is a shortcut for + /// `self.take_opt_constructed(Tag::SEQUENCE, op)`. pub fn take_opt_sequence( - &mut self, - op: F - ) -> Result, S::Err> - where F: FnOnce(&mut Constructed) -> Result { + &mut self, op: F, + ) -> Result, DecodeError> + where F: FnOnce(&mut Constructed) -> Result> { self.take_opt_constructed_if(Tag::SEQUENCE, op) } /// Processes a mandatory SET value. /// /// This is a shortcut for `self.take_constructed(Tag::SET, op)`. - pub fn take_set(&mut self, op: F) -> Result - where F: FnOnce(&mut Constructed) -> Result { + pub fn take_set( + &mut self, op: F, + ) -> Result> + where F: FnOnce(&mut Constructed) -> Result> { self.take_constructed_if(Tag::SET, op) } /// Processes an optional SET value. /// /// This is a shortcut for `self.take_opt_constructed(Tag::SET, op)`. - pub fn take_opt_set(&mut self, op: F) -> Result, S::Err> - where F: FnOnce(&mut Constructed) -> Result { + pub fn take_opt_set( + &mut self, op: F + ) -> Result, DecodeError> + where F: FnOnce(&mut Constructed) -> Result> { self.take_opt_constructed_if(Tag::SET, op) } } @@ -1272,6 +1406,43 @@ enum State { /// Unbounded value: read as far as we get. Unbounded, } + + +//============ Error Types =================================================== + +/// A value with a certain tag was expected. +#[derive(Clone, Copy, Debug)] +struct ExpectedTag(Tag); + +impl fmt::Display for ExpectedTag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "expected tag {}", self.0) + } +} + +impl From for ContentError { + fn from(err: ExpectedTag) -> Self { + ContentError::from_boxed(Box::new(err)) + } +} + + +/// An integer with a certain value was expected. +#[derive(Clone, Copy, Debug)] +struct ExpectedIntValue(T); + +impl fmt::Display for ExpectedIntValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "expected integer value {}", self.0) + } +} + +impl From> for ContentError +where T: fmt::Display + Send + Sync + 'static { + fn from(err: ExpectedIntValue) -> Self { + ContentError::from_boxed(Box::new(err)) + } +} //============ Tests ========================================================= @@ -1284,7 +1455,7 @@ mod test { fn constructed_skip() { // Two primitives. Constructed::decode( - b"\x02\x01\x00\x02\x01\x00".as_ref(), Mode::Ber, |cons| { + b"\x02\x01\x00\x02\x01\x00".into_source(), Mode::Ber, |cons| { cons.skip(|_, _, _| Ok(())).unwrap(); cons.skip(|_, _, _| Ok(())).unwrap(); Ok(()) @@ -1293,7 +1464,7 @@ mod test { // One definite constructed with two primitives, then one primitive Constructed::decode( - b"\x30\x06\x02\x01\x00\x02\x01\x00\x02\x01\x00".as_ref(), + b"\x30\x06\x02\x01\x00\x02\x01\x00\x02\x01\x00".into_source(), Mode::Ber, |cons| { cons.skip(|_, _, _| Ok(())).unwrap(); @@ -1308,7 +1479,7 @@ mod test { b"\x30\x08\ \x30\x06\ \x02\x01\x00\x02\x01\x00\ - \x02\x01\x00".as_ref(), + \x02\x01\x00".into_source(), Mode::Ber, |cons| { cons.skip(|_, _, _| Ok(())).unwrap(); @@ -1322,7 +1493,7 @@ mod test { b"\x30\x0A\ \x30\x80\ \x02\x01\x00\x02\x01\x00\ - \0\0".as_ref(), + \0\0".into_source(), Mode::Ber, |cons| { cons.skip(|_, _, _| Ok(())).unwrap(); @@ -1332,3 +1503,4 @@ mod test { } } + diff --git a/src/decode/error.rs b/src/decode/error.rs index 6a470ff..4cc1384 100644 --- a/src/decode/error.rs +++ b/src/decode/error.rs @@ -3,26 +3,161 @@ //! This is a private module. Its public content is being re-exported by the //! parent module. -use std::fmt; +use std::{error, fmt}; +use std::convert::Infallible; +use super::source::Pos; -//------------ Error --------------------------------------------------------- +//------------ ContentError -------------------------------------------------- + +/// An error happened while interpreting BER-encoded data. +pub struct ContentError { + /// The error message. + message: ErrorMessage, +} + +impl ContentError { + /// Creates a content error from a static str. + pub fn from_static(msg: &'static str) -> Self { + ContentError { + message: ErrorMessage::Static(msg) + } + } + + /// Creates a content error from a boxed trait object. + pub fn from_boxed( + msg: Box + ) -> Self { + ContentError { + message: ErrorMessage::Boxed(msg) + } + } +} + +impl From<&'static str> for ContentError { + fn from(msg: &'static str) -> Self { + Self::from_static(msg) + } +} + +impl From for ContentError { + fn from(msg: String) -> Self { + Self::from_boxed(Box::new(msg)) + } +} + +impl From> for ContentError { + fn from(err: DecodeError) -> Self { + match err.inner { + DecodeErrorKind::Source(_) => unreachable!(), + DecodeErrorKind::Content { error, .. } => error, + } + } +} + + +impl fmt::Display for ContentError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.message.fmt(f) + } +} + +impl fmt::Debug for ContentError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("ContentError") + .field(&self.message) + .finish() + } +} + + +//------------ ErrorMessage -------------------------------------------------- + +/// The actual error message as a hidden enum. +enum ErrorMessage { + /// The error message is a static str. + Static(&'static str), + + /// The error message is a boxed trait object. + Boxed(Box), +} + +impl fmt::Display for ErrorMessage { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ErrorMessage::Static(msg) => f.write_str(msg), + ErrorMessage::Boxed(ref msg) => msg.fmt(f), + } + } +} + +impl fmt::Debug for ErrorMessage { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ErrorMessage::Static(msg) => f.write_str(msg), + ErrorMessage::Boxed(ref msg) => msg.fmt(f), + } + } +} + + +//------------ DecodeError --------------------------------------------------- /// An error happened while decoding data. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum Error { - /// The data didn’t conform to the expected structure. - Malformed, +/// +/// This can either be a source error – which the type is generic over – or a +/// content error annotated with the position in the source where it happened. +#[derive(Debug)] +pub struct DecodeError { + inner: DecodeErrorKind, +} + +#[derive(Debug)] +enum DecodeErrorKind { + Source(S), + Content { + error: ContentError, + pos: Pos, + } +} - /// An encoding used by the data is not yet implemented by the crate. - Unimplemented, +impl DecodeError { + /// Creates a decode error from a content error and a position. + pub fn content(error: impl Into, pos: Pos) -> Self { + DecodeError { + inner: DecodeErrorKind::Content { error: error.into(), pos }, + } + } } -impl fmt::Display for Error { +impl DecodeError { + /// Converts a decode error from an infallible source into another error. + pub fn convert(self) -> DecodeError { + match self.inner { + DecodeErrorKind::Source(_) => unreachable!(), + DecodeErrorKind::Content { error, pos } => { + DecodeError::content(error, pos) + } + } + } +} + +impl From for DecodeError { + fn from(err: S) -> Self { + DecodeError { inner: DecodeErrorKind::Source(err) } + } +} + +impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::Malformed => write!(f, "malformed data"), - Error::Unimplemented => write!(f, "format not implemented"), + match self.inner { + DecodeErrorKind::Source(ref err) => err.fmt(f), + DecodeErrorKind::Content { ref error, pos } => { + write!(f, "{} (at position {})", error, pos) + } } } } + +impl error::Error for DecodeError { } + diff --git a/src/decode/mod.rs b/src/decode/mod.rs index 671fab8..e6d87c8 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -4,9 +4,9 @@ //! //! The basic idea is that for each type a function exists that knows how //! to decode one value of that type. For constructed types, this function -//! in turn relies on similar functions provided for its consituent types. +//! in turn relies on similar functions provided for its constituent types. //! For a detailed introduction to how to write these functions, please -//! refer to the [decode section of the guide]. +//! refer to the [decode section of the guide][crate::guide::decode]. //! //! The two most important types of this module are [`Primitive`] and //! [`Constructed`], representing the content octets of a value in primitive @@ -20,19 +20,30 @@ //! The enum [`Content`] is used for cases where a value can be either //! primitive or constructed such as most string types. //! -//! Decoding is jumpstarted by providing a data source to parse data from. -//! This is any value that implements the [`Source`] trait. +//! The data for decoding is provided by any type that implements the +//! [`Source`] trait – or can be converted into such a type via the +//! [`IntoSource`] trait. Implementations for both `bytes::Bytes` and +//! `&[u8]` are available. //! -//! [decode section of the guide]: ../guide/decode/index.html -//! [`Primitive`]: struct.Primitive.html -//! [`Constructed`]: struct.Constructed.html -//! [`Content`]: enum.Content.html -//! [`Source`]: trait.Source.html +//! During decoding, errors can happen. There are two kinds of errors: for +//! one, the source can fail to gather more data, e.g., when reading from a +//! file fails. Such errors are called _source errors._ Their type is +//! provided by the source. +//! +//! Second, data that cannot be decoded according to the syntax is said to +//! result in a _content error._ The [`ContentError`] type is used for such +//! errors. +//! +//! When decoding data from a source, both errors can happen. The type +//! `DecodeError` provides a way to store either of them and is the error +//! type you will likely encounter the most. pub use self::content::{Content, Constructed, Primitive}; -pub use self::error::Error; -pub use self::error::Error::{Malformed, Unimplemented}; -pub use self::source::{CaptureSource, LimitedSource, Source}; +pub use self::error::{ContentError, DecodeError}; +pub use self::source::{ + BytesSource, CaptureSource, IntoSource, Pos, LimitedSource, SliceSource, + Source +}; mod content; mod error; diff --git a/src/decode/source.rs b/src/decode/source.rs index 4caefca..cb4c0c1 100644 --- a/src/decode/source.rs +++ b/src/decode/source.rs @@ -1,12 +1,13 @@ //! The source for decoding data. //! -//! This is an internal module. It’s public types are re-exported by the +//! This is an internal module. Its public types are re-exported by the //! parent. -use std::mem; +use std::{error, fmt, mem, ops}; use std::cmp::min; +use std::convert::Infallible; use bytes::Bytes; -use super::error::Error; +use super::error::{ContentError, DecodeError}; //------------ Source -------------------------------------------------------- @@ -17,46 +18,33 @@ use super::error::Error; /// decoders. /// /// A source can only progress forward over time. It provides the ability to -/// access the next few bytes as a slice, advance forward, or advance forward -/// returning a Bytes value of the data it advanced over. +/// access the next few bytes as a slice or a [`Bytes`] value, and advance +/// forward. /// /// _Please note:_ This trait may change as we gain more experience with /// decoding in different circumstances. If you implement it for your own /// types, we would appreciate feedback what worked well and what didn’t. pub trait Source { - /// The error produced by the source. - /// - /// The type used here needs to wrap [`ber::decode::Error`] and extends it - /// by whatever happens if acquiring additional data fails. If `Source` - /// is implemented for types where this acqusition cannot fail, - /// `ber::decode::Error` should be used here. - /// - /// [`ber::decode::Error`]: enum.Error.html - type Err: From; + /// The error produced when the source failed to read more data. + type Error: error::Error; + + /// Returns the current logical postion within the sequence of data. + fn pos(&self) -> Pos; /// Request at least `len` bytes to be available. /// /// The method returns the number of bytes that are actually available. /// This may only be smaller than `len` if the source ends with less - /// bytes available. + /// bytes available. It may be larger than `len` but less than the total + /// number of bytes still left in the source. + /// + /// The method can be called multiple times without advancing in between. + /// If in this case `len` is larger than when last called, the source + /// should try and make the additional data available. /// /// The method should only return an error if the source somehow fails /// to get more data such as an IO error or reset connection. - fn request(&mut self, len: usize) -> Result; - - /// Advance the source by `len` bytes. - /// - /// The method advances the start of the view provided by the source by - /// `len` bytes. Advancing beyond the end of a source is an error. - /// Implementations should return their equivalient of - /// [`Error::Malformed`]. - /// - /// The value of `len` may be larger than the last length previously - /// request via [`request`]. - /// - /// [`Error::Malformed`]: enum.Error.html#variant.Malformed - /// [`request`]: #tymethod.request - fn advance(&mut self, len: usize) -> Result<(), Self::Err>; + fn request(&mut self, len: usize) -> Result; /// Returns a bytes slice with the available data. /// @@ -72,30 +60,58 @@ pub trait Source { /// The method returns a [`Bytes`] value of the range `start..end` from /// the beginning of the current view of the source. Both indexes must /// not be greater than the value returned by the last successful call - /// to [`request`]. + /// to [`request`][Self::request]. /// /// # Panics /// - /// The method panics if `start` or `end` are larger than the last - /// successful call to [`request`]. - /// - /// [`Bytes`]: ../../bytes/struct.Bytes.html - /// [`request`]: #tymethod.request + /// The method panics if `start` or `end` are larger than the result of + /// the last successful call to [`request`][Self::request]. fn bytes(&self, start: usize, end: usize) -> Bytes; + /// Advance the source by `len` bytes. + /// + /// The method advances the start of the view provided by the source by + /// `len` bytes. This value must not be greater than the value returned + /// by the last successful call to [`request`][Self::request]. + /// + /// # Panics + /// + /// The method panics if `len` is larger than the result of the last + /// successful call to [`request`][Self::request]. + fn advance(&mut self, len: usize); + + /// Skip over the next `len` bytes. + /// + /// The method attempts to advance the source by `len` bytes or by + /// however many bytes are still available if this number is smaller, + /// without making these bytes available. + /// + /// Returns the number of bytes skipped over. This value may only differ + /// from len if the remainder of the source contains less than `len` + /// bytes. + /// + /// The default implementation uses `request` and `advance`. However, for + /// some sources it may be significantly cheaper to provide a specialised + /// implementation. + fn skip(&mut self, len: usize) -> Result { + let res = min(self.request(len)?, len); + self.advance(res); + Ok(res) + } + //--- Advanced access /// Takes a single octet from the source. /// /// If there aren’t any more octets available from the source, returns - /// a malformed error. - fn take_u8(&mut self) -> Result { + /// a content error. + fn take_u8(&mut self) -> Result> { if self.request(1)? < 1 { - xerr!(return Err(Error::Malformed.into())) + return Err(self.content_err("unexpected end of data")) } let res = self.slice()[0]; - self.advance(1)?; + self.advance(1); Ok(res) } @@ -103,85 +119,258 @@ pub trait Source { /// /// If there aren’t any more octets available from the source, returns /// `Ok(None)`. - fn take_opt_u8(&mut self) -> Result, Self::Err> { + fn take_opt_u8(&mut self) -> Result, Self::Error> { if self.request(1)? < 1 { return Ok(None) } let res = self.slice()[0]; - self.advance(1)?; + self.advance(1); Ok(Some(res)) } + + /// Returns a content error at the current position of the source. + fn content_err( + &self, err: impl Into + ) -> DecodeError { + DecodeError::content(err.into(), self.pos()) + } } -impl Source for Bytes { - type Err = Error; +impl<'a, T: Source> Source for &'a mut T { + type Error = T::Error; - fn request(&mut self, _len: usize) -> Result { - Ok(self.len()) + fn request(&mut self, len: usize) -> Result { + Source::request(*self, len) } - - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - if len > self.len() { - Err(Error::Malformed) - } - else { - bytes::Buf::advance(self, len); - Ok(()) - } + + fn advance(&mut self, len: usize) { + Source::advance(*self, len) } fn slice(&self) -> &[u8] { - self.as_ref() + Source::slice(*self) } fn bytes(&self, start: usize, end: usize) -> Bytes { - self.slice(start..end) + Source::bytes(*self, start, end) + } + + fn pos(&self) -> Pos { + Source::pos(*self) } } -impl<'a> Source for &'a [u8] { - type Err = Error; - fn request(&mut self, _len: usize) -> Result { - Ok(self.len()) +//------------ IntoSource ---------------------------------------------------- + +/// A type that can be converted into a source. +pub trait IntoSource { + type Source: Source; + + fn into_source(self) -> Self::Source; +} + +impl IntoSource for T { + type Source = Self; + + fn into_source(self) -> Self::Source { + self } +} - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - if len > self.len() { - Err(Error::Malformed) - } - else { - *self = &self[len..]; - Ok(()) - } + +//------------ Pos ----------------------------------------------------------- + +/// The logical position within a source. +/// +/// Values of this type can only be used for diagnostics. They can not be used +/// to determine how far a source has been advanced since it was created. This +/// is why we used a newtype. +#[derive(Clone, Copy, Debug, Default)] +pub struct Pos(usize); + +impl From for Pos { + fn from(pos: usize) -> Pos { + Pos(pos) + } +} + +impl ops::Add for Pos { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Pos(self.0 + rhs.0) + } +} + +impl fmt::Display for Pos { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + + +//------------ BytesSource --------------------------------------------------- + +/// A source for a bytes value. +#[derive(Clone, Debug)] +pub struct BytesSource { + /// The bytes. + data: Bytes, + + /// The current read position in the data. + pos: usize, + + /// The offset for the reported position. + /// + /// This is the value reported by `Source::pos` when `self.pos` is zero. + offset: Pos, +} + +impl BytesSource { + /// Creates a new bytes source from a bytes values. + pub fn new(data: Bytes) -> Self { + BytesSource { data, pos: 0, offset: 0.into() } + } + + /// Creates a new bytes source with an explicit offset. + /// + /// When this function is used to create a bytes source, `Source::pos` + /// will report a value increates by `offset`. + pub fn with_offset(data: Bytes, offset: Pos) -> Self { + BytesSource { data, pos: 0, offset } + } + + /// Returns the remaining length of data. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns whether there is any data remaining. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Splits the first `len` bytes off the source and returns them. + /// + /// # Panics + /// + /// This method panics of `len` is larger than `self.len()`. + pub fn split_to(&mut self, len: usize) -> Bytes { + let res = self.data.split_to(len); + self.pos += len; + res + } + + /// Converts the source into the remaining bytes. + pub fn into_bytes(self) -> Bytes { + self.data + } +} + +impl Source for BytesSource { + type Error = Infallible; + + fn pos(&self) -> Pos { + self.offset + self.pos.into() + } + + fn request(&mut self, _len: usize) -> Result { + Ok(self.data.len()) } fn slice(&self) -> &[u8] { - self + self.data.as_ref() } fn bytes(&self, start: usize, end: usize) -> Bytes { - Bytes::copy_from_slice(&self[start..end]) + self.data.slice(start..end) + } + + fn advance(&mut self, len: usize) { + assert!(len <= self.data.len()); + bytes::Buf::advance(&mut self.data, len); + self.pos += len; } } -impl<'a, T: Source> Source for &'a mut T { - type Err = T::Err; +impl IntoSource for Bytes { + type Source = BytesSource; - fn request(&mut self, len: usize) -> Result { - Source::request(*self, len) + fn into_source(self) -> Self::Source { + BytesSource::new(self) } - - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - Source::advance(*self, len) +} + + +//------------ SliceSource --------------------------------------------------- + +#[derive(Clone, Copy, Debug)] +pub struct SliceSource<'a> { + data: &'a [u8], + pos: usize +} + +impl<'a> SliceSource<'a> { + /// Creates a new bytes source from a slice. + pub fn new(data: &'a [u8]) -> Self { + SliceSource { data, pos: 0 } + } + + /// Returns the remaining length of data. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns whether there is any data remaining. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Splits the first `len` bytes off the source and returns them. + /// + /// # Panics + /// + /// This method panics of `len` is larger than `self.len()`. + pub fn split_to(&mut self, len: usize) -> &'a [u8] { + let (left, right) = self.data.split_at(len); + self.data = right; + self.pos += len; + left + } +} + +impl<'a> Source for SliceSource<'a> { + type Error = Infallible; + + fn pos(&self) -> Pos { + self.pos.into() + } + + fn request(&mut self, _len: usize) -> Result { + Ok(self.data.len()) + } + + fn advance(&mut self, len: usize) { + assert!(len <= self.data.len()); + self.data = &self.data[len..]; + self.pos += len; } fn slice(&self) -> &[u8] { - Source::slice(*self) + self.data } fn bytes(&self, start: usize, end: usize) -> Bytes { - Source::bytes(*self, start, end) + Bytes::copy_from_slice(&self.data[start..end]) + } +} + +impl<'a> IntoSource for &'a [u8] { + type Source = SliceSource<'a>; + + fn into_source(self) -> Self::Source { + SliceSource::new(self) } } @@ -273,9 +462,13 @@ impl LimitedSource { /// If there currently is no limit, the method will panic. Otherwise it /// will simply advance to the end of the limit which may be something /// the underlying source doesn’t like and thus produce an error. - pub fn skip_all(&mut self) -> Result<(), S::Err> { + pub fn skip_all(&mut self) -> Result<(), DecodeError> { let limit = self.limit.unwrap(); - self.advance(limit) + if self.request(limit)? < limit { + return Err(self.content_err("unexpected end of data")) + } + self.advance(limit); + Ok(()) } /// Returns a bytes value containing all octets until the current limit. @@ -283,14 +476,17 @@ impl LimitedSource { /// If there currently is no limit, the method will panic. Otherwise it /// tries to acquire a bytes value for the octets from the current /// position to the end of the limit and advance to the end of the limit. - /// This may result in an error by the underlying source. - pub fn take_all(&mut self) -> Result { + /// + /// This will result in a source error if the underlying source returns + /// an error. It will result in a content error if the underlying source + /// ends before the limit is reached. + pub fn take_all(&mut self) -> Result> { let limit = self.limit.unwrap(); if self.request(limit)? < limit { - return Err(Error::Malformed.into()) + return Err(self.content_err("unexpected end of data")) } let res = self.bytes(0, limit); - self.advance(limit)?; + self.advance(limit); Ok(res) } @@ -303,18 +499,19 @@ impl LimitedSource { /// If there is no limit set, the method will try to access one single /// octet and return a malformed error if that is actually possible, i.e., /// if there are octets left in the underlying source. - pub fn exhausted(&mut self) -> Result<(), S::Err> { + /// + /// Any source errors are passed through. If there the data is not + /// exhausted as described above, a content error is created. + pub fn exhausted(&mut self) -> Result<(), DecodeError> { match self.limit { Some(0) => Ok(()), - Some(_limit) => { - xerr!(Err(Error::Malformed.into())) - } + Some(_limit) => Err(self.content_err("trailing data")), None => { if self.source.request(1)? == 0 { Ok(()) } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("trailing data")) } } } @@ -322,9 +519,13 @@ impl LimitedSource { } impl Source for LimitedSource { - type Err = S::Err; + type Error = S::Error; - fn request(&mut self, len: usize) -> Result { + fn pos(&self) -> Pos { + self.source.pos() + } + + fn request(&mut self, len: usize) -> Result { if let Some(limit) = self.limit { Ok(min(limit, self.source.request(min(limit, len))?)) } @@ -333,11 +534,12 @@ impl Source for LimitedSource { } } - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { + fn advance(&mut self, len: usize) { if let Some(limit) = self.limit { - if len > limit { - xerr!(return Err(Error::Malformed.into())) - } + assert!( + len <= limit, + "advanced past end of limit" + ); self.limit = Some(limit - len); } self.source.advance(len) @@ -376,7 +578,13 @@ impl Source for LimitedSource { /// /// [`Constructed::capture`]: struct.Constructed.html#method.capture pub struct CaptureSource<'a, S: 'a> { + /// The wrapped real source. source: &'a mut S, + + /// The number of bytes the source has promised to have for us. + len: usize, + + /// The position in the source our view starts at. pos: usize, } @@ -385,7 +593,8 @@ impl<'a, S: Source> CaptureSource<'a, S> { pub fn new(source: &'a mut S) -> Self { CaptureSource { source, - pos: 0 + len: 0, + pos: 0, } } @@ -400,26 +609,20 @@ impl<'a, S: Source> CaptureSource<'a, S> { /// /// Advances the underlying source to the end of the captured bytes. pub fn skip(self) { - assert!( - !self.source.advance(self.pos).is_err(), - "failed to advance capture source" - ); + self.source.advance(self.pos) } } impl<'a, S: Source + 'a> Source for CaptureSource<'a, S> { - type Err = S::Err; + type Error = S::Error; - fn request(&mut self, len: usize) -> Result { - self.source.request(self.pos + len).map(|res| res - self.pos) + fn pos(&self) -> Pos { + self.source.pos() + self.pos.into() } - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - if self.request(len)? < len { - return Err(Error::Malformed.into()) - } - self.pos += len; - Ok(()) + fn request(&mut self, len: usize) -> Result { + self.len = self.source.request(self.pos + len)?; + Ok(self.len - self.pos) } fn slice(&self) -> &[u8] { @@ -427,7 +630,25 @@ impl<'a, S: Source + 'a> Source for CaptureSource<'a, S> { } fn bytes(&self, start: usize, end: usize) -> Bytes { - self.source.bytes(start + self.pos, end + self.pos) + let start = start + self.pos; + let end = end + self.pos; + assert!( + self.len >= start, + "start past the end of data" + ); + assert!( + self.len >= end, + "end past the end of data" + ); + self.source.bytes(start, end) + } + + fn advance(&mut self, len: usize) { + assert!( + self.len >= self.pos + len, + "advanced past the end of data" + ); + self.pos += len; } } @@ -440,77 +661,70 @@ mod test { #[test] fn take_u8() { - let mut source = &b"123"[..]; - assert_eq!(source.take_u8(), Ok(b'1')); - assert_eq!(source.take_u8(), Ok(b'2')); - assert_eq!(source.take_u8(), Ok(b'3')); - assert_eq!(source.take_u8(), Err(Error::Malformed)); + let mut source = b"123".into_source(); + assert_eq!(source.take_u8().unwrap(), b'1'); + assert_eq!(source.take_u8().unwrap(), b'2'); + assert_eq!(source.take_u8().unwrap(), b'3'); + assert!(source.take_u8().is_err()) } #[test] fn take_opt_u8() { - let mut source = &b"123"[..]; - assert_eq!(source.take_opt_u8(), Ok(Some(b'1'))); - assert_eq!(source.take_opt_u8(), Ok(Some(b'2'))); - assert_eq!(source.take_opt_u8(), Ok(Some(b'3'))); - assert_eq!(source.take_opt_u8(), Ok(None)); + let mut source = b"123".into_source(); + assert_eq!(source.take_opt_u8().unwrap(), Some(b'1')); + assert_eq!(source.take_opt_u8().unwrap(), Some(b'2')); + assert_eq!(source.take_opt_u8().unwrap(), Some(b'3')); + assert_eq!(source.take_opt_u8().unwrap(), None); } #[test] fn bytes_impl() { - let mut bytes = Bytes::from_static(b"1234567890"); + let mut bytes = Bytes::from_static(b"1234567890").into_source(); assert!(bytes.request(4).unwrap() >= 4); assert!(&Source::slice(&bytes)[..4] == b"1234"); assert_eq!(bytes.bytes(2, 4), Bytes::from_static(b"34")); - Source::advance(&mut bytes, 4).unwrap(); + Source::advance(&mut bytes, 4); assert!(bytes.request(4).unwrap() >= 4); assert!(&Source::slice(&bytes)[..4] == b"5678"); - Source::advance(&mut bytes, 4).unwrap(); + Source::advance(&mut bytes, 4); assert_eq!(bytes.request(4).unwrap(), 2); - assert!(&Source::slice(&bytes)[..] == b"90"); - assert_eq!( - Source::advance(&mut bytes, 4).unwrap_err(), - Error::Malformed - ); + assert!(&Source::slice(&bytes) == b"90"); + bytes.advance(2); + assert_eq!(bytes.request(4).unwrap(), 0); } #[test] fn slice_impl() { - let mut bytes = &b"1234567890"[..]; + let mut bytes = b"1234567890".into_source(); assert!(bytes.request(4).unwrap() >= 4); - assert!(&Source::slice(&bytes)[..4] == b"1234"); + assert!(&bytes.slice()[..4] == b"1234"); assert_eq!(bytes.bytes(2, 4), Bytes::from_static(b"34")); - Source::advance(&mut bytes, 4).unwrap(); + bytes.advance(4); assert!(bytes.request(4).unwrap() >= 4); - assert!(&Source::slice(&bytes)[..4] == b"5678"); - Source::advance(&mut bytes, 4).unwrap(); + assert!(&bytes.slice()[..4] == b"5678"); + bytes.advance(4); assert_eq!(bytes.request(4).unwrap(), 2); - assert!(&Source::slice(&bytes)[..] == b"90"); - assert_eq!( - Source::advance(&mut bytes, 4).unwrap_err(), - Error::Malformed - ); + assert!(&bytes.slice() == b"90"); + bytes.advance(2); + assert_eq!(bytes.request(4).unwrap(), 0); } #[test] fn limited_source() { - let mut the_source = LimitedSource::new(&b"12345678"[..]); + let mut the_source = LimitedSource::new( + b"12345678".into_source() + ); the_source.set_limit(Some(4)); let mut source = the_source.clone(); - assert_eq!(source.exhausted(), Err(Error::Malformed)); + assert!(source.exhausted().is_err()); assert_eq!(source.request(6).unwrap(), 4); - source.advance(2).unwrap(); - assert_eq!(source.exhausted(), Err(Error::Malformed)); + source.advance(2); + assert!(source.exhausted().is_err()); assert_eq!(source.request(6).unwrap(), 2); - source.advance(2).unwrap(); + source.advance(2); source.exhausted().unwrap(); assert_eq!(source.request(6).unwrap(), 0); - source.advance(0).unwrap(); - assert_eq!(source.advance(2).unwrap_err(), Error::Malformed); - - let mut source = the_source.clone(); - assert_eq!(source.advance(5).unwrap_err(), Error::Malformed); let mut source = the_source.clone(); source.skip_all().unwrap(); @@ -524,30 +738,46 @@ mod test { assert_eq!(source.slice(), b"5678"); } + #[test] + #[should_panic] + fn limited_source_far_advance() { + let mut source = LimitedSource::new( + b"12345678".into_source() + ); + source.set_limit(Some(4)); + assert_eq!(source.request(6).unwrap(), 4); + source.advance(4); + assert_eq!(source.request(6).unwrap(), 0); + source.advance(6); // panics + } + #[test] #[should_panic] fn limit_further() { - let mut source = LimitedSource::new(&b"12345"); + let mut source = LimitedSource::new(b"12345".into_source()); source.set_limit(Some(4)); - source.limit_further(Some(5)); + source.limit_further(Some(5)); // panics } #[test] fn capture_source() { - let mut source = &b"1234567890"[..]; + let mut source = b"1234567890".into_source(); { let mut capture = CaptureSource::new(&mut source); - capture.advance(4).unwrap(); + assert_eq!(capture.request(4).unwrap(), 10); + capture.advance(4); assert_eq!(capture.into_bytes(), Bytes::from_static(b"1234")); } - assert_eq!(source, b"567890"); + assert_eq!(source.data, b"567890"); - let mut source = &b"1234567890"[..]; + let mut source = b"1234567890".into_source(); { let mut capture = CaptureSource::new(&mut source); - capture.advance(4).unwrap(); + assert_eq!(capture.request(4).unwrap(), 10); + capture.advance(4); capture.skip(); } - assert_eq!(source, b"567890"); + assert_eq!(source.data, b"567890"); } } + diff --git a/src/encode/values.rs b/src/encode/values.rs index 3218764..e541525 100644 --- a/src/encode/values.rs +++ b/src/encode/values.rs @@ -48,7 +48,7 @@ pub trait Values { fn to_captured(&self, mode: Mode) -> Captured { let mut target = Vec::new(); self.write_encoded(mode, &mut target).unwrap(); - Captured::new(target.into(), mode) + Captured::new(target.into(), mode, Default::default()) } } diff --git a/src/guide/decode.rs b/src/guide/decode.rs index bd4ad5d..77134f7 100644 --- a/src/guide/decode.rs +++ b/src/guide/decode.rs @@ -19,7 +19,7 @@ //! take them as function arguments such as closures. //! //! An example will make this concept more clear. Let’s say we have the -//! following ASN.1 specifiction: +//! following ASN.1 specification: //! //! ```text //! EncapsulatedContentInfo ::= SEQUENCE { @@ -49,16 +49,17 @@ //! # use bcder::{Oid, OctetString}; //! use bcder::Tag; //! use bcder::decode; +//! use bcder::decode::DecodeError; //! //! # pub struct EncapsulatedContentInfo { //! # content_type: Oid, //! # content: Option, //! # } -//! # +//! # //! impl EncapsulatedContentInfo { //! pub fn take_from( //! cons: &mut decode::Constructed -//! ) -> Result { +//! ) -> Result> { //! cons.take_sequence(|cons| { //! Ok(EncapsulatedContentInfo { //! content_type: Oid::take_from(cons)?, @@ -89,16 +90,17 @@ //! # use bcder::{Oid, OctetString}; //! use bcder::Tag; //! use bcder::decode; +//! use bcder::decode::DecodeError; //! //! # pub struct EncapsulatedContentInfo { //! # content_type: Oid, //! # content: Option, //! # } -//! # +//! # //! impl EncapsulatedContentInfo { //! pub fn from_constructed( //! cons: &mut decode::Constructed -//! ) -> Result { +//! ) -> Result> { //! Ok(EncapsulatedContentInfo { //! content_type: Oid::take_from(cons)?, //! content: cons.take_opt_constructed_if(Tag::ctx(0), |cons| { @@ -110,4 +112,3 @@ //! ``` //! //! _TODO: Elaborate._ - diff --git a/src/int.rs b/src/int.rs index 160f12e..c695e02 100644 --- a/src/int.rs +++ b/src/int.rs @@ -13,11 +13,11 @@ //! [`Integer`]: struct.Integer.html //! [`Unsigned`]: struct.Unsigned.html -use std::{cmp, fmt, hash, io, mem}; +use std::{cmp, error, fmt, hash, io, mem}; use std::convert::TryFrom; use bytes::Bytes; use crate::decode; -use crate::decode::Source; +use crate::decode::{DecodeError, Source}; use crate::encode::PrimitiveContent; use crate::mode::Mode; use crate::tag::Tag; @@ -35,7 +35,7 @@ macro_rules! slice_to_builtin { ( signed, $slice:expr, $type:ident, $err:expr) => {{ const LEN: usize = mem::size_of::<$type>(); if $slice.len() > LEN { - Err($err) + $err } else { // Start with all zeros if positive or all 0xFF if negative @@ -55,7 +55,7 @@ macro_rules! slice_to_builtin { // out if the sign bit is set. const LEN: usize = mem::size_of::<$type>(); if $slice[0] & 0x80 != 0 { - Err($err) + $err } else { let val = if $slice[0] == 0 { &$slice[1..] } @@ -64,7 +64,7 @@ macro_rules! slice_to_builtin { Ok(0) } else if val.len() > LEN { - Err($err) + $err } else { let mut res = [0; LEN]; @@ -81,7 +81,8 @@ macro_rules! decode_builtin { let res = { let slice = $prim.slice_all()?; slice_to_builtin!( - $flavor, slice, $type, decode::Malformed.into() + $flavor, slice, $type, + Err($prim.content_err("invalid integer")) )? }; $prim.skip_all()?; @@ -108,7 +109,7 @@ macro_rules! builtin_from { fn try_from(val: &'a $from) -> Result<$to, Self::Error> { let val = val.as_slice(); - slice_to_builtin!($flavor, val, $to, OverflowError(())) + slice_to_builtin!($flavor, val, $to, Err(OverflowError(()))) } } @@ -161,24 +162,24 @@ impl Integer { /// a correctly encoded signed integer. pub fn take_from( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.take_primitive_if(Tag::INTEGER, Self::from_primitive) } /// Constructs a signed integer from the content of a primitive value. pub fn from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { let res = prim.take_all()?; - match (res.get(0), res.get(1).map(|x| x & 0x80 != 0)) { + match (res.first(), res.get(1).map(|x| x & 0x80 != 0)) { (Some(0), Some(false)) => { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(prim.content_err("invalid integer")) } (Some(0xFF), Some(true)) => { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(prim.content_err("invalid integer")) } (None, _) => { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(prim.content_err("invalid integer")) } _ => { } } @@ -188,7 +189,7 @@ impl Integer { /// Constructs an `i8` from the content of a primitive value. pub fn i8_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { Self::check_head(prim)?; prim.take_u8().map(|x| x as i8) } @@ -196,28 +197,28 @@ impl Integer { /// Constructs an `i16` from the content of a primitive value. pub fn i16_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { decode_builtin!(signed, prim, i16) } /// Constructs an `i32` from the content of a primitive value. pub fn i32_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { decode_builtin!(signed, prim, i32) } /// Constructs an `i64` from the content of a primitive value. pub fn i64_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { decode_builtin!(signed, prim, i64) } /// Constructs an `i128` from the content of a primitive value. pub fn i128_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { decode_builtin!(signed, prim, i128) } @@ -232,17 +233,17 @@ impl Integer { /// for equality comparision. fn check_head( prim: &mut decode::Primitive - ) -> Result<(), S::Err> { + ) -> Result<(), DecodeError> { if prim.request(2)? == 0 { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(prim.content_err("invalid integer")) } let slice = prim.slice(); - match (slice.get(0), slice.get(1).map(|x| x & 0x80 != 0)) { + match (slice.first(), slice.get(1).map(|x| x & 0x80 != 0)) { (Some(0), Some(false)) => { - xerr!(Err(decode::Error::Malformed.into())) + Err(prim.content_err("invalid integer")) } (Some(0xFF), Some(true)) => { - xerr!(Err(decode::Error::Malformed.into())) + Err(prim.content_err("invalid integer")) } _ => Ok(()) } @@ -450,8 +451,8 @@ impl Unsigned { /// /// # Errors /// - /// Will return `Error::Malformed` if the given slice is empty. - pub fn from_slice(slice: &[u8]) -> Result { + /// Will return a malformed error if the given slice is empty. + pub fn from_slice(slice: &[u8]) -> Result { Self::from_bytes(Bytes::copy_from_slice(slice)) } @@ -459,14 +460,16 @@ impl Unsigned { /// /// # Errors /// - /// Will return `Error::Malformed` if the given Bytes value is empty. - pub fn from_bytes(bytes: Bytes) -> Result { + /// Will return a malformed error if the given slice is empty. + pub fn from_bytes(bytes: Bytes) -> Result { if bytes.is_empty() { - return Err(crate::decode::Error::Malformed); + return Err(InvalidInteger(())) } // Skip any leading zero bytes. - let num_leading_zero_bytes = bytes.as_ref().iter().take_while(|&&b| b == 0x00).count(); + let num_leading_zero_bytes = bytes.as_ref().iter().take_while(|&&b| { + b == 0x00 + }).count(); let value = bytes.slice(num_leading_zero_bytes..); // Create a new Unsigned integer from the given value bytes, ensuring @@ -490,39 +493,39 @@ impl Unsigned { pub fn take_from( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.take_primitive_if(Tag::INTEGER, Self::from_primitive) } pub fn from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { Self::check_head(prim)?; Integer::from_primitive(prim).map(Unsigned) } pub fn u8_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { Self::check_head(prim)?; match prim.remaining() { 1 => prim.take_u8(), // sign bit has been checked above. 2 => { // First byte must be 0x00, second is the result. if prim.take_u8()? != 0 { - xerr!(Err(decode::Malformed.into())) + Err(prim.content_err("invalid integer")) } else { prim.take_u8() } } - _ => xerr!(Err(decode::Malformed.into())) + _ => Err(prim.content_err("invalid integer")) } } pub fn u16_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { Self::check_head(prim)?; match prim.remaining() { 1 => Ok(prim.take_u8()?.into()), @@ -534,7 +537,7 @@ impl Unsigned { } 3 => { if prim.take_u8()? != 0 { - xerr!(return Err(decode::Malformed.into())); + return Err(prim.content_err("invalid integer")) } let res = { u16::from(prim.take_u8()?) << 8 | @@ -542,34 +545,31 @@ impl Unsigned { }; if res < 0x8000 { // This could have been in fewer bytes. - Err(decode::Malformed.into()) + Err(prim.content_err("invalid integer")) } else { Ok(res) } } - _ => xerr!(Err(decode::Malformed.into())) + _ => Err(prim.content_err("invalid integer")) } } pub fn u32_from_primitive( prim: &mut decode::Primitive - ) -> Result { - Self::check_head(prim)?; + ) -> Result> { decode_builtin!(unsigned, prim, u32) } pub fn u64_from_primitive( prim: &mut decode::Primitive - ) -> Result { - Self::check_head(prim)?; + ) -> Result> { decode_builtin!(unsigned, prim, u64) } pub fn u128_from_primitive( prim: &mut decode::Primitive - ) -> Result { - Self::check_head(prim)?; + ) -> Result> { decode_builtin!(unsigned, prim, u128) } @@ -579,10 +579,10 @@ impl Unsigned { /// sign bit is not set. fn check_head( prim: &mut decode::Primitive - ) -> Result<(), S::Err> { + ) -> Result<(), DecodeError> { Integer::check_head(prim)?; - if prim.slice().get(0).unwrap() & 0x80 != 0 { - xerr!(Err(decode::Error::Malformed.into())) + if prim.slice().first().unwrap() & 0x80 != 0 { + Err(prim.content_err("invalid integer")) } else { Ok(()) @@ -626,8 +626,8 @@ builtin_from!(unsigned, Unsigned, u64); builtin_from!(unsigned, Unsigned, u128); -impl<'a> TryFrom for Unsigned { - type Error = crate::decode::Error; +impl TryFrom for Unsigned { + type Error = InvalidInteger; fn try_from(value: Bytes) -> Result { Unsigned::from_bytes(value) @@ -674,6 +674,21 @@ impl<'a> PrimitiveContent for &'a Unsigned { } +//------------ InvalidInteger ------------------------------------------------ + +/// A octets slice does not contain a validly encoded integer. +#[derive(Clone, Copy, Debug)] +pub struct InvalidInteger(()); + +impl fmt::Display for InvalidInteger { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "invalid integer") + } +} + +impl error::Error for InvalidInteger { } + + //------------ OverflowError ------------------------------------------------- #[derive(Clone, Copy, Debug)] @@ -685,6 +700,8 @@ impl fmt::Display for OverflowError { } } +impl error::Error for OverflowError { } + //============ Tests ========================================================= @@ -709,15 +726,15 @@ mod test { let pos = [0xF7, 0xF744, 0xF74402]; for &i in &neg { - assert_eq!(Integer::from(i).is_positive(), false, "{}", i); - assert_eq!(Integer::from(i).is_negative(), true, "{}", i); + assert!(!Integer::from(i).is_positive(), "{}", i); + assert!(Integer::from(i).is_negative(), "{}", i); } for &i in &pos { - assert_eq!(Integer::from(i).is_positive(), true, "{}", i); - assert_eq!(Integer::from(i).is_negative(), false, "{}", i); + assert!(Integer::from(i).is_positive(), "{}", i); + assert!(!Integer::from(i).is_negative(), "{}", i); } - assert_eq!(Integer::from(0).is_positive(), false); - assert_eq!(Integer::from(0).is_negative(), false); + assert!(!Integer::from(0).is_positive()); + assert!(!Integer::from(0).is_negative()); } #[test] @@ -750,12 +767,11 @@ mod test { ).unwrap(), 0x7f ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x80".as_ref(), Mode::Der, |prim| Unsigned::u8_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( Primitive::decode_slice( @@ -786,19 +802,17 @@ mod test { ).unwrap(), 0xA234 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\xA2\x34".as_ref(), Mode::Der, |prim| Unsigned::u16_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x00\x12\x34".as_ref(), Mode::Der, |prim| Unsigned::u16_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -836,19 +850,17 @@ mod test { ).unwrap(), 0xA2345678 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x00\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u32_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\xa2\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u32_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -872,19 +884,17 @@ mod test { ).unwrap(), 0xa234567812345678 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\0\x12\x34\x56\x78\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u64_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x30\x12\x34\x56\x78\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u64_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -910,21 +920,19 @@ mod test { ).unwrap(), 0xa2345678123456781234567812345678 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\0\x12\x34\x56\x78\x12\x34\x56\x78 \x12\x34\x56\x78\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u128_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x30\x12\x34\x56\x78\x12\x34\x56\x78 \x12\x34\x56\x78\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u128_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); } @@ -944,19 +952,17 @@ mod test { ).unwrap(), -1 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x00\xFF".as_ref(), Mode::Der, |prim| Integer::i8_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x40\xFF".as_ref(), Mode::Der, |prim| Integer::i8_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -980,19 +986,17 @@ mod test { ).unwrap(), -32513 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x80\xFF\x32".as_ref(), Mode::Der, |prim| Integer::i16_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x00\xFF\x32".as_ref(), Mode::Der, |prim| Integer::i16_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -1140,28 +1144,66 @@ mod test { #[test] fn encode_variable_length_unsigned_from_slice() { - assert_eq!(Unsigned::from_slice(&[]), Err(crate::decode::Error::Malformed)); + assert!(Unsigned::from_slice(&[]).is_err()); test_der(&Unsigned::from_slice(&[0xFF]).unwrap(), b"\x00\xFF"); test_der(&Unsigned::from_slice(&[0x00, 0xFF]).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_slice(&[0x00, 0x00, 0xFF]).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_slice(&[0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF]).unwrap(), b"\x00\xDE\xAD\xBE\xEF"); + test_der( + &Unsigned::from_slice(&[0x00, 0x00, 0xFF]).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::from_slice( + &[0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF] + ).unwrap(), + b"\x00\xDE\xAD\xBE\xEF" + ); } #[test] fn encode_variable_length_unsigned_from_bytes() { - assert_eq!(Unsigned::from_bytes(Bytes::new()), Err(crate::decode::Error::Malformed)); - test_der(&Unsigned::from_bytes(Bytes::from(vec![0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_bytes(Bytes::from(vec![0x00, 0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_bytes(Bytes::from(vec![0x00, 0x00, 0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_bytes(Bytes::from(vec![0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF])).unwrap(), b"\x00\xDE\xAD\xBE\xEF"); + assert!(Unsigned::from_bytes(Bytes::new()).is_err()); + test_der( + &Unsigned::from_bytes(Bytes::from(vec![0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::from_bytes(Bytes::from(vec![0x00, 0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::from_bytes(Bytes::from( + vec![0x00, 0x00, 0xFF]) + ).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::from_bytes(Bytes::from( + vec![0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF] + )).unwrap(), + b"\x00\xDE\xAD\xBE\xEF" + ); } #[test] fn encode_variable_length_unsigned_try_from_bytes() { - assert_eq!(Unsigned::try_from(Bytes::new()), Err(crate::decode::Error::Malformed)); - test_der(&Unsigned::try_from(Bytes::from(vec![0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::try_from(Bytes::from(vec![0x00, 0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::try_from(Bytes::from(vec![0x00, 0x00, 0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::try_from(Bytes::from(vec![0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF])).unwrap(), b"\x00\xDE\xAD\xBE\xEF"); + assert!(Unsigned::try_from(Bytes::new()).is_err()); + test_der( + &Unsigned::try_from(Bytes::from(vec![0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::try_from(Bytes::from(vec![0x00, 0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::try_from(Bytes::from(vec![0x00, 0x00, 0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::try_from(Bytes::from( + vec![0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF] + )).unwrap(), + b"\x00\xDE\xAD\xBE\xEF" + ); } } diff --git a/src/length.rs b/src/length.rs index e2ecba4..5c36ff6 100644 --- a/src/length.rs +++ b/src/length.rs @@ -3,7 +3,7 @@ //! This is a private module. Its public tiems are re-exported by the parent. use std::io; -use crate::decode; +use crate::decode::{DecodeError, Source}; use crate::mode::Mode; @@ -52,10 +52,10 @@ pub enum Length { impl Length { /// Takes a length value from the beginning of a source. - pub fn take_from( + pub fn take_from( source: &mut S, mode: Mode - ) -> Result { + ) -> Result> { match source.take_u8()? { // Bit 7 clear: other bits are the length n if (n & 0x80) == 0 => Ok(Length::Definite(n as usize)), @@ -70,7 +70,7 @@ impl Length { Ok(Length::Definite(len)) } else { - Err(decode::Error::Malformed.into()) + Err(source.content_err("invalid length")) } } 0x82 => { @@ -81,7 +81,7 @@ impl Length { Ok(Length::Definite(len)) } else { - Err(decode::Error::Malformed.into()) + Err(source.content_err("invalid length")) } } 0x83 => { @@ -93,7 +93,7 @@ impl Length { Ok(Length::Definite(len)) } else { - Err(decode::Error::Malformed.into()) + Err(source.content_err("invalid length")) } } 0x84 => { @@ -106,12 +106,14 @@ impl Length { Ok(Length::Definite(len)) } else { - Err(decode::Error::Malformed.into()) + Err(source.content_err("invalid length")) } } _ => { // We only implement up to two length bytes for now. - Err(decode::Error::Unimplemented.into()) + Err(source.content_err( + "lengths over 4 bytes not implemented" + )) } } } diff --git a/src/lib.rs b/src/lib.rs index 79974f8..de5084b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,8 +51,6 @@ pub use self::tag::Tag; //--- Public modules -#[macro_use] pub mod debug; - pub mod decode; pub mod encode; diff --git a/src/mode.rs b/src/mode.rs index 9ce8eeb..7fd2bed 100644 --- a/src/mode.rs +++ b/src/mode.rs @@ -3,6 +3,7 @@ //! This is a private module. It’s public items are re-exported by the parent. use crate::decode; +use crate::decode::DecodeError; //------------ Mode ---------------------------------------------------------- @@ -50,10 +51,14 @@ impl Mode { /// by this value. The closure `op` will be given the content of the /// source as a sequence of values. The closure does not need to process /// all values in the source. - pub fn decode(self, source: S, op: F) -> Result + pub fn decode( + self, source: S, op: F, + ) -> Result::Error>> where - S: decode::Source, - F: FnOnce(&mut decode::Constructed) -> Result + S: decode::IntoSource, + F: FnOnce( + &mut decode::Constructed + ) -> Result::Error>>, { decode::Constructed::decode(source, self, op) } diff --git a/src/oid.rs b/src/oid.rs index 7bcda9f..a783bae 100644 --- a/src/oid.rs +++ b/src/oid.rs @@ -8,8 +8,8 @@ use std::{fmt, hash, io}; use bytes::Bytes; -use crate::{decode, encode}; -use crate::decode::Source; +use crate::encode; +use crate::decode::{Constructed, DecodeError, Source}; use crate::mode::Mode; use crate::tag::Tag; @@ -44,7 +44,7 @@ use crate::tag::Tag; /// and produces the `u8` array for their encoded value. You can install /// this binary via `cargo install ber`. #[derive(Clone, Debug)] -pub struct Oid=Bytes>(pub T); +pub struct Oid = Bytes>(pub T); /// A type alias for `Oid<&'static [u8]>. /// @@ -60,9 +60,9 @@ impl Oid { /// If the source has reached its end, if the next value does not have /// the `Tag::OID`, or if it is not a primitive value, returns a malformed /// error. - pub fn skip_in( - cons: &mut decode::Constructed - ) -> Result<(), S::Err> { + pub fn skip_in( + cons: &mut Constructed + ) -> Result<(), DecodeError> { cons.take_primitive_if(Tag::OID, |prim| prim.skip_all()) } @@ -71,9 +71,9 @@ impl Oid { /// If the source has reached its end of if the next value does not have /// the `Tag::OID`, returns `Ok(None)`. If the next value has the right /// tag but is not a primitive value, returns a malformed error. - pub fn skip_opt_in( - cons: &mut decode::Constructed - ) -> Result, S::Err> { + pub fn skip_opt_in( + cons: &mut Constructed + ) -> Result, DecodeError> { cons.take_opt_primitive_if(Tag::OID, |prim| prim.skip_all()) } @@ -82,9 +82,9 @@ impl Oid { /// If the source has reached its end, if the next value does not have /// the `Tag::OID`, or if it is not a primitive value, returns a malformed /// error. - pub fn take_from( - constructed: &mut decode::Constructed - ) -> Result { + pub fn take_from( + constructed: &mut Constructed + ) -> Result> { constructed.take_primitive_if(Tag::OID, |content| { content.take_all().map(Oid) }) @@ -95,9 +95,9 @@ impl Oid { /// If the source has reached its end of if the next value does not have /// the `Tag::OID`, returns `Ok(None)`. If the next value has the right /// tag but is not a primitive value, returns a malformed error. - pub fn take_opt_from( - constructed: &mut decode::Constructed - ) -> Result, S::Err> { + pub fn take_opt_from( + constructed: &mut Constructed + ) -> Result, DecodeError> { constructed.take_opt_primitive_if(Tag::OID, |content| { content.take_all().map(Oid) }) @@ -106,10 +106,9 @@ impl Oid { impl> Oid { /// Skip over an object identifier if it matches `self`. - pub fn skip_if( - &self, - constructed: &mut decode::Constructed - ) -> Result<(), S::Err> { + pub fn skip_if( + &self, constructed: &mut Constructed, + ) -> Result<(), DecodeError> { constructed.take_primitive_if(Tag::OID, |content| { let len = content.remaining(); content.request(len)?; @@ -118,7 +117,7 @@ impl> Oid { Ok(()) } else { - xerr!(Err(decode::Error::Malformed.into())) + Err(content.content_err("object identifier mismatch")) } }) } @@ -177,7 +176,7 @@ impl> fmt::Display for Oid { // I can’t be bothered to figure out how to convert a seven // bit integer into decimal. let mut components = self.iter(); - // There’s at least one and it is always an valid u32. + // There’s at least one and it is always a valid u32. write!(f, "{}", components.next().unwrap().to_u32().unwrap())?; for component in components { if let Some(val) = component.to_u32() { diff --git a/src/string/bit.rs b/src/string/bit.rs index b998551..0c5624e 100644 --- a/src/string/bit.rs +++ b/src/string/bit.rs @@ -5,7 +5,7 @@ use std::io; use bytes::Bytes; use crate::{decode, encode}; -use crate::decode::Source; +use crate::decode::{DecodeError, Source}; use crate::length::Length; use crate::mode::Mode; use crate::tag::Tag; @@ -129,25 +129,27 @@ impl BitString { /// Takes a single bit string value from constructed content. pub fn take_from( constructed: &mut decode::Constructed - ) -> Result { + ) -> Result> { constructed.take_value_if(Tag::BIT_STRING, Self::from_content) } /// Skip over a single bit string value inside constructed content. pub fn skip_in( cons: &mut decode::Constructed - ) -> Result<(), S::Err> { + ) -> Result<(), DecodeError> { cons.take_value_if(Tag::BIT_STRING, Self::skip_content) } /// Parses the content octets of a bit string value. pub fn from_content( content: &mut decode::Content - ) -> Result { + ) -> Result> { match *content { decode::Content::Primitive(ref mut inner) => { if inner.mode() == Mode::Cer && inner.remaining() > 1000 { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(content.content_err( + "long bit string component in CER mode" + )) } Ok(BitString { unused: inner.take_u8()?, @@ -156,10 +158,14 @@ impl BitString { } decode::Content::Constructed(ref inner) => { if inner.mode() == Mode::Der { - xerr!(Err(decode::Error::Malformed.into())) + Err(content.content_err( + "constructed bit string in DER mode" + )) } else { - xerr!(Err(decode::Error::Unimplemented.into())) + Err(content.content_err( + "constructed bit string not implemented" + )) } } } @@ -168,20 +174,26 @@ impl BitString { /// Skips over the content octets of a bit string value. pub fn skip_content( content: &mut decode::Content - ) -> Result<(), S::Err> { + ) -> Result<(), DecodeError> { match *content { decode::Content::Primitive(ref mut inner) => { if inner.mode() == Mode::Cer && inner.remaining() > 1000 { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(content.content_err( + "long bit string component in CER mode" + )) } inner.skip_all() } decode::Content::Constructed(ref inner) => { if inner.mode() == Mode::Der { - xerr!(Err(decode::Error::Malformed.into())) + Err(content.content_err( + "constructed bit string in DER mode" + )) } else { - xerr!(Err(decode::Error::Unimplemented.into())) + Err(content.content_err( + "constructed bit string not implemented" + )) } } } diff --git a/src/string/octet.rs b/src/string/octet.rs index 68a5b66..a39a000 100644 --- a/src/string/octet.rs +++ b/src/string/octet.rs @@ -7,9 +7,13 @@ //! CER mode. An implementation of that is TODO. use std::{cmp, hash, io, mem}; +use std::convert::Infallible; use bytes::{BytesMut, Bytes}; use crate::captured::Captured; use crate::{decode, encode}; +use crate::decode::{ + BytesSource, DecodeError, IntoSource, Pos, SliceSource, Source +}; use crate::mode::Mode; use crate::length::Length; use crate::tag::Tag; @@ -82,7 +86,9 @@ impl OctetString { OctetStringIter(Inner::Primitive(inner.as_ref())) } Inner::Constructed(ref inner) => { - OctetStringIter(Inner::Constructed(inner.as_ref())) + OctetStringIter( + Inner::Constructed(inner.as_slice().into_source()) + ) } } } @@ -154,15 +160,6 @@ impl OctetString { } !self.iter().any(|s| !s.is_empty()) } - - /// Creates a source that can be used to decode the string’s content. - /// - /// The returned value contains a clone of the string (which, because of - /// the use of `Bytes` is rather cheap) that implements the `Source` - /// trait and thus can be used to decode the string’s content. - pub fn to_source(&self) -> OctetStringSource { - OctetStringSource::new(self) - } } @@ -176,7 +173,7 @@ impl OctetString { /// octet string, a malformed error is returned. pub fn take_from( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.take_value_if(Tag::OCTET_STRING, Self::from_content) } @@ -189,18 +186,20 @@ impl OctetString { /// malformed error is returned. pub fn take_opt_from( cons: &mut decode::Constructed - ) -> Result, S::Err> { + ) -> Result, DecodeError> { cons.take_opt_value_if(Tag::OCTET_STRING, Self::from_content) } /// Takes an octet string value from content. pub fn from_content( content: &mut decode::Content - ) -> Result { + ) -> Result> { match *content { decode::Content::Primitive(ref mut inner) => { if inner.mode() == Mode::Cer && inner.remaining() > 1000 { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(content.content_err( + "long string component in CER mode" + )) } Ok(OctetString(Inner::Primitive(inner.take_all()?))) } @@ -209,7 +208,9 @@ impl OctetString { Mode::Ber => Self::take_constructed_ber(inner), Mode::Cer => Self::take_constructed_cer(inner), Mode::Der => { - xerr!(Err(decode::Error::Malformed.into())) + Err(content.content_err( + "constructed string in DER mode" + )) } } } @@ -221,14 +222,14 @@ impl OctetString { /// It consists octet string values either primitive or constructed. fn take_constructed_ber( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.capture(|cons| { while cons.skip_opt(|tag, _, _| if tag == Tag::OCTET_STRING { Ok(()) } else { - xerr!(Err(decode::Malformed.into())) + Err("expected octet string".into()) } )?.is_some() { } Ok(()) @@ -241,17 +242,21 @@ impl OctetString { /// values each except for the last one exactly 1000 octets long. fn take_constructed_cer( constructed: &mut decode::Constructed - ) -> Result { + ) -> Result> { let mut short = false; constructed.capture(|con| { while let Some(()) = con.take_opt_primitive_if(Tag::OCTET_STRING, |primitive| { if primitive.remaining() > 1000 { - xerr!(return Err(decode::Error::Malformed.into())); + return Err(primitive.content_err( + "long string component in CER mode" + )); } if primitive.remaining() < 1000 { if short { - xerr!(return Err(decode::Error::Malformed.into())); + return Err(primitive.content_err( + "short non-terminal string component in CER mode" + )); } short = true } @@ -467,11 +472,20 @@ impl<'a> IntoIterator for &'a OctetString { } +//--- IntoSource + +impl IntoSource for OctetString { + type Source = OctetStringSource; + + fn into_source(self) -> Self::Source { + OctetStringSource::new(self) + } +} + + //------------ OctetStringSource --------------------------------------------- -/// A source atop an octet string. -/// -/// You can get a value of this type by calling `OctetString::source`. +/// A decode source atop an octet string. // // Assuming we have a correctly encoded octet string, its content is a // sequence of value headers (i.e., tag and length octets) and actual string @@ -487,34 +501,52 @@ pub struct OctetStringSource { current: Bytes, /// The remainder of the value after the value in `current`. - remainder: Bytes, + remainder: BytesSource, + + /// The current position in the string. + pos: Pos, } impl OctetStringSource { /// Creates a new source atop an existing octet string. - fn new(from: &OctetString) -> Self { + fn new(from: OctetString) -> Self { + Self::with_offset(from, Pos::default()) + } + + /// Creates a new source with a given start position. + fn with_offset(from: OctetString, offset: Pos) -> Self { match from.0 { - Inner::Primitive(ref inner) => { + Inner::Primitive(inner) => { OctetStringSource { - current: inner.clone(), - remainder: Bytes::new(), + current: inner, + remainder: Bytes::new().into_source(), + pos: offset, } } - Inner::Constructed(ref inner) => { + Inner::Constructed(inner) => { OctetStringSource { current: Bytes::new(), - remainder: inner.clone().into_bytes() + remainder: inner.into_bytes().into_source(), + pos: offset, } } } } - /// Returns the bytes of the next primitive value in the string. + /// Returns the next value for `self.current`. + /// + /// This is the content of the first primitive value found in the + /// remainder. /// /// Returns `None` if we are done. - fn next_primitive(&mut self) -> Option { - while !self.remainder.is_empty() { - let (tag, cons) = Tag::take_from(&mut self.remainder).unwrap(); + fn next_current(&mut self) -> Option { + // Unwrapping here is okay. The only error that can happen is that + // the tag is longer that we support. However, we already checked that + // there’s only OctetString or End of Value tags which we _do_ + // support. + while let Some((tag, cons)) = Tag::take_opt_from( + &mut self.remainder + ).unwrap() { let length = Length::take_from( &mut self.remainder, Mode::Ber ).unwrap(); @@ -538,15 +570,19 @@ impl OctetStringSource { } impl decode::Source for OctetStringSource { - type Err = decode::Error; + type Error = Infallible; - fn request(&mut self, len: usize) -> Result { + fn pos(&self) -> Pos { + self.pos + } + + fn request(&mut self, len: usize) -> Result { if self.current.len() < len && !self.remainder.is_empty() { // Make a new current that is at least `len` long. let mut current = BytesMut::with_capacity(self.current.len()); current.extend_from_slice(&self.current.clone()); while current.len() < len { - if let Some(bytes) = self.next_primitive() { + if let Some(bytes) = self.next_current() { current.extend_from_slice(bytes.as_ref()) } else { @@ -558,18 +594,10 @@ impl decode::Source for OctetStringSource { Ok(self.current.len()) } - fn advance(&mut self, mut len: usize) -> Result<(), decode::Error> { - while len > self.current.len() { - len -= self.current.len(); - self.current = match self.next_primitive() { - Some(value) => value, - None => { - xerr!(return Err(decode::Error::Malformed)) - } - } - } - self.current.advance(len)?; - Ok(()) + fn advance(&mut self, len: usize) { + assert!(len <= self.current.len()); + self.pos = self.pos + len.into(); + bytes::Buf::advance(&mut self.current, len) } fn slice(&self) -> &[u8] { @@ -589,7 +617,7 @@ impl decode::Source for OctetStringSource { /// You can get a value of this type by calling `OctetString::iter` or relying /// on the `IntoIterator` impl for a `&OctetString`. #[derive(Clone, Debug)] -pub struct OctetStringIter<'a>(Inner<&'a [u8], &'a [u8]>); +pub struct OctetStringIter<'a>(Inner<&'a [u8], SliceSource<'a>>); impl<'a> Iterator for OctetStringIter<'a> { type Item = &'a [u8]; @@ -617,9 +645,7 @@ impl<'a> Iterator for OctetStringIter<'a> { Length::Definite(len) => len, _ => unreachable!() }; - let res = &inner[..length]; - *inner = &inner[length..]; - return Some(res) + return Some(inner.split_to(length)) } Tag::END_OF_VALUE => continue, _ => unreachable!() @@ -888,7 +914,7 @@ mod tests { assert_eq!( decode::Constructed::decode( b"\x24\x04\ - \x04\x02ab".as_ref(), + \x04\x02ab".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) @@ -902,7 +928,7 @@ mod tests { decode::Constructed::decode( b"\x24\x06\ \x04\x01a\ - \x04\x01b".as_ref(), + \x04\x01b".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) @@ -917,7 +943,7 @@ mod tests { b"\x24\x08\ \x24\x80\ \x04\x02ab\ - \0\0".as_ref(), + \0\0".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) @@ -926,13 +952,12 @@ mod tests { "ab" ); - println!("lllllll"); // I(p) assert_eq!( decode::Constructed::decode( b"\x24\x80\ \x04\x02ab\ - \0\0".as_ref(), + \0\0".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) @@ -948,7 +973,7 @@ mod tests { \x04\x01a\ \x24\x80\ \x04\x01b\ - \0\0".as_ref(), + \0\0".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) diff --git a/src/string/restricted.rs b/src/string/restricted.rs index d7ce1a9..0c3e28e 100644 --- a/src/string/restricted.rs +++ b/src/string/restricted.rs @@ -8,6 +8,7 @@ use std::borrow::Cow; use std::marker::PhantomData; use bytes::Bytes; use crate::{decode, encode}; +use crate::decode::DecodeError; use crate::tag::Tag; use super::octet::{OctetString, OctetStringIter, OctetStringOctets}; @@ -158,16 +159,16 @@ impl RestrictedString { /// returned. pub fn take_from( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.take_value_if(L::TAG, Self::from_content) } /// Takes a character set from content. pub fn from_content( content: &mut decode::Content - ) -> Result { + ) -> Result> { let os = OctetString::from_content(content)?; - Self::new(os).map_err(|_| decode::Error::Malformed.into()) + Self::new(os).map_err(|_| content.content_err("invalid character")) } /// Returns a value encoder for the character string with the natural tag. @@ -542,6 +543,7 @@ mod test { use super::*; use bytes::Bytes; + use crate::decode::IntoSource; use crate::mode::Mode; use crate::encode::Values; @@ -554,7 +556,7 @@ mod test { ps.encode_ref().write_encoded(Mode::Der, &mut v).unwrap(); let decoded = Mode::Der.decode( - v.as_slice(), + v.as_slice().into_source(), PrintableString::take_from ).unwrap(); diff --git a/src/tag.rs b/src/tag.rs index 06f074f..36aefa4 100644 --- a/src/tag.rs +++ b/src/tag.rs @@ -3,7 +3,7 @@ //! This is a private module. Its public items are re-exported by the parent. use std::{fmt, io}; -use crate::decode; +use crate::decode::{DecodeError, Source}; //------------ Tag ----------------------------------------------------------- @@ -22,8 +22,8 @@ use crate::decode; /// /// # Limitations /// -/// We can only decode up to four identifier octets. That is, we only support tag -/// numbers between 0 and 1fffff. +/// We can only decode up to four identifier octets. That is, we only support +/// tag numbers between 0 and 1fffff. /// /// [`Primitive`]: decode/struct.Primitive.html /// [`Constructed`]: decode/struct.Constructed.html @@ -316,7 +316,9 @@ impl Tag { /// Returns the number of the tag. pub fn number(self) -> u32 { - if (Tag::SINGLEBYTE_DATA_MASK & self.0[0]) != Tag::SINGLEBYTE_DATA_MASK { + if (Tag::SINGLEBYTE_DATA_MASK & self.0[0]) + != Tag::SINGLEBYTE_DATA_MASK + { // It's a single byte identifier u32::from(Tag::SINGLEBYTE_DATA_MASK & self.0[0]) } else if Tag::LAST_OCTET_MASK & self.0[1] == 0 { @@ -335,15 +337,17 @@ impl Tag { } } - /// Takes a tag from the beginning of a source. + /// Takes an optional tag from the beginning of a source. /// /// Upon success, returns both the tag and whether the value is - /// constructed. If there are no more octets available in the source, - /// an error is returned. - pub fn take_from( + /// constructed. + pub fn take_opt_from( source: &mut S, - ) -> Result<(Self, bool), S::Err> { - let byte = source.take_u8()?; + ) -> Result, DecodeError> { + let byte = match source.take_opt_u8()? { + Some(byte) => byte, + None => return Ok(None) + }; // clear constructed bit let mut data = [byte & !Tag::CONSTRUCTED_MASK, 0, 0, 0]; let constructed = byte & Tag::CONSTRUCTED_MASK != 0; @@ -351,13 +355,29 @@ impl Tag { for i in 1..=3 { data[i] = source.take_u8()?; if data[i] & Tag::LAST_OCTET_MASK == 0 { - return Ok((Tag(data), constructed)); + return Ok(Some((Tag(data), constructed))); } } } else { - return Ok((Tag(data), constructed)); + return Ok(Some((Tag(data), constructed))); + } + Err(source.content_err( + "tag values longer than 4 bytes not implemented" + )) + } + + /// Takes a tag from the beginning of a source. + /// + /// Upon success, returns both the tag and whether the value is + /// constructed. If there are no more octets available in the source, + /// an error is returned. + pub fn take_from( + source: &mut S, + ) -> Result<(Self, bool), DecodeError> { + match Self::take_opt_from(source)? { + Some(res) => Ok(res), + None => Err(source.content_err("additional values expected")) } - xerr!(Err(decode::Error::Unimplemented.into())) } /// Takes a tag from the beginning of a resource if it matches this tag. @@ -365,10 +385,10 @@ impl Tag { /// If there is no more data available in the source or if the tag is /// something else, returns `Ok(None)`. If the tag matches `self`, returns /// whether the value is constructed. - pub fn take_from_if( + pub fn take_from_if( self, source: &mut S, - ) -> Result, S::Err> { + ) -> Result, DecodeError> { if source.request(1)? == 0 { return Ok(None) } @@ -380,7 +400,7 @@ impl Tag { loop { if source.request(i + 1)? == 0 { // Not enough data for a complete tag. - xerr!(return Err(decode::Error::Malformed.into())) + return Err(source.content_err("short tag value")) } data[i] = source.slice()[i]; if data[i] & Tag::LAST_OCTET_MASK == 0 { @@ -388,14 +408,16 @@ impl Tag { } // We don’t support tags larger than 4 bytes. if i == 3 { - xerr!(return Err(decode::Error::Unimplemented.into())) + return Err(source.content_err( + "tag values longer than 4 bytes not implemented" + )) } i += 1; } } let (tag, compressed) = (Tag(data), byte & Tag::CONSTRUCTED_MASK != 0); if tag == self { - source.advance(tag.encoded_len())?; + source.advance(tag.encoded_len()); Ok(Some(compressed)) } else { @@ -490,6 +512,7 @@ impl fmt::Debug for Tag { #[cfg(test)] mod test { use super::*; + use crate::decode::IntoSource; const TYPES: &[u8] = &[Tag::UNIVERSAL, Tag::APPLICATION, Tag::CONTEXT_SPECIFIC, Tag::PRIVATE]; @@ -503,10 +526,15 @@ mod test { for i in range.clone() { let tag = Tag::new(typ, i); let expected = Tag([typ | i as u8, 0, 0, 0]); - let decoded = Tag::take_from(&mut &tag.0[..]).unwrap(); - assert_eq!(tag.take_from_if(&mut &tag.0[..]), Ok(Some(false))); + let decoded = Tag::take_from( + &mut tag.0.into_source() + ).unwrap(); + assert_eq!( + tag.take_from_if(&mut tag.0.into_source()).unwrap(), + Some(false) + ); // The value is not constructed. - assert_eq!(decoded.1, false); + assert!(!decoded.1); // The tag is the same assert_eq!(decoded.0, expected); // We get the same number back. @@ -532,10 +560,15 @@ mod test { let expected = Tag([ Tag::SINGLEBYTE_DATA_MASK | typ, i as u8, 0, 0 ]); - let decoded = Tag::take_from(&mut &tag.0[..]).unwrap(); - assert_eq!(tag.take_from_if(&mut &tag.0[..]), Ok(Some(false))); + let decoded = Tag::take_from( + &mut tag.0.into_source() + ).unwrap(); + assert_eq!( + tag.take_from_if(&mut tag.0.into_source()).unwrap(), + Some(false) + ); // The value is not constructed. - assert_eq!(decoded.1, false); + assert!(!decoded.1); // The tag is the same assert_eq!(decoded.0, expected); assert_eq!(tag.number(), i); @@ -561,10 +594,15 @@ mod test { i as u8 & !Tag::LAST_OCTET_MASK, 0 ]); - let decoded = Tag::take_from(&mut &tag.0[..]).unwrap(); - assert_eq!(tag.take_from_if(&mut &tag.0[..]), Ok(Some(false))); + let decoded = Tag::take_from( + &mut tag.0.into_source() + ).unwrap(); + assert_eq!( + tag.take_from_if(&mut tag.0.into_source()).unwrap(), + Some(false) + ); // The value is not constructed. - assert_eq!(decoded.1, false); + assert!(!decoded.1); // The tag is the same assert_eq!(decoded.0, expected); assert_eq!(tag.number(), i); @@ -590,10 +628,15 @@ mod test { (i >> 7) as u8 | Tag::LAST_OCTET_MASK, i as u8 & !Tag::LAST_OCTET_MASK ]); - let decoded = Tag::take_from(&mut &tag.0[..]).unwrap(); - assert_eq!(tag.take_from_if(&mut &tag.0[..]), Ok(Some(false))); + let decoded = Tag::take_from( + &mut tag.0.into_source() + ).unwrap(); + assert_eq!( + tag.take_from_if(&mut tag.0.into_source()).unwrap(), + Some(false) + ); // The value is not constructed. - assert_eq!(decoded.1, false); + assert!(!decoded.1); // The tag is the same assert_eq!(decoded.0, expected); assert_eq!(tag.number(), i); @@ -607,14 +650,12 @@ mod test { let large_tag = [ 0b1111_1111, 0b1000_0000, 0b1000_0000, 0b1000_0000, 0b1000_0000 ]; - assert_eq!( - Tag::take_from(&mut &large_tag[..]), - Err(decode::Error::Unimplemented) + assert!( + Tag::take_from(&mut large_tag.into_source()).is_err() ); let short_tag = [0b1111_1111, 0b1000_0000]; - assert_eq!( - Tag::take_from(&mut &short_tag[..]), - Err(decode::Error::Malformed) + assert!( + Tag::take_from(&mut short_tag.into_source()).is_err() ); } }