diff --git a/Cargo.toml b/Cargo.toml index ee226ed..b8d8373 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bcder" -version = "0.6.2-dev" +version = "0.7.0-dev" edition = "2018" authors = ["The NLnet Labs RPKI Team "] description = "Handling of data encoded in BER, CER, and DER." @@ -11,12 +11,6 @@ categories = ["encoding", "network-programming", "parsing"] license = "BSD-3-Clause" [dependencies] -backtrace = { version = "^0.3.15", optional = true } bytes = "^1.0" smallvec = "^1.1" -[features] -# Print a backtrace when a parsing error occurs. This feature is intended for -# development use exclusively and MUST NOT be enabled in release builds. -extra-debug = [ "backtrace" ] - diff --git a/src/captured.rs b/src/captured.rs index 579c252..80360f0 100644 --- a/src/captured.rs +++ b/src/captured.rs @@ -2,9 +2,10 @@ //! //! This is a private module. Its public items are re-exported by the parent. -use std::{fmt, io, ops}; +use std::{fmt, io, mem, ops}; use bytes::{Bytes, BytesMut}; use crate::{decode, encode}; +use crate::decode::{BytesSource, DecodeError, IntoSource, Pos}; use crate::mode::Mode; @@ -39,8 +40,14 @@ use crate::mode::Mode; /// [`Mode`]: ../enum.Mode.html #[derive(Clone)] pub struct Captured { + /// The captured data. bytes: Bytes, + + /// The encoding mode of the captured data. mode: Mode, + + /// The start position of the data in the original source. + start: Pos, } impl Captured { @@ -49,8 +56,8 @@ impl Captured { /// Because we can’t guarantee that the bytes are properly encoded, we /// keep this function crate public. The type, however, doesn’t rely on /// content being properly encoded so this method isn’t unsafe. - pub(crate) fn new(bytes: Bytes, mode: Mode) -> Self { - Captured { bytes, mode } + pub(crate) fn new(bytes: Bytes, mode: Mode, start: Pos) -> Self { + Captured { bytes, mode, start } } /// Creates a captured value by encoding data. @@ -68,7 +75,8 @@ impl Captured { pub fn empty(mode: Mode) -> Self { Captured { bytes: Bytes::new(), - mode + mode, + start: Pos::default(), } } @@ -90,11 +98,13 @@ impl Captured { /// The method consumes the value. If you want to keep it around, simply /// clone it first. Since bytes values are cheap to clone, this is /// relatively cheap. - pub fn decode(self, op: F) -> Result + pub fn decode( + self, op: F + ) -> Result::Error>> where F: FnOnce( - &mut decode::Constructed - ) -> Result + &mut decode::Constructed + ) -> Result::Error>> { self.mode.decode(self.bytes, op) } @@ -104,13 +114,20 @@ impl Captured { /// The method calls `op` to parse a number of values from the beginning /// of the value and then advances the content of the captured value until /// after the end of these decoded values. - pub fn decode_partial(&mut self, op: F) -> Result + pub fn decode_partial( + &mut self, op: F + ) -> Result::Error>> where F: FnOnce( - &mut decode::Constructed<&mut Bytes> - ) -> Result + &mut decode::Constructed<&mut BytesSource> + ) -> Result::Error>> { - self.mode.decode(&mut self.bytes, op) + let mut source = mem::replace( + &mut self.bytes, Bytes::new() + ).into_source(); + let res = self.mode.decode(&mut source, op); + self.bytes = source.into_bytes(); + res } /// Trades the value for a bytes value with the raw data. @@ -148,6 +165,17 @@ impl AsRef<[u8]> for Captured { } +//--- IntoSource + +impl IntoSource for Captured { + type Source = BytesSource; + + fn into_source(self) -> Self::Source { + BytesSource::with_offset(self.bytes, self.start) + } +} + + //--- encode::Values impl encode::Values for Captured { @@ -223,7 +251,7 @@ impl CapturedBuilder { } pub fn freeze(self) -> Captured { - Captured::new(self.bytes.freeze(), self.mode) + Captured::new(self.bytes.freeze(), self.mode, Pos::default()) } } diff --git a/src/debug.rs b/src/debug.rs deleted file mode 100644 index 5131ed8..0000000 --- a/src/debug.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! Macros for last-resort debugging. -//! -//! _Note:_ This facility is going to be replaced by an error type that -//! includes a backtrace of the `extra-debug` feature is set. -//! -//! Since error reporting of the BER parser is limited on purpose, debugging -//! code using it may be difficult. To remedy this somewhat, this module -//! contains a macro `xerr!()` that will print out a backtrace if the -//! `extra-debug` feature is enable during build before resolving into -//! whatever the expression it encloses resolves to otherwise. Use it -//! whenever you initially produce an error, i.e.: -//! -//! ```rust,ignore -//! if foo { -//! xerr!(Err(Error::Malformed)) -//! } -//! ``` -//! -//! or, with an early return: -//! -//! ```rust,ignore -//! if foo { -//! xerr!(return Err(Error::Malformed)); -//! } -//! ``` - -#[cfg(feature = "extra-debug")] -extern crate backtrace; - -#[cfg(feature="extra-debug")] -pub use self::backtrace::Backtrace; - -#[cfg(feature = "extra-debug")] -#[macro_export] -macro_rules! xerr { - ($test:expr) => {{ - eprintln!( - "--- EXTRA DEBUG ---\n{:?}\n--- EXTRA DEBUG ---", - $crate::debug::Backtrace::new() - ); - $test - }} -} - -#[cfg(not(feature = "extra-debug"))] -#[macro_export] -macro_rules! xerr { - ($test:expr) => { $test }; -} - diff --git a/src/decode/content.rs b/src/decode/content.rs index bdec10d..860af73 100644 --- a/src/decode/content.rs +++ b/src/decode/content.rs @@ -3,6 +3,11 @@ //! This is an internal module. Its public types are re-exported by the //! parent. +#![allow(unused_imports)] +#![allow(dead_code)] + +use std::fmt; +use std::convert::Infallible; use bytes::Bytes; use smallvec::SmallVec; use crate::captured::Captured; @@ -10,8 +15,10 @@ use crate::int::{Integer, Unsigned}; use crate::length::Length; use crate::mode::Mode; use crate::tag::Tag; -use super::error::Error; -use super::source::{CaptureSource, LimitedSource, Source}; +use super::error::{ContentError, DecodeError}; +use super::source::{ + CaptureSource, IntoSource, LimitedSource, Pos, SliceSource, Source, +}; //------------ Content ------------------------------------------------------- @@ -34,10 +41,10 @@ pub enum Content<'a, S: 'a> { } impl<'a, S: Source + 'a> Content<'a, S> { - /// Checkes that the content has been parsed completely. + /// Checks that the content has been parsed completely. /// /// Returns a malformed error if not. - fn exhausted(self) -> Result<(), S::Err> { + fn exhausted(self) -> Result<(), DecodeError> { match self { Content::Primitive(inner) => inner.exhausted(), Content::Constructed(mut inner) => inner.exhausted() @@ -69,11 +76,13 @@ impl<'a, S: Source + 'a> Content<'a, S> { } /// Converts a reference into into one to a primitive value or errors out. - pub fn as_primitive(&mut self) -> Result<&mut Primitive<'a, S>, S::Err> { + pub fn as_primitive( + &mut self + ) -> Result<&mut Primitive<'a, S>, DecodeError> { match *self { Content::Primitive(ref mut inner) => Ok(inner), - Content::Constructed(_) => { - xerr!(Err(Error::Malformed.into())) + Content::Constructed(ref inner) => { + Err(inner.content_err("expected primitive value")) } } } @@ -81,14 +90,24 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// Converts a reference into on to a constructed value or errors out. pub fn as_constructed( &mut self - ) -> Result<&mut Constructed<'a, S>, S::Err> { + ) -> Result<&mut Constructed<'a, S>, DecodeError> { match *self { - Content::Primitive(_) => { - xerr!(Err(Error::Malformed.into())) + Content::Primitive(ref inner) => { + Err(inner.content_err("expected constructed value")) } Content::Constructed(ref mut inner) => Ok(inner), } } + + /// Produces a content error at the current source position. + pub fn content_err( + &self, err: impl Into, + ) -> DecodeError { + match *self { + Content::Primitive(ref inner) => inner.content_err(err), + Content::Constructed(ref inner) => inner.content_err(err), + } + } } #[allow(clippy::wrong_self_convention)] @@ -97,12 +116,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content is not primitive or does not contain a single BER /// encoded INTEGER value between 0 and 256, returns a malformed error. - pub fn to_u8(&mut self) -> Result { + pub fn to_u8(&mut self) -> Result> { if let Content::Primitive(ref mut prim) = *self { prim.to_u8() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected integer (0..255)")) } } @@ -111,13 +130,15 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// The content needs to be primitive and contain a validly encoded /// integer of value `expected` or else a malformed error will be /// returned. - pub fn skip_u8_if(&mut self, expected: u8) -> Result<(), S::Err> { + pub fn skip_u8_if( + &mut self, expected: u8, + ) -> Result<(), DecodeError> { let res = self.to_u8()?; if res == expected { Ok(()) } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err(ExpectedIntValue(expected))) } } @@ -125,12 +146,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content is not primitive or does not contain a single BER /// encoded INTEGER value between 0 and 2^16-1, returns a malformed error. - pub fn to_u16(&mut self) -> Result { + pub fn to_u16(&mut self) -> Result> { if let Content::Primitive(ref mut prim) = *self { prim.to_u16() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected integer (0..65535)")) } } @@ -138,12 +159,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content is not primitive or does not contain a single BER /// encoded INTEGER value between 0 and 2^32-1, returns a malformed error. - pub fn to_u32(&mut self) -> Result { + pub fn to_u32(&mut self) -> Result> { if let Content::Primitive(ref mut prim) = *self { prim.to_u32() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected integer (0..4294967295)")) } } @@ -151,12 +172,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content is not primitive or does not contain a single BER /// encoded INTEGER value between 0 and 2^64-1, returns a malformed error. - pub fn to_u64(&mut self) -> Result { + pub fn to_u64(&mut self) -> Result> { if let Content::Primitive(ref mut prim) = *self { prim.to_u64() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected integer (0..2**64-1)")) } } @@ -164,12 +185,12 @@ impl<'a, S: Source + 'a> Content<'a, S> { /// /// If the content isn’t primitive and contains a single BER encoded /// NULL value (i.e., nothing), returns a malformed error. - pub fn to_null(&mut self) -> Result<(), S::Err> { + pub fn to_null(&mut self) -> Result<(), DecodeError> { if let Content::Primitive(ref mut prim) = *self { prim.to_null() } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("expected NULL")) } } } @@ -207,14 +228,18 @@ pub struct Primitive<'a, S: 'a> { /// The decoding mode to operate in. mode: Mode, + + /// The start position of the value in the source. + start: Pos, } /// # Value Management /// impl<'a, S: 'a> Primitive<'a, S> { /// Creates a new primitive from the given source and mode. - fn new(source: &'a mut LimitedSource, mode: Mode) -> Self { - Primitive { source, mode } + fn new(source: &'a mut LimitedSource, mode: Mode) -> Self + where S: Source { + Primitive { start: source.pos(), source, mode } } /// Returns the current decoding mode. @@ -231,20 +256,29 @@ impl<'a, S: 'a> Primitive<'a, S> { } } +impl<'a, S: Source + 'a> Primitive<'a, S> { + /// Produces a content error at the current source position. + pub fn content_err( + &self, err: impl Into, + ) -> DecodeError { + DecodeError::content(err, self.start) + } +} + /// # High-level Decoding /// #[allow(clippy::wrong_self_convention)] impl<'a, S: Source + 'a> Primitive<'a, S> { /// Parses the primitive value as a BOOLEAN value. - pub fn to_bool(&mut self) -> Result { + pub fn to_bool(&mut self) -> Result> { let res = self.take_u8()?; if self.mode != Mode::Ber { match res { 0 => Ok(false), 0xFF => Ok(true), _ => { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("invalid boolean")) } } } @@ -254,61 +288,65 @@ impl<'a, S: Source + 'a> Primitive<'a, S> { } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i8(&mut self) -> Result { + pub fn to_i8(&mut self) -> Result> { Integer::i8_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i16(&mut self) -> Result { + pub fn to_i16(&mut self) -> Result> { Integer::i16_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i32(&mut self) -> Result { + pub fn to_i32(&mut self) -> Result> { Integer::i32_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i64(&mut self) -> Result { + pub fn to_i64(&mut self) -> Result> { Integer::i64_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `i8`. - pub fn to_i128(&mut self) -> Result { + pub fn to_i128(&mut self) -> Result> { Integer::i128_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `u8`. - pub fn to_u8(&mut self) -> Result { + pub fn to_u8(&mut self) -> Result> { Unsigned::u8_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `u16`. - pub fn to_u16(&mut self) -> Result { + pub fn to_u16(&mut self) -> Result> { Unsigned::u16_from_primitive(self) } /// Parses the primitive value as an INTEGER limited to a `u32`. - pub fn to_u32(&mut self) -> Result { + pub fn to_u32(&mut self) -> Result> { Unsigned::u32_from_primitive(self) } /// Parses the primitive value as a INTEGER value limited to a `u64`. - pub fn to_u64(&mut self) -> Result { + pub fn to_u64(&mut self) -> Result> { Unsigned::u64_from_primitive(self) } /// Parses the primitive value as a INTEGER value limited to a `u128`. - pub fn to_u128(&mut self) -> Result { + pub fn to_u128(&mut self) -> Result> { Unsigned::u64_from_primitive(self) } /// Converts the content octets to a NULL value. /// /// Since such a value is empty, this doesn’t really do anything. - pub fn to_null(&mut self) -> Result<(), S::Err> { - // The rest is taken care of by the exhausted check later ... - Ok(()) + pub fn to_null(&mut self) -> Result<(), DecodeError> { + if self.remaining() > 0 { + Err(self.content_err("invalid NULL value")) + } + else { + Ok(()) + } } } @@ -322,31 +360,38 @@ impl<'a, S: Source + 'a> Primitive<'a, S> { impl<'a, S: Source + 'a> Primitive<'a, S> { /// Returns the number of remaining octets. /// - /// The returned value reflects what is left of the content and therefore - /// decreases when the primitive is advanced. + /// The returned value reflects what is left of the expected length of + /// content and therefore decreases when the primitive is advanced. pub fn remaining(&self) -> usize { self.source.limit().unwrap() } /// Skips the rest of the content. - pub fn skip_all(&mut self) -> Result<(), S::Err> { + /// + /// Returns a malformed error if the source ends before the expected + /// length of content. + pub fn skip_all(&mut self) -> Result<(), DecodeError> { self.source.skip_all() } /// Returns the remainder of the content as a `Bytes` value. - pub fn take_all(&mut self) -> Result { + pub fn take_all(&mut self) -> Result> { self.source.take_all() } /// Returns a bytes slice of the remainder of the content. - pub fn slice_all(&mut self) -> Result<&[u8], S::Err> { + pub fn slice_all(&mut self) -> Result<&[u8], DecodeError> { let remaining = self.remaining(); - self.source.request(remaining)?; - Ok(&self.source.slice()[..remaining]) + if self.source.request(remaining)? < remaining { + Err(self.source.content_err("unexpected end of data")) + } + else { + Ok(&self.source.slice()[..remaining]) + } } /// Checkes whether all content has been advanced over. - fn exhausted(self) -> Result<(), S::Err> { + fn exhausted(self) -> Result<(), DecodeError> { self.source.exhausted() } } @@ -354,7 +399,7 @@ impl<'a, S: Source + 'a> Primitive<'a, S> { /// # Support for Testing /// -impl<'a> Primitive<'a, &'a [u8]> { +impl Primitive<'static, ()> { /// Decode a bytes slice via a closure. /// /// This method can be used in testing code for decoding primitive @@ -374,13 +419,17 @@ impl<'a> Primitive<'a, &'a [u8]> { /// ) /// ``` pub fn decode_slice( - source: &'a [u8], + data: &[u8], mode: Mode, op: F - ) -> Result - where F: FnOnce(&mut Primitive<&[u8]>) -> Result { - let mut lim = LimitedSource::new(source); - lim.set_limit(Some(source.len())); + ) -> Result> + where + F: FnOnce( + &mut Primitive + ) -> Result> + { + let mut lim = LimitedSource::new(data.into_source()); + lim.set_limit(Some(data.len())); let mut prim = Primitive::new(&mut lim, mode); let res = op(&mut prim)?; prim.exhausted()?; @@ -392,14 +441,14 @@ impl<'a> Primitive<'a, &'a [u8]> { //--- Source impl<'a, S: Source + 'a> Source for Primitive<'a, S> { - type Err = S::Err; + type Error = S::Error; - fn request(&mut self, len: usize) -> Result { - self.source.request(len) + fn pos(&self) -> Pos { + self.source.pos() } - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - self.source.advance(len) + fn request(&mut self, len: usize) -> Result { + self.source.request(len) } fn slice(&self) -> &[u8] { @@ -409,6 +458,10 @@ impl<'a, S: Source + 'a> Source for Primitive<'a, S> { fn bytes(&self, start: usize, end: usize) -> Bytes { self.source.bytes(start, end) } + + fn advance(&mut self, len: usize) { + self.source.advance(len) + } } @@ -443,6 +496,9 @@ pub struct Constructed<'a, S: 'a> { /// The encoding mode to use. mode: Mode, + + /// The start position of the value in the source. + start: Pos, } /// # General Management @@ -454,10 +510,10 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { state: State, mode: Mode ) -> Self { - Constructed { source, state, mode } + Constructed { start: source.pos(), source, state, mode } } - /// Decode a source as a constructed content. + /// Decode a source as constructed content. /// /// The function will start decoding of `source` in the given mode. It /// will pass a constructed content value to the closure `op` which @@ -466,9 +522,14 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// This function is identical to calling [`Mode::decode`]. /// /// [`Mode::decode`]: ../enum.Mode.html#method.decode - pub fn decode(source: S, mode: Mode, op: F) -> Result - where F: FnOnce(&mut Constructed) -> Result { - let mut source = LimitedSource::new(source); + pub fn decode( + source: I, mode: Mode, op: F, + ) -> Result> + where + I: IntoSource, + F: FnOnce(&mut Constructed) -> Result> + { + let mut source = LimitedSource::new(source.into_source()); let mut cons = Constructed::new(&mut source, State::Unbounded, mode); let res = op(&mut cons)?; cons.exhausted()?; @@ -486,6 +547,15 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { } } +impl<'a, S: Source + 'a> Constructed<'a, S> { + /// Produces a content error at start of the value. + pub fn content_err( + &self, err: impl Into, + ) -> DecodeError { + DecodeError::content(err, self.start) + } +} + /// # Fundamental Reading /// impl<'a, S: Source + 'a> Constructed<'a, S> { @@ -494,7 +564,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// For a value of definite length, this is the case when the limit of the /// source has been reached. For indefinite values, we need to have either /// already read or can now read the end-of-value marker. - fn exhausted(&mut self) -> Result<(), S::Err> { + fn exhausted(&mut self) -> Result<(), DecodeError> { match self.state { State::Done => Ok(()), State::Definite => { @@ -505,7 +575,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { if tag != Tag::END_OF_VALUE || constructed || !Length::take_from(self.source, self.mode)?.is_zero() { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("unexpected trailing values")) } else { Ok(()) @@ -546,8 +616,10 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { &mut self, expected: Option, op: F - ) -> Result, S::Err> - where F: FnOnce(Tag, &mut Content) -> Result { + ) -> Result, DecodeError> + where + F: FnOnce(Tag, &mut Content) -> Result> + { if self.is_exhausted() { return Ok(None) } @@ -568,16 +640,22 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { if tag == Tag::END_OF_VALUE { if let State::Indefinite = self.state { if constructed { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "constructed end of value" + )) } if !length.is_zero() { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "non-empty end of value" + )) } self.state = State::Done; return Ok(None) } else { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "unexpected end of value" + )) } } @@ -589,7 +667,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { // Definite length constructed values are not allowed // in CER. if self.mode == Mode::Cer { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "definite length constructed in CER mode" + )) } Content::Constructed( Constructed::new( @@ -611,10 +691,14 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { } Length::Indefinite => { if !constructed || self.mode == Mode::Der { - xerr!(return Err(Error::Malformed.into())) + return Err(self.source.content_err( + "indefinite length constructed in DER mode" + )) } let mut content = Content::Constructed( - Constructed::new(self.source, State::Indefinite, self.mode) + Constructed::new( + self.source, State::Indefinite, self.mode + ) ); let res = op(tag, &mut content)?; content.exhausted()?; @@ -622,6 +706,21 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { } } } + + /// Makes sure the next value is present. + fn mandatory( + &mut self, op: F, + ) -> Result> + where + F: FnOnce( + &mut Constructed + ) -> Result, DecodeError>, + { + match op(self)? { + Some(res) => Ok(res), + None => Err(self.source.content_err("missing futher values")), + } + } } /// # Processing Contained Values @@ -638,13 +737,15 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// method returns a malformed error if there isn’t at least one more /// value available. It also returns an error if the closure returns one /// or if reading from the source fails. - pub fn take_value(&mut self, op: F) -> Result - where F: FnOnce(Tag, &mut Content) -> Result { + pub fn take_value( + &mut self, op: F, + ) -> Result> + where + F: FnOnce(Tag, &mut Content) -> Result>, + { match self.process_next_value(None, op)? { Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } + None => Err(self.content_err("missing futher values")), } } @@ -658,8 +759,12 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// If there are no more values available, the method returns `Ok(None)`. /// It returns an error if the closure returns one or if reading from /// the source fails. - pub fn take_opt_value(&mut self, op: F) -> Result, S::Err> - where F: FnOnce(Tag, &mut Content) -> Result { + pub fn take_opt_value( + &mut self, op: F, + ) -> Result, DecodeError> + where + F: FnOnce(Tag, &mut Content) -> Result>, + { self.process_next_value(None, op) } @@ -677,16 +782,14 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { &mut self, expected: Tag, op: F - ) -> Result - where F: FnOnce(&mut Content) -> Result { + ) -> Result> + where F: FnOnce(&mut Content) -> Result> { let res = self.process_next_value(Some(expected), |_, content| { op(content) })?; match res { Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } + None => Err(self.content_err(ExpectedTag(expected))), } } @@ -704,12 +807,12 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { &mut self, expected: Tag, op: F - ) -> Result, S::Err> - where F: FnOnce(&mut Content) -> Result { + ) -> Result, DecodeError> + where F: FnOnce(&mut Content) -> Result> { self.process_next_value(Some(expected), |_, content| op(content)) } - /// Process a constructed value. + /// Processes a constructed value. /// /// If the next value is a constructed value, its tag and content are /// being given to the closure `op` which has to process it completely. @@ -719,14 +822,15 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// value or if the closure doesn’t process the next value completely, /// a malformed error is returned. An error is also returned if the /// closure returns one or if accessing the underlying source fails. - pub fn take_constructed(&mut self, op: F) -> Result - where F: FnOnce(Tag, &mut Constructed) -> Result { - match self.take_opt_constructed(op)? { - Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } - } + pub fn take_constructed( + &mut self, op: F + ) -> Result> + where + F: FnOnce( + Tag, &mut Constructed + ) -> Result>, + { + self.mandatory(|cons| cons.take_opt_constructed(op)) } /// Processes an optional constructed value. @@ -745,8 +849,12 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { pub fn take_opt_constructed( &mut self, op: F - ) -> Result, S::Err> - where F: FnOnce(Tag, &mut Constructed) -> Result { + ) -> Result, DecodeError> + where + F: FnOnce( + Tag, &mut Constructed, + ) -> Result> + { self.process_next_value(None, |tag, content| { op(tag, content.as_constructed()?) }) @@ -767,15 +875,12 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { pub fn take_constructed_if( &mut self, expected: Tag, - op: F - ) -> Result - where F: FnOnce(&mut Constructed) -> Result { - match self.take_opt_constructed_if(expected, op)? { - Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } - } + op: F, + ) -> Result> + where + F: FnOnce(&mut Constructed) -> Result>, + { + self.mandatory(|cons| cons.take_opt_constructed_if(expected, op)) } /// Processes an optional constructed value if it has a given tag. @@ -795,9 +900,11 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { pub fn take_opt_constructed_if( &mut self, expected: Tag, - op: F - ) -> Result, S::Err> - where F: FnOnce(&mut Constructed) -> Result { + op: F, + ) -> Result, DecodeError> + where + F: FnOnce(&mut Constructed) -> Result>, + { self.process_next_value(Some(expected), |_, content| { op(content.as_constructed()?) }) @@ -813,14 +920,13 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// reached, or if the closure fails to process the next value’s content /// fully, a malformed error is returned. An error is also returned if /// the closure returns one or if accessing the underlying source fails. - pub fn take_primitive(&mut self, op: F) -> Result - where F: FnOnce(Tag, &mut Primitive) -> Result { - match self.take_opt_primitive(op)? { - Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } - } + pub fn take_primitive( + &mut self, op: F, + ) -> Result> + where + F: FnOnce(Tag, &mut Primitive) -> Result>, + { + self.mandatory(|cons| cons.take_opt_primitive(op)) } /// Processes an optional primitive value. @@ -835,10 +941,11 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// malformed error is returned. An error is also returned if /// the closure returns one or if accessing the underlying source fails. pub fn take_opt_primitive( - &mut self, - op: F - ) -> Result, S::Err> - where F: FnOnce(Tag, &mut Primitive) -> Result { + &mut self, op: F, + ) -> Result, DecodeError> + where + F: FnOnce(Tag, &mut Primitive) -> Result>, + { self.process_next_value(None, |tag, content| { op(tag, content.as_primitive()?) }) @@ -855,17 +962,10 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// the closure doesn’t advance over the complete content. If access to /// the underlying source fails, an error is returned, too. pub fn take_primitive_if( - &mut self, - expected: Tag, - op: F - ) -> Result - where F: FnOnce(&mut Primitive) -> Result { - match self.take_opt_primitive_if(expected, op)? { - Some(res) => Ok(res), - None => { - xerr!(Err(Error::Malformed.into())) - } - } + &mut self, expected: Tag, op: F, + ) -> Result> + where F: FnOnce(&mut Primitive) -> Result> { + self.mandatory(|cons| cons.take_opt_primitive_if(expected, op)) } /// Processes an optional primitive value of a given tag. @@ -880,11 +980,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// fully the method returns a malformed error. If access to the /// underlying source fails, it returns an appropriate error. pub fn take_opt_primitive_if( - &mut self, - expected: Tag, - op: F - ) -> Result, S::Err> - where F: FnOnce(&mut Primitive) -> Result { + &mut self, expected: Tag, op: F, + ) -> Result, DecodeError> + where F: FnOnce(&mut Primitive) -> Result> { self.process_next_value(Some(expected), |_, content| { op(content.as_primitive()?) }) @@ -902,13 +1000,16 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// If the closure returns an error, this error is returned. /// /// [`Captured`]: ../captures/struct.Captured.html - pub fn capture(&mut self, op: F) -> Result + pub fn capture( + &mut self, op: F, + ) -> Result> where F: FnOnce( &mut Constructed>> - ) -> Result<(), S::Err> + ) -> Result<(), DecodeError> { let limit = self.source.limit(); + let start = self.source.pos(); let mut source = LimitedSource::new(CaptureSource::new(self.source)); source.set_limit(limit); { @@ -918,7 +1019,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { op(&mut constructed)?; self.state = constructed.state; } - Ok(Captured::new(source.unwrap().into_bytes(), self.mode)) + Ok(Captured::new( + source.unwrap().into_bytes(), self.mode, start, + )) } /// Captures one value for later processing @@ -930,28 +1033,25 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// to the underlying source fails, an appropriate error is returned. /// /// [`Captured`]: ../captures/struct.Captured.html - pub fn capture_one(&mut self) -> Result { - self.capture(|cons| { - match cons.skip_one()? { - Some(()) => Ok(()), - None => { - xerr!(Err(Error::Malformed.into())) - } - } - }) + pub fn capture_one(&mut self) -> Result> { + self.capture(|cons| cons.mandatory(|cons| cons.skip_one())) } /// Captures all remaining content for later processing. /// /// The method takes all remaining values from this value’s content and /// returns their encoded form in a `Bytes` value. - pub fn capture_all(&mut self) -> Result { + pub fn capture_all( + &mut self + ) -> Result> { self.capture(|cons| cons.skip_all()) } /// Skips over content. - pub fn skip_opt(&mut self, mut op: F) -> Result, S::Err> - where F: FnMut(Tag, bool, usize) -> Result<(), S::Err> { + pub fn skip_opt( + &mut self, mut op: F, + ) -> Result, DecodeError> + where F: FnMut(Tag, bool, usize) -> Result<(), ContentError> { // If we already know we are at the end of the value, we can return. if self.is_exhausted() { return Ok(None) @@ -970,7 +1070,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { if !constructed { if tag == Tag::END_OF_VALUE { if length != Length::Definite(0) { - xerr!(return Err(Error::Malformed.into())) + return Err(self.content_err("non-empty end of value")) } // End-of-value: The top of the stack needs to be an @@ -989,36 +1089,51 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { return Ok(None) } else { - xerr!(return Err(Error::Malformed.into())) + return Err(self.content_err( + "invalid nested values" + )) } } - _ => xerr!(return Err(Error::Malformed.into())) + _ => { + return Err(self.content_err( + "invalid nested values" + )) + } } } else { // Primitive value. Check for the length to be definite, - // check that the caller likes it, then try to read over it. + // check that the caller likes it, then try to read over + // it. if let Length::Definite(len) = length { - op(tag, constructed, stack.len())?; - self.source.advance(len)?; + if let Err(err) = op(tag, constructed, stack.len()) { + return Err(self.content_err(err)); + } + self.source.advance(len); } else { - xerr!(return Err(Error::Malformed.into())); + return Err(self.content_err( + "primitive value with indefinite length" + )) } } } else if let Length::Definite(len) = length { - // Definite constructed value. First check if the caller likes - // it. Check that there is enough limit left for the value. If - // so, push the limit at the end of the value to the stack, - // update the limit to our length, and continue. - op(tag, constructed, stack.len())?; + // Definite constructed value. First check if the caller + // likes it. Check that there is enough limit left for the + // value. If so, push the limit at the end of the value to + // the stack, update the limit to our length, and continue. + if let Err(err) = op(tag, constructed, stack.len()) { + return Err(self.content_err(err)); + } stack.push(Some(match self.source.limit() { Some(limit) => { match limit.checked_sub(len) { Some(len) => Some(len), None => { - xerr!(return Err(Error::Malformed.into())); + return Err(self.content_err( + "invalid nested values" + )); } } } @@ -1029,7 +1144,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { else { // Indefinite constructed value. Simply push a `None` to the // stack, if the caller likes it. - op(tag, constructed, stack.len())?; + if let Err(err) = op(tag, constructed, stack.len()) { + return Err(self.content_err(err)); + } stack.push(None); continue; } @@ -1051,7 +1168,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { Some(None) => { // We need a End-of-value, so running out of // data is an error. - xerr!(return Err(Error::Malformed.into())); + return Err(self.content_err(" + missing futher values" + )) } None => unreachable!(), } @@ -1064,18 +1183,13 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { } } - pub fn skip(&mut self, op: F) -> Result<(), S::Err> - where F: FnMut(Tag, bool, usize) -> Result<(), S::Err> { - if self.skip_opt(op)? == None { - xerr!(Err(Error::Malformed.into())) - } - else { - Ok(()) - } + pub fn skip(&mut self, op: F) -> Result<(), DecodeError> + where F: FnMut(Tag, bool, usize) -> Result<(), ContentError> { + self.mandatory(|cons| cons.skip_opt(op)) } /// Skips over all remaining content. - pub fn skip_all(&mut self) -> Result<(), S::Err> { + pub fn skip_all(&mut self) -> Result<(), DecodeError> { while let Some(()) = self.skip_one()? { } Ok(()) } @@ -1084,7 +1198,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If there is a next value, returns `Ok(Some(()))`, if the end of value /// has already been reached, returns `Ok(None)`. - pub fn skip_one(&mut self) -> Result, S::Err> { + pub fn skip_one(&mut self) -> Result, DecodeError> { if self.is_exhausted() { Ok(None) } @@ -1103,22 +1217,24 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// encoding. impl<'a, S: Source + 'a> Constructed<'a, S> { /// Processes and returns a mandatory boolean value. - pub fn take_bool(&mut self) -> Result { + pub fn take_bool(&mut self) -> Result> { self.take_primitive_if(Tag::BOOLEAN, |prim| prim.to_bool()) } /// Processes and returns an optional boolean value. - pub fn take_opt_bool(&mut self) -> Result, S::Err> { + pub fn take_opt_bool( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::BOOLEAN, |prim| prim.to_bool()) } /// Processes a mandatory NULL value. - pub fn take_null(&mut self) -> Result<(), S::Err> { + pub fn take_null(&mut self) -> Result<(), DecodeError> { self.take_primitive_if(Tag::NULL, |_| Ok(())).map(|_| ()) } /// Processes an optional NULL value. - pub fn take_opt_null(&mut self) -> Result<(), S::Err> { + pub fn take_opt_null(&mut self) -> Result<(), DecodeError> { self.take_opt_primitive_if(Tag::NULL, |_| Ok(())).map(|_| ()) } @@ -1126,7 +1242,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 255, a malformed /// error is returned. - pub fn take_u8(&mut self) -> Result { + pub fn take_u8(&mut self) -> Result> { self.take_primitive_if(Tag::INTEGER, |prim| prim.to_u8()) } @@ -1134,7 +1250,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 255, a malformed /// error is returned. - pub fn take_opt_u8(&mut self) -> Result, S::Err> { + pub fn take_opt_u8( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| prim.to_u8()) } @@ -1142,11 +1260,13 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the next value is an integer but of a different value, returns /// a malformed error. - pub fn skip_u8_if(&mut self, expected: u8) -> Result<(), S::Err> { + pub fn skip_u8_if( + &mut self, expected: u8, + ) -> Result<(), DecodeError> { self.take_primitive_if(Tag::INTEGER, |prim| { let got = prim.take_u8()?; if got != expected { - xerr!(Err(Error::Malformed.into())) + Err(prim.content_err(ExpectedIntValue(expected))) } else { Ok(()) @@ -1158,11 +1278,13 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the next value is an integer but of a different value, returns /// a malformed error. - pub fn skip_opt_u8_if(&mut self, expected: u8) -> Result<(), S::Err> { + pub fn skip_opt_u8_if( + &mut self, expected: u8, + ) -> Result<(), DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| { let got = prim.take_u8()?; if got != expected { - xerr!(Err(Error::Malformed.into())) + Err(prim.content_err(ExpectedIntValue(expected))) } else { Ok(()) @@ -1172,17 +1294,19 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// Processes a mandatory INTEGER value of the `u16` range. /// - /// If the integer value is less than 0 or greater than 65535, a malformed - /// error is returned. - pub fn take_u16(&mut self) -> Result { + /// If the integer value is less than 0 or greater than 65535, a + /// malformed error is returned. + pub fn take_u16(&mut self) -> Result> { self.take_primitive_if(Tag::INTEGER, |prim| prim.to_u16()) } /// Processes an optional INTEGER value of the `u16` range. /// - /// If the integer value is less than 0 or greater than 65535, a malformed - /// error is returned. - pub fn take_opt_u16(&mut self) -> Result, S::Err> { + /// If the integer value is less than 0 or greater than 65535, a + /// malformed error is returned. + pub fn take_opt_u16( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| prim.to_u16()) } @@ -1190,7 +1314,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 2^32-1, a /// malformed error is returned. - pub fn take_u32(&mut self) -> Result { + pub fn take_u32(&mut self) -> Result> { self.take_primitive_if(Tag::INTEGER, |prim| prim.to_u32()) } @@ -1198,7 +1322,9 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 2^32-1, a /// malformed error is returned. - pub fn take_opt_u32(&mut self) -> Result, S::Err> { + pub fn take_opt_u32( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| prim.to_u32()) } @@ -1206,7 +1332,7 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 2^64-1, a /// malformed error is returned. - pub fn take_u64(&mut self) -> Result { + pub fn take_u64(&mut self) -> Result> { self.take_primitive_if(Tag::INTEGER, |prim| prim.to_u64()) } @@ -1214,42 +1340,50 @@ impl<'a, S: Source + 'a> Constructed<'a, S> { /// /// If the integer value is less than 0 or greater than 2^64-1, a /// malformed error is returned. - pub fn take_opt_u64(&mut self) -> Result, S::Err> { + pub fn take_opt_u64( + &mut self, + ) -> Result, DecodeError> { self.take_opt_primitive_if(Tag::INTEGER, |prim| prim.to_u64()) } /// Processes a mandatory SEQUENCE value. /// /// This is a shortcut for `self.take_constructed(Tag::SEQUENCE, op)`. - pub fn take_sequence(&mut self, op: F) -> Result - where F: FnOnce(&mut Constructed) -> Result { + pub fn take_sequence( + &mut self, op: F, + ) -> Result> + where F: FnOnce(&mut Constructed) -> Result> { self.take_constructed_if(Tag::SEQUENCE, op) } /// Processes an optional SEQUENCE value. /// - /// This is a shortcut for `self.take_opt_constructed(Tag::SEQUENCE, op)`. + /// This is a shortcut for + /// `self.take_opt_constructed(Tag::SEQUENCE, op)`. pub fn take_opt_sequence( - &mut self, - op: F - ) -> Result, S::Err> - where F: FnOnce(&mut Constructed) -> Result { + &mut self, op: F, + ) -> Result, DecodeError> + where F: FnOnce(&mut Constructed) -> Result> { self.take_opt_constructed_if(Tag::SEQUENCE, op) } /// Processes a mandatory SET value. /// /// This is a shortcut for `self.take_constructed(Tag::SET, op)`. - pub fn take_set(&mut self, op: F) -> Result - where F: FnOnce(&mut Constructed) -> Result { + pub fn take_set( + &mut self, op: F, + ) -> Result> + where F: FnOnce(&mut Constructed) -> Result> { self.take_constructed_if(Tag::SET, op) } /// Processes an optional SET value. /// /// This is a shortcut for `self.take_opt_constructed(Tag::SET, op)`. - pub fn take_opt_set(&mut self, op: F) -> Result, S::Err> - where F: FnOnce(&mut Constructed) -> Result { + pub fn take_opt_set( + &mut self, op: F + ) -> Result, DecodeError> + where F: FnOnce(&mut Constructed) -> Result> { self.take_opt_constructed_if(Tag::SET, op) } } @@ -1272,6 +1406,43 @@ enum State { /// Unbounded value: read as far as we get. Unbounded, } + + +//============ Error Types =================================================== + +/// A value with a certain tag was expected. +#[derive(Clone, Copy, Debug)] +struct ExpectedTag(Tag); + +impl fmt::Display for ExpectedTag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "expected tag {}", self.0) + } +} + +impl From for ContentError { + fn from(err: ExpectedTag) -> Self { + ContentError::from_boxed(Box::new(err)) + } +} + + +/// An integer with a certain value was expected. +#[derive(Clone, Copy, Debug)] +struct ExpectedIntValue(T); + +impl fmt::Display for ExpectedIntValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "expected integer value {}", self.0) + } +} + +impl From> for ContentError +where T: fmt::Display + Send + Sync + 'static { + fn from(err: ExpectedIntValue) -> Self { + ContentError::from_boxed(Box::new(err)) + } +} //============ Tests ========================================================= @@ -1284,7 +1455,7 @@ mod test { fn constructed_skip() { // Two primitives. Constructed::decode( - b"\x02\x01\x00\x02\x01\x00".as_ref(), Mode::Ber, |cons| { + b"\x02\x01\x00\x02\x01\x00".into_source(), Mode::Ber, |cons| { cons.skip(|_, _, _| Ok(())).unwrap(); cons.skip(|_, _, _| Ok(())).unwrap(); Ok(()) @@ -1293,7 +1464,7 @@ mod test { // One definite constructed with two primitives, then one primitive Constructed::decode( - b"\x30\x06\x02\x01\x00\x02\x01\x00\x02\x01\x00".as_ref(), + b"\x30\x06\x02\x01\x00\x02\x01\x00\x02\x01\x00".into_source(), Mode::Ber, |cons| { cons.skip(|_, _, _| Ok(())).unwrap(); @@ -1308,7 +1479,7 @@ mod test { b"\x30\x08\ \x30\x06\ \x02\x01\x00\x02\x01\x00\ - \x02\x01\x00".as_ref(), + \x02\x01\x00".into_source(), Mode::Ber, |cons| { cons.skip(|_, _, _| Ok(())).unwrap(); @@ -1322,7 +1493,7 @@ mod test { b"\x30\x0A\ \x30\x80\ \x02\x01\x00\x02\x01\x00\ - \0\0".as_ref(), + \0\0".into_source(), Mode::Ber, |cons| { cons.skip(|_, _, _| Ok(())).unwrap(); @@ -1332,3 +1503,4 @@ mod test { } } + diff --git a/src/decode/error.rs b/src/decode/error.rs index 6a470ff..4cc1384 100644 --- a/src/decode/error.rs +++ b/src/decode/error.rs @@ -3,26 +3,161 @@ //! This is a private module. Its public content is being re-exported by the //! parent module. -use std::fmt; +use std::{error, fmt}; +use std::convert::Infallible; +use super::source::Pos; -//------------ Error --------------------------------------------------------- +//------------ ContentError -------------------------------------------------- + +/// An error happened while interpreting BER-encoded data. +pub struct ContentError { + /// The error message. + message: ErrorMessage, +} + +impl ContentError { + /// Creates a content error from a static str. + pub fn from_static(msg: &'static str) -> Self { + ContentError { + message: ErrorMessage::Static(msg) + } + } + + /// Creates a content error from a boxed trait object. + pub fn from_boxed( + msg: Box + ) -> Self { + ContentError { + message: ErrorMessage::Boxed(msg) + } + } +} + +impl From<&'static str> for ContentError { + fn from(msg: &'static str) -> Self { + Self::from_static(msg) + } +} + +impl From for ContentError { + fn from(msg: String) -> Self { + Self::from_boxed(Box::new(msg)) + } +} + +impl From> for ContentError { + fn from(err: DecodeError) -> Self { + match err.inner { + DecodeErrorKind::Source(_) => unreachable!(), + DecodeErrorKind::Content { error, .. } => error, + } + } +} + + +impl fmt::Display for ContentError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.message.fmt(f) + } +} + +impl fmt::Debug for ContentError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("ContentError") + .field(&self.message) + .finish() + } +} + + +//------------ ErrorMessage -------------------------------------------------- + +/// The actual error message as a hidden enum. +enum ErrorMessage { + /// The error message is a static str. + Static(&'static str), + + /// The error message is a boxed trait object. + Boxed(Box), +} + +impl fmt::Display for ErrorMessage { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ErrorMessage::Static(msg) => f.write_str(msg), + ErrorMessage::Boxed(ref msg) => msg.fmt(f), + } + } +} + +impl fmt::Debug for ErrorMessage { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ErrorMessage::Static(msg) => f.write_str(msg), + ErrorMessage::Boxed(ref msg) => msg.fmt(f), + } + } +} + + +//------------ DecodeError --------------------------------------------------- /// An error happened while decoding data. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum Error { - /// The data didn’t conform to the expected structure. - Malformed, +/// +/// This can either be a source error – which the type is generic over – or a +/// content error annotated with the position in the source where it happened. +#[derive(Debug)] +pub struct DecodeError { + inner: DecodeErrorKind, +} + +#[derive(Debug)] +enum DecodeErrorKind { + Source(S), + Content { + error: ContentError, + pos: Pos, + } +} - /// An encoding used by the data is not yet implemented by the crate. - Unimplemented, +impl DecodeError { + /// Creates a decode error from a content error and a position. + pub fn content(error: impl Into, pos: Pos) -> Self { + DecodeError { + inner: DecodeErrorKind::Content { error: error.into(), pos }, + } + } } -impl fmt::Display for Error { +impl DecodeError { + /// Converts a decode error from an infallible source into another error. + pub fn convert(self) -> DecodeError { + match self.inner { + DecodeErrorKind::Source(_) => unreachable!(), + DecodeErrorKind::Content { error, pos } => { + DecodeError::content(error, pos) + } + } + } +} + +impl From for DecodeError { + fn from(err: S) -> Self { + DecodeError { inner: DecodeErrorKind::Source(err) } + } +} + +impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::Malformed => write!(f, "malformed data"), - Error::Unimplemented => write!(f, "format not implemented"), + match self.inner { + DecodeErrorKind::Source(ref err) => err.fmt(f), + DecodeErrorKind::Content { ref error, pos } => { + write!(f, "{} (at position {})", error, pos) + } } } } + +impl error::Error for DecodeError { } + diff --git a/src/decode/mod.rs b/src/decode/mod.rs index 671fab8..e6d87c8 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -4,9 +4,9 @@ //! //! The basic idea is that for each type a function exists that knows how //! to decode one value of that type. For constructed types, this function -//! in turn relies on similar functions provided for its consituent types. +//! in turn relies on similar functions provided for its constituent types. //! For a detailed introduction to how to write these functions, please -//! refer to the [decode section of the guide]. +//! refer to the [decode section of the guide][crate::guide::decode]. //! //! The two most important types of this module are [`Primitive`] and //! [`Constructed`], representing the content octets of a value in primitive @@ -20,19 +20,30 @@ //! The enum [`Content`] is used for cases where a value can be either //! primitive or constructed such as most string types. //! -//! Decoding is jumpstarted by providing a data source to parse data from. -//! This is any value that implements the [`Source`] trait. +//! The data for decoding is provided by any type that implements the +//! [`Source`] trait – or can be converted into such a type via the +//! [`IntoSource`] trait. Implementations for both `bytes::Bytes` and +//! `&[u8]` are available. //! -//! [decode section of the guide]: ../guide/decode/index.html -//! [`Primitive`]: struct.Primitive.html -//! [`Constructed`]: struct.Constructed.html -//! [`Content`]: enum.Content.html -//! [`Source`]: trait.Source.html +//! During decoding, errors can happen. There are two kinds of errors: for +//! one, the source can fail to gather more data, e.g., when reading from a +//! file fails. Such errors are called _source errors._ Their type is +//! provided by the source. +//! +//! Second, data that cannot be decoded according to the syntax is said to +//! result in a _content error._ The [`ContentError`] type is used for such +//! errors. +//! +//! When decoding data from a source, both errors can happen. The type +//! `DecodeError` provides a way to store either of them and is the error +//! type you will likely encounter the most. pub use self::content::{Content, Constructed, Primitive}; -pub use self::error::Error; -pub use self::error::Error::{Malformed, Unimplemented}; -pub use self::source::{CaptureSource, LimitedSource, Source}; +pub use self::error::{ContentError, DecodeError}; +pub use self::source::{ + BytesSource, CaptureSource, IntoSource, Pos, LimitedSource, SliceSource, + Source +}; mod content; mod error; diff --git a/src/decode/source.rs b/src/decode/source.rs index 4caefca..cb4c0c1 100644 --- a/src/decode/source.rs +++ b/src/decode/source.rs @@ -1,12 +1,13 @@ //! The source for decoding data. //! -//! This is an internal module. It’s public types are re-exported by the +//! This is an internal module. Its public types are re-exported by the //! parent. -use std::mem; +use std::{error, fmt, mem, ops}; use std::cmp::min; +use std::convert::Infallible; use bytes::Bytes; -use super::error::Error; +use super::error::{ContentError, DecodeError}; //------------ Source -------------------------------------------------------- @@ -17,46 +18,33 @@ use super::error::Error; /// decoders. /// /// A source can only progress forward over time. It provides the ability to -/// access the next few bytes as a slice, advance forward, or advance forward -/// returning a Bytes value of the data it advanced over. +/// access the next few bytes as a slice or a [`Bytes`] value, and advance +/// forward. /// /// _Please note:_ This trait may change as we gain more experience with /// decoding in different circumstances. If you implement it for your own /// types, we would appreciate feedback what worked well and what didn’t. pub trait Source { - /// The error produced by the source. - /// - /// The type used here needs to wrap [`ber::decode::Error`] and extends it - /// by whatever happens if acquiring additional data fails. If `Source` - /// is implemented for types where this acqusition cannot fail, - /// `ber::decode::Error` should be used here. - /// - /// [`ber::decode::Error`]: enum.Error.html - type Err: From; + /// The error produced when the source failed to read more data. + type Error: error::Error; + + /// Returns the current logical postion within the sequence of data. + fn pos(&self) -> Pos; /// Request at least `len` bytes to be available. /// /// The method returns the number of bytes that are actually available. /// This may only be smaller than `len` if the source ends with less - /// bytes available. + /// bytes available. It may be larger than `len` but less than the total + /// number of bytes still left in the source. + /// + /// The method can be called multiple times without advancing in between. + /// If in this case `len` is larger than when last called, the source + /// should try and make the additional data available. /// /// The method should only return an error if the source somehow fails /// to get more data such as an IO error or reset connection. - fn request(&mut self, len: usize) -> Result; - - /// Advance the source by `len` bytes. - /// - /// The method advances the start of the view provided by the source by - /// `len` bytes. Advancing beyond the end of a source is an error. - /// Implementations should return their equivalient of - /// [`Error::Malformed`]. - /// - /// The value of `len` may be larger than the last length previously - /// request via [`request`]. - /// - /// [`Error::Malformed`]: enum.Error.html#variant.Malformed - /// [`request`]: #tymethod.request - fn advance(&mut self, len: usize) -> Result<(), Self::Err>; + fn request(&mut self, len: usize) -> Result; /// Returns a bytes slice with the available data. /// @@ -72,30 +60,58 @@ pub trait Source { /// The method returns a [`Bytes`] value of the range `start..end` from /// the beginning of the current view of the source. Both indexes must /// not be greater than the value returned by the last successful call - /// to [`request`]. + /// to [`request`][Self::request]. /// /// # Panics /// - /// The method panics if `start` or `end` are larger than the last - /// successful call to [`request`]. - /// - /// [`Bytes`]: ../../bytes/struct.Bytes.html - /// [`request`]: #tymethod.request + /// The method panics if `start` or `end` are larger than the result of + /// the last successful call to [`request`][Self::request]. fn bytes(&self, start: usize, end: usize) -> Bytes; + /// Advance the source by `len` bytes. + /// + /// The method advances the start of the view provided by the source by + /// `len` bytes. This value must not be greater than the value returned + /// by the last successful call to [`request`][Self::request]. + /// + /// # Panics + /// + /// The method panics if `len` is larger than the result of the last + /// successful call to [`request`][Self::request]. + fn advance(&mut self, len: usize); + + /// Skip over the next `len` bytes. + /// + /// The method attempts to advance the source by `len` bytes or by + /// however many bytes are still available if this number is smaller, + /// without making these bytes available. + /// + /// Returns the number of bytes skipped over. This value may only differ + /// from len if the remainder of the source contains less than `len` + /// bytes. + /// + /// The default implementation uses `request` and `advance`. However, for + /// some sources it may be significantly cheaper to provide a specialised + /// implementation. + fn skip(&mut self, len: usize) -> Result { + let res = min(self.request(len)?, len); + self.advance(res); + Ok(res) + } + //--- Advanced access /// Takes a single octet from the source. /// /// If there aren’t any more octets available from the source, returns - /// a malformed error. - fn take_u8(&mut self) -> Result { + /// a content error. + fn take_u8(&mut self) -> Result> { if self.request(1)? < 1 { - xerr!(return Err(Error::Malformed.into())) + return Err(self.content_err("unexpected end of data")) } let res = self.slice()[0]; - self.advance(1)?; + self.advance(1); Ok(res) } @@ -103,85 +119,258 @@ pub trait Source { /// /// If there aren’t any more octets available from the source, returns /// `Ok(None)`. - fn take_opt_u8(&mut self) -> Result, Self::Err> { + fn take_opt_u8(&mut self) -> Result, Self::Error> { if self.request(1)? < 1 { return Ok(None) } let res = self.slice()[0]; - self.advance(1)?; + self.advance(1); Ok(Some(res)) } + + /// Returns a content error at the current position of the source. + fn content_err( + &self, err: impl Into + ) -> DecodeError { + DecodeError::content(err.into(), self.pos()) + } } -impl Source for Bytes { - type Err = Error; +impl<'a, T: Source> Source for &'a mut T { + type Error = T::Error; - fn request(&mut self, _len: usize) -> Result { - Ok(self.len()) + fn request(&mut self, len: usize) -> Result { + Source::request(*self, len) } - - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - if len > self.len() { - Err(Error::Malformed) - } - else { - bytes::Buf::advance(self, len); - Ok(()) - } + + fn advance(&mut self, len: usize) { + Source::advance(*self, len) } fn slice(&self) -> &[u8] { - self.as_ref() + Source::slice(*self) } fn bytes(&self, start: usize, end: usize) -> Bytes { - self.slice(start..end) + Source::bytes(*self, start, end) + } + + fn pos(&self) -> Pos { + Source::pos(*self) } } -impl<'a> Source for &'a [u8] { - type Err = Error; - fn request(&mut self, _len: usize) -> Result { - Ok(self.len()) +//------------ IntoSource ---------------------------------------------------- + +/// A type that can be converted into a source. +pub trait IntoSource { + type Source: Source; + + fn into_source(self) -> Self::Source; +} + +impl IntoSource for T { + type Source = Self; + + fn into_source(self) -> Self::Source { + self } +} - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - if len > self.len() { - Err(Error::Malformed) - } - else { - *self = &self[len..]; - Ok(()) - } + +//------------ Pos ----------------------------------------------------------- + +/// The logical position within a source. +/// +/// Values of this type can only be used for diagnostics. They can not be used +/// to determine how far a source has been advanced since it was created. This +/// is why we used a newtype. +#[derive(Clone, Copy, Debug, Default)] +pub struct Pos(usize); + +impl From for Pos { + fn from(pos: usize) -> Pos { + Pos(pos) + } +} + +impl ops::Add for Pos { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Pos(self.0 + rhs.0) + } +} + +impl fmt::Display for Pos { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + + +//------------ BytesSource --------------------------------------------------- + +/// A source for a bytes value. +#[derive(Clone, Debug)] +pub struct BytesSource { + /// The bytes. + data: Bytes, + + /// The current read position in the data. + pos: usize, + + /// The offset for the reported position. + /// + /// This is the value reported by `Source::pos` when `self.pos` is zero. + offset: Pos, +} + +impl BytesSource { + /// Creates a new bytes source from a bytes values. + pub fn new(data: Bytes) -> Self { + BytesSource { data, pos: 0, offset: 0.into() } + } + + /// Creates a new bytes source with an explicit offset. + /// + /// When this function is used to create a bytes source, `Source::pos` + /// will report a value increates by `offset`. + pub fn with_offset(data: Bytes, offset: Pos) -> Self { + BytesSource { data, pos: 0, offset } + } + + /// Returns the remaining length of data. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns whether there is any data remaining. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Splits the first `len` bytes off the source and returns them. + /// + /// # Panics + /// + /// This method panics of `len` is larger than `self.len()`. + pub fn split_to(&mut self, len: usize) -> Bytes { + let res = self.data.split_to(len); + self.pos += len; + res + } + + /// Converts the source into the remaining bytes. + pub fn into_bytes(self) -> Bytes { + self.data + } +} + +impl Source for BytesSource { + type Error = Infallible; + + fn pos(&self) -> Pos { + self.offset + self.pos.into() + } + + fn request(&mut self, _len: usize) -> Result { + Ok(self.data.len()) } fn slice(&self) -> &[u8] { - self + self.data.as_ref() } fn bytes(&self, start: usize, end: usize) -> Bytes { - Bytes::copy_from_slice(&self[start..end]) + self.data.slice(start..end) + } + + fn advance(&mut self, len: usize) { + assert!(len <= self.data.len()); + bytes::Buf::advance(&mut self.data, len); + self.pos += len; } } -impl<'a, T: Source> Source for &'a mut T { - type Err = T::Err; +impl IntoSource for Bytes { + type Source = BytesSource; - fn request(&mut self, len: usize) -> Result { - Source::request(*self, len) + fn into_source(self) -> Self::Source { + BytesSource::new(self) } - - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - Source::advance(*self, len) +} + + +//------------ SliceSource --------------------------------------------------- + +#[derive(Clone, Copy, Debug)] +pub struct SliceSource<'a> { + data: &'a [u8], + pos: usize +} + +impl<'a> SliceSource<'a> { + /// Creates a new bytes source from a slice. + pub fn new(data: &'a [u8]) -> Self { + SliceSource { data, pos: 0 } + } + + /// Returns the remaining length of data. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns whether there is any data remaining. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Splits the first `len` bytes off the source and returns them. + /// + /// # Panics + /// + /// This method panics of `len` is larger than `self.len()`. + pub fn split_to(&mut self, len: usize) -> &'a [u8] { + let (left, right) = self.data.split_at(len); + self.data = right; + self.pos += len; + left + } +} + +impl<'a> Source for SliceSource<'a> { + type Error = Infallible; + + fn pos(&self) -> Pos { + self.pos.into() + } + + fn request(&mut self, _len: usize) -> Result { + Ok(self.data.len()) + } + + fn advance(&mut self, len: usize) { + assert!(len <= self.data.len()); + self.data = &self.data[len..]; + self.pos += len; } fn slice(&self) -> &[u8] { - Source::slice(*self) + self.data } fn bytes(&self, start: usize, end: usize) -> Bytes { - Source::bytes(*self, start, end) + Bytes::copy_from_slice(&self.data[start..end]) + } +} + +impl<'a> IntoSource for &'a [u8] { + type Source = SliceSource<'a>; + + fn into_source(self) -> Self::Source { + SliceSource::new(self) } } @@ -273,9 +462,13 @@ impl LimitedSource { /// If there currently is no limit, the method will panic. Otherwise it /// will simply advance to the end of the limit which may be something /// the underlying source doesn’t like and thus produce an error. - pub fn skip_all(&mut self) -> Result<(), S::Err> { + pub fn skip_all(&mut self) -> Result<(), DecodeError> { let limit = self.limit.unwrap(); - self.advance(limit) + if self.request(limit)? < limit { + return Err(self.content_err("unexpected end of data")) + } + self.advance(limit); + Ok(()) } /// Returns a bytes value containing all octets until the current limit. @@ -283,14 +476,17 @@ impl LimitedSource { /// If there currently is no limit, the method will panic. Otherwise it /// tries to acquire a bytes value for the octets from the current /// position to the end of the limit and advance to the end of the limit. - /// This may result in an error by the underlying source. - pub fn take_all(&mut self) -> Result { + /// + /// This will result in a source error if the underlying source returns + /// an error. It will result in a content error if the underlying source + /// ends before the limit is reached. + pub fn take_all(&mut self) -> Result> { let limit = self.limit.unwrap(); if self.request(limit)? < limit { - return Err(Error::Malformed.into()) + return Err(self.content_err("unexpected end of data")) } let res = self.bytes(0, limit); - self.advance(limit)?; + self.advance(limit); Ok(res) } @@ -303,18 +499,19 @@ impl LimitedSource { /// If there is no limit set, the method will try to access one single /// octet and return a malformed error if that is actually possible, i.e., /// if there are octets left in the underlying source. - pub fn exhausted(&mut self) -> Result<(), S::Err> { + /// + /// Any source errors are passed through. If there the data is not + /// exhausted as described above, a content error is created. + pub fn exhausted(&mut self) -> Result<(), DecodeError> { match self.limit { Some(0) => Ok(()), - Some(_limit) => { - xerr!(Err(Error::Malformed.into())) - } + Some(_limit) => Err(self.content_err("trailing data")), None => { if self.source.request(1)? == 0 { Ok(()) } else { - xerr!(Err(Error::Malformed.into())) + Err(self.content_err("trailing data")) } } } @@ -322,9 +519,13 @@ impl LimitedSource { } impl Source for LimitedSource { - type Err = S::Err; + type Error = S::Error; - fn request(&mut self, len: usize) -> Result { + fn pos(&self) -> Pos { + self.source.pos() + } + + fn request(&mut self, len: usize) -> Result { if let Some(limit) = self.limit { Ok(min(limit, self.source.request(min(limit, len))?)) } @@ -333,11 +534,12 @@ impl Source for LimitedSource { } } - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { + fn advance(&mut self, len: usize) { if let Some(limit) = self.limit { - if len > limit { - xerr!(return Err(Error::Malformed.into())) - } + assert!( + len <= limit, + "advanced past end of limit" + ); self.limit = Some(limit - len); } self.source.advance(len) @@ -376,7 +578,13 @@ impl Source for LimitedSource { /// /// [`Constructed::capture`]: struct.Constructed.html#method.capture pub struct CaptureSource<'a, S: 'a> { + /// The wrapped real source. source: &'a mut S, + + /// The number of bytes the source has promised to have for us. + len: usize, + + /// The position in the source our view starts at. pos: usize, } @@ -385,7 +593,8 @@ impl<'a, S: Source> CaptureSource<'a, S> { pub fn new(source: &'a mut S) -> Self { CaptureSource { source, - pos: 0 + len: 0, + pos: 0, } } @@ -400,26 +609,20 @@ impl<'a, S: Source> CaptureSource<'a, S> { /// /// Advances the underlying source to the end of the captured bytes. pub fn skip(self) { - assert!( - !self.source.advance(self.pos).is_err(), - "failed to advance capture source" - ); + self.source.advance(self.pos) } } impl<'a, S: Source + 'a> Source for CaptureSource<'a, S> { - type Err = S::Err; + type Error = S::Error; - fn request(&mut self, len: usize) -> Result { - self.source.request(self.pos + len).map(|res| res - self.pos) + fn pos(&self) -> Pos { + self.source.pos() + self.pos.into() } - fn advance(&mut self, len: usize) -> Result<(), Self::Err> { - if self.request(len)? < len { - return Err(Error::Malformed.into()) - } - self.pos += len; - Ok(()) + fn request(&mut self, len: usize) -> Result { + self.len = self.source.request(self.pos + len)?; + Ok(self.len - self.pos) } fn slice(&self) -> &[u8] { @@ -427,7 +630,25 @@ impl<'a, S: Source + 'a> Source for CaptureSource<'a, S> { } fn bytes(&self, start: usize, end: usize) -> Bytes { - self.source.bytes(start + self.pos, end + self.pos) + let start = start + self.pos; + let end = end + self.pos; + assert!( + self.len >= start, + "start past the end of data" + ); + assert!( + self.len >= end, + "end past the end of data" + ); + self.source.bytes(start, end) + } + + fn advance(&mut self, len: usize) { + assert!( + self.len >= self.pos + len, + "advanced past the end of data" + ); + self.pos += len; } } @@ -440,77 +661,70 @@ mod test { #[test] fn take_u8() { - let mut source = &b"123"[..]; - assert_eq!(source.take_u8(), Ok(b'1')); - assert_eq!(source.take_u8(), Ok(b'2')); - assert_eq!(source.take_u8(), Ok(b'3')); - assert_eq!(source.take_u8(), Err(Error::Malformed)); + let mut source = b"123".into_source(); + assert_eq!(source.take_u8().unwrap(), b'1'); + assert_eq!(source.take_u8().unwrap(), b'2'); + assert_eq!(source.take_u8().unwrap(), b'3'); + assert!(source.take_u8().is_err()) } #[test] fn take_opt_u8() { - let mut source = &b"123"[..]; - assert_eq!(source.take_opt_u8(), Ok(Some(b'1'))); - assert_eq!(source.take_opt_u8(), Ok(Some(b'2'))); - assert_eq!(source.take_opt_u8(), Ok(Some(b'3'))); - assert_eq!(source.take_opt_u8(), Ok(None)); + let mut source = b"123".into_source(); + assert_eq!(source.take_opt_u8().unwrap(), Some(b'1')); + assert_eq!(source.take_opt_u8().unwrap(), Some(b'2')); + assert_eq!(source.take_opt_u8().unwrap(), Some(b'3')); + assert_eq!(source.take_opt_u8().unwrap(), None); } #[test] fn bytes_impl() { - let mut bytes = Bytes::from_static(b"1234567890"); + let mut bytes = Bytes::from_static(b"1234567890").into_source(); assert!(bytes.request(4).unwrap() >= 4); assert!(&Source::slice(&bytes)[..4] == b"1234"); assert_eq!(bytes.bytes(2, 4), Bytes::from_static(b"34")); - Source::advance(&mut bytes, 4).unwrap(); + Source::advance(&mut bytes, 4); assert!(bytes.request(4).unwrap() >= 4); assert!(&Source::slice(&bytes)[..4] == b"5678"); - Source::advance(&mut bytes, 4).unwrap(); + Source::advance(&mut bytes, 4); assert_eq!(bytes.request(4).unwrap(), 2); - assert!(&Source::slice(&bytes)[..] == b"90"); - assert_eq!( - Source::advance(&mut bytes, 4).unwrap_err(), - Error::Malformed - ); + assert!(&Source::slice(&bytes) == b"90"); + bytes.advance(2); + assert_eq!(bytes.request(4).unwrap(), 0); } #[test] fn slice_impl() { - let mut bytes = &b"1234567890"[..]; + let mut bytes = b"1234567890".into_source(); assert!(bytes.request(4).unwrap() >= 4); - assert!(&Source::slice(&bytes)[..4] == b"1234"); + assert!(&bytes.slice()[..4] == b"1234"); assert_eq!(bytes.bytes(2, 4), Bytes::from_static(b"34")); - Source::advance(&mut bytes, 4).unwrap(); + bytes.advance(4); assert!(bytes.request(4).unwrap() >= 4); - assert!(&Source::slice(&bytes)[..4] == b"5678"); - Source::advance(&mut bytes, 4).unwrap(); + assert!(&bytes.slice()[..4] == b"5678"); + bytes.advance(4); assert_eq!(bytes.request(4).unwrap(), 2); - assert!(&Source::slice(&bytes)[..] == b"90"); - assert_eq!( - Source::advance(&mut bytes, 4).unwrap_err(), - Error::Malformed - ); + assert!(&bytes.slice() == b"90"); + bytes.advance(2); + assert_eq!(bytes.request(4).unwrap(), 0); } #[test] fn limited_source() { - let mut the_source = LimitedSource::new(&b"12345678"[..]); + let mut the_source = LimitedSource::new( + b"12345678".into_source() + ); the_source.set_limit(Some(4)); let mut source = the_source.clone(); - assert_eq!(source.exhausted(), Err(Error::Malformed)); + assert!(source.exhausted().is_err()); assert_eq!(source.request(6).unwrap(), 4); - source.advance(2).unwrap(); - assert_eq!(source.exhausted(), Err(Error::Malformed)); + source.advance(2); + assert!(source.exhausted().is_err()); assert_eq!(source.request(6).unwrap(), 2); - source.advance(2).unwrap(); + source.advance(2); source.exhausted().unwrap(); assert_eq!(source.request(6).unwrap(), 0); - source.advance(0).unwrap(); - assert_eq!(source.advance(2).unwrap_err(), Error::Malformed); - - let mut source = the_source.clone(); - assert_eq!(source.advance(5).unwrap_err(), Error::Malformed); let mut source = the_source.clone(); source.skip_all().unwrap(); @@ -524,30 +738,46 @@ mod test { assert_eq!(source.slice(), b"5678"); } + #[test] + #[should_panic] + fn limited_source_far_advance() { + let mut source = LimitedSource::new( + b"12345678".into_source() + ); + source.set_limit(Some(4)); + assert_eq!(source.request(6).unwrap(), 4); + source.advance(4); + assert_eq!(source.request(6).unwrap(), 0); + source.advance(6); // panics + } + #[test] #[should_panic] fn limit_further() { - let mut source = LimitedSource::new(&b"12345"); + let mut source = LimitedSource::new(b"12345".into_source()); source.set_limit(Some(4)); - source.limit_further(Some(5)); + source.limit_further(Some(5)); // panics } #[test] fn capture_source() { - let mut source = &b"1234567890"[..]; + let mut source = b"1234567890".into_source(); { let mut capture = CaptureSource::new(&mut source); - capture.advance(4).unwrap(); + assert_eq!(capture.request(4).unwrap(), 10); + capture.advance(4); assert_eq!(capture.into_bytes(), Bytes::from_static(b"1234")); } - assert_eq!(source, b"567890"); + assert_eq!(source.data, b"567890"); - let mut source = &b"1234567890"[..]; + let mut source = b"1234567890".into_source(); { let mut capture = CaptureSource::new(&mut source); - capture.advance(4).unwrap(); + assert_eq!(capture.request(4).unwrap(), 10); + capture.advance(4); capture.skip(); } - assert_eq!(source, b"567890"); + assert_eq!(source.data, b"567890"); } } + diff --git a/src/encode/values.rs b/src/encode/values.rs index 3218764..e541525 100644 --- a/src/encode/values.rs +++ b/src/encode/values.rs @@ -48,7 +48,7 @@ pub trait Values { fn to_captured(&self, mode: Mode) -> Captured { let mut target = Vec::new(); self.write_encoded(mode, &mut target).unwrap(); - Captured::new(target.into(), mode) + Captured::new(target.into(), mode, Default::default()) } } diff --git a/src/guide/decode.rs b/src/guide/decode.rs index bd4ad5d..77134f7 100644 --- a/src/guide/decode.rs +++ b/src/guide/decode.rs @@ -19,7 +19,7 @@ //! take them as function arguments such as closures. //! //! An example will make this concept more clear. Let’s say we have the -//! following ASN.1 specifiction: +//! following ASN.1 specification: //! //! ```text //! EncapsulatedContentInfo ::= SEQUENCE { @@ -49,16 +49,17 @@ //! # use bcder::{Oid, OctetString}; //! use bcder::Tag; //! use bcder::decode; +//! use bcder::decode::DecodeError; //! //! # pub struct EncapsulatedContentInfo { //! # content_type: Oid, //! # content: Option, //! # } -//! # +//! # //! impl EncapsulatedContentInfo { //! pub fn take_from( //! cons: &mut decode::Constructed -//! ) -> Result { +//! ) -> Result> { //! cons.take_sequence(|cons| { //! Ok(EncapsulatedContentInfo { //! content_type: Oid::take_from(cons)?, @@ -89,16 +90,17 @@ //! # use bcder::{Oid, OctetString}; //! use bcder::Tag; //! use bcder::decode; +//! use bcder::decode::DecodeError; //! //! # pub struct EncapsulatedContentInfo { //! # content_type: Oid, //! # content: Option, //! # } -//! # +//! # //! impl EncapsulatedContentInfo { //! pub fn from_constructed( //! cons: &mut decode::Constructed -//! ) -> Result { +//! ) -> Result> { //! Ok(EncapsulatedContentInfo { //! content_type: Oid::take_from(cons)?, //! content: cons.take_opt_constructed_if(Tag::ctx(0), |cons| { @@ -110,4 +112,3 @@ //! ``` //! //! _TODO: Elaborate._ - diff --git a/src/int.rs b/src/int.rs index 160f12e..c695e02 100644 --- a/src/int.rs +++ b/src/int.rs @@ -13,11 +13,11 @@ //! [`Integer`]: struct.Integer.html //! [`Unsigned`]: struct.Unsigned.html -use std::{cmp, fmt, hash, io, mem}; +use std::{cmp, error, fmt, hash, io, mem}; use std::convert::TryFrom; use bytes::Bytes; use crate::decode; -use crate::decode::Source; +use crate::decode::{DecodeError, Source}; use crate::encode::PrimitiveContent; use crate::mode::Mode; use crate::tag::Tag; @@ -35,7 +35,7 @@ macro_rules! slice_to_builtin { ( signed, $slice:expr, $type:ident, $err:expr) => {{ const LEN: usize = mem::size_of::<$type>(); if $slice.len() > LEN { - Err($err) + $err } else { // Start with all zeros if positive or all 0xFF if negative @@ -55,7 +55,7 @@ macro_rules! slice_to_builtin { // out if the sign bit is set. const LEN: usize = mem::size_of::<$type>(); if $slice[0] & 0x80 != 0 { - Err($err) + $err } else { let val = if $slice[0] == 0 { &$slice[1..] } @@ -64,7 +64,7 @@ macro_rules! slice_to_builtin { Ok(0) } else if val.len() > LEN { - Err($err) + $err } else { let mut res = [0; LEN]; @@ -81,7 +81,8 @@ macro_rules! decode_builtin { let res = { let slice = $prim.slice_all()?; slice_to_builtin!( - $flavor, slice, $type, decode::Malformed.into() + $flavor, slice, $type, + Err($prim.content_err("invalid integer")) )? }; $prim.skip_all()?; @@ -108,7 +109,7 @@ macro_rules! builtin_from { fn try_from(val: &'a $from) -> Result<$to, Self::Error> { let val = val.as_slice(); - slice_to_builtin!($flavor, val, $to, OverflowError(())) + slice_to_builtin!($flavor, val, $to, Err(OverflowError(()))) } } @@ -161,24 +162,24 @@ impl Integer { /// a correctly encoded signed integer. pub fn take_from( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.take_primitive_if(Tag::INTEGER, Self::from_primitive) } /// Constructs a signed integer from the content of a primitive value. pub fn from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { let res = prim.take_all()?; - match (res.get(0), res.get(1).map(|x| x & 0x80 != 0)) { + match (res.first(), res.get(1).map(|x| x & 0x80 != 0)) { (Some(0), Some(false)) => { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(prim.content_err("invalid integer")) } (Some(0xFF), Some(true)) => { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(prim.content_err("invalid integer")) } (None, _) => { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(prim.content_err("invalid integer")) } _ => { } } @@ -188,7 +189,7 @@ impl Integer { /// Constructs an `i8` from the content of a primitive value. pub fn i8_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { Self::check_head(prim)?; prim.take_u8().map(|x| x as i8) } @@ -196,28 +197,28 @@ impl Integer { /// Constructs an `i16` from the content of a primitive value. pub fn i16_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { decode_builtin!(signed, prim, i16) } /// Constructs an `i32` from the content of a primitive value. pub fn i32_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { decode_builtin!(signed, prim, i32) } /// Constructs an `i64` from the content of a primitive value. pub fn i64_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { decode_builtin!(signed, prim, i64) } /// Constructs an `i128` from the content of a primitive value. pub fn i128_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { decode_builtin!(signed, prim, i128) } @@ -232,17 +233,17 @@ impl Integer { /// for equality comparision. fn check_head( prim: &mut decode::Primitive - ) -> Result<(), S::Err> { + ) -> Result<(), DecodeError> { if prim.request(2)? == 0 { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(prim.content_err("invalid integer")) } let slice = prim.slice(); - match (slice.get(0), slice.get(1).map(|x| x & 0x80 != 0)) { + match (slice.first(), slice.get(1).map(|x| x & 0x80 != 0)) { (Some(0), Some(false)) => { - xerr!(Err(decode::Error::Malformed.into())) + Err(prim.content_err("invalid integer")) } (Some(0xFF), Some(true)) => { - xerr!(Err(decode::Error::Malformed.into())) + Err(prim.content_err("invalid integer")) } _ => Ok(()) } @@ -450,8 +451,8 @@ impl Unsigned { /// /// # Errors /// - /// Will return `Error::Malformed` if the given slice is empty. - pub fn from_slice(slice: &[u8]) -> Result { + /// Will return a malformed error if the given slice is empty. + pub fn from_slice(slice: &[u8]) -> Result { Self::from_bytes(Bytes::copy_from_slice(slice)) } @@ -459,14 +460,16 @@ impl Unsigned { /// /// # Errors /// - /// Will return `Error::Malformed` if the given Bytes value is empty. - pub fn from_bytes(bytes: Bytes) -> Result { + /// Will return a malformed error if the given slice is empty. + pub fn from_bytes(bytes: Bytes) -> Result { if bytes.is_empty() { - return Err(crate::decode::Error::Malformed); + return Err(InvalidInteger(())) } // Skip any leading zero bytes. - let num_leading_zero_bytes = bytes.as_ref().iter().take_while(|&&b| b == 0x00).count(); + let num_leading_zero_bytes = bytes.as_ref().iter().take_while(|&&b| { + b == 0x00 + }).count(); let value = bytes.slice(num_leading_zero_bytes..); // Create a new Unsigned integer from the given value bytes, ensuring @@ -490,39 +493,39 @@ impl Unsigned { pub fn take_from( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.take_primitive_if(Tag::INTEGER, Self::from_primitive) } pub fn from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { Self::check_head(prim)?; Integer::from_primitive(prim).map(Unsigned) } pub fn u8_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { Self::check_head(prim)?; match prim.remaining() { 1 => prim.take_u8(), // sign bit has been checked above. 2 => { // First byte must be 0x00, second is the result. if prim.take_u8()? != 0 { - xerr!(Err(decode::Malformed.into())) + Err(prim.content_err("invalid integer")) } else { prim.take_u8() } } - _ => xerr!(Err(decode::Malformed.into())) + _ => Err(prim.content_err("invalid integer")) } } pub fn u16_from_primitive( prim: &mut decode::Primitive - ) -> Result { + ) -> Result> { Self::check_head(prim)?; match prim.remaining() { 1 => Ok(prim.take_u8()?.into()), @@ -534,7 +537,7 @@ impl Unsigned { } 3 => { if prim.take_u8()? != 0 { - xerr!(return Err(decode::Malformed.into())); + return Err(prim.content_err("invalid integer")) } let res = { u16::from(prim.take_u8()?) << 8 | @@ -542,34 +545,31 @@ impl Unsigned { }; if res < 0x8000 { // This could have been in fewer bytes. - Err(decode::Malformed.into()) + Err(prim.content_err("invalid integer")) } else { Ok(res) } } - _ => xerr!(Err(decode::Malformed.into())) + _ => Err(prim.content_err("invalid integer")) } } pub fn u32_from_primitive( prim: &mut decode::Primitive - ) -> Result { - Self::check_head(prim)?; + ) -> Result> { decode_builtin!(unsigned, prim, u32) } pub fn u64_from_primitive( prim: &mut decode::Primitive - ) -> Result { - Self::check_head(prim)?; + ) -> Result> { decode_builtin!(unsigned, prim, u64) } pub fn u128_from_primitive( prim: &mut decode::Primitive - ) -> Result { - Self::check_head(prim)?; + ) -> Result> { decode_builtin!(unsigned, prim, u128) } @@ -579,10 +579,10 @@ impl Unsigned { /// sign bit is not set. fn check_head( prim: &mut decode::Primitive - ) -> Result<(), S::Err> { + ) -> Result<(), DecodeError> { Integer::check_head(prim)?; - if prim.slice().get(0).unwrap() & 0x80 != 0 { - xerr!(Err(decode::Error::Malformed.into())) + if prim.slice().first().unwrap() & 0x80 != 0 { + Err(prim.content_err("invalid integer")) } else { Ok(()) @@ -626,8 +626,8 @@ builtin_from!(unsigned, Unsigned, u64); builtin_from!(unsigned, Unsigned, u128); -impl<'a> TryFrom for Unsigned { - type Error = crate::decode::Error; +impl TryFrom for Unsigned { + type Error = InvalidInteger; fn try_from(value: Bytes) -> Result { Unsigned::from_bytes(value) @@ -674,6 +674,21 @@ impl<'a> PrimitiveContent for &'a Unsigned { } +//------------ InvalidInteger ------------------------------------------------ + +/// A octets slice does not contain a validly encoded integer. +#[derive(Clone, Copy, Debug)] +pub struct InvalidInteger(()); + +impl fmt::Display for InvalidInteger { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "invalid integer") + } +} + +impl error::Error for InvalidInteger { } + + //------------ OverflowError ------------------------------------------------- #[derive(Clone, Copy, Debug)] @@ -685,6 +700,8 @@ impl fmt::Display for OverflowError { } } +impl error::Error for OverflowError { } + //============ Tests ========================================================= @@ -709,15 +726,15 @@ mod test { let pos = [0xF7, 0xF744, 0xF74402]; for &i in &neg { - assert_eq!(Integer::from(i).is_positive(), false, "{}", i); - assert_eq!(Integer::from(i).is_negative(), true, "{}", i); + assert!(!Integer::from(i).is_positive(), "{}", i); + assert!(Integer::from(i).is_negative(), "{}", i); } for &i in &pos { - assert_eq!(Integer::from(i).is_positive(), true, "{}", i); - assert_eq!(Integer::from(i).is_negative(), false, "{}", i); + assert!(Integer::from(i).is_positive(), "{}", i); + assert!(!Integer::from(i).is_negative(), "{}", i); } - assert_eq!(Integer::from(0).is_positive(), false); - assert_eq!(Integer::from(0).is_negative(), false); + assert!(!Integer::from(0).is_positive()); + assert!(!Integer::from(0).is_negative()); } #[test] @@ -750,12 +767,11 @@ mod test { ).unwrap(), 0x7f ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x80".as_ref(), Mode::Der, |prim| Unsigned::u8_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( Primitive::decode_slice( @@ -786,19 +802,17 @@ mod test { ).unwrap(), 0xA234 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\xA2\x34".as_ref(), Mode::Der, |prim| Unsigned::u16_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x00\x12\x34".as_ref(), Mode::Der, |prim| Unsigned::u16_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -836,19 +850,17 @@ mod test { ).unwrap(), 0xA2345678 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x00\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u32_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\xa2\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u32_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -872,19 +884,17 @@ mod test { ).unwrap(), 0xa234567812345678 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\0\x12\x34\x56\x78\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u64_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x30\x12\x34\x56\x78\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u64_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -910,21 +920,19 @@ mod test { ).unwrap(), 0xa2345678123456781234567812345678 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\0\x12\x34\x56\x78\x12\x34\x56\x78 \x12\x34\x56\x78\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u128_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x30\x12\x34\x56\x78\x12\x34\x56\x78 \x12\x34\x56\x78\x12\x34\x56\x78".as_ref(), Mode::Der, |prim| Unsigned::u128_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); } @@ -944,19 +952,17 @@ mod test { ).unwrap(), -1 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x00\xFF".as_ref(), Mode::Der, |prim| Integer::i8_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x40\xFF".as_ref(), Mode::Der, |prim| Integer::i8_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -980,19 +986,17 @@ mod test { ).unwrap(), -32513 ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x80\xFF\x32".as_ref(), Mode::Der, |prim| Integer::i16_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); - assert_eq!( + assert!( Primitive::decode_slice( b"\x00\xFF\x32".as_ref(), Mode::Der, |prim| Integer::i16_from_primitive(prim) - ).unwrap_err(), - decode::Malformed + ).is_err() ); assert_eq!( @@ -1140,28 +1144,66 @@ mod test { #[test] fn encode_variable_length_unsigned_from_slice() { - assert_eq!(Unsigned::from_slice(&[]), Err(crate::decode::Error::Malformed)); + assert!(Unsigned::from_slice(&[]).is_err()); test_der(&Unsigned::from_slice(&[0xFF]).unwrap(), b"\x00\xFF"); test_der(&Unsigned::from_slice(&[0x00, 0xFF]).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_slice(&[0x00, 0x00, 0xFF]).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_slice(&[0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF]).unwrap(), b"\x00\xDE\xAD\xBE\xEF"); + test_der( + &Unsigned::from_slice(&[0x00, 0x00, 0xFF]).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::from_slice( + &[0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF] + ).unwrap(), + b"\x00\xDE\xAD\xBE\xEF" + ); } #[test] fn encode_variable_length_unsigned_from_bytes() { - assert_eq!(Unsigned::from_bytes(Bytes::new()), Err(crate::decode::Error::Malformed)); - test_der(&Unsigned::from_bytes(Bytes::from(vec![0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_bytes(Bytes::from(vec![0x00, 0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_bytes(Bytes::from(vec![0x00, 0x00, 0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::from_bytes(Bytes::from(vec![0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF])).unwrap(), b"\x00\xDE\xAD\xBE\xEF"); + assert!(Unsigned::from_bytes(Bytes::new()).is_err()); + test_der( + &Unsigned::from_bytes(Bytes::from(vec![0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::from_bytes(Bytes::from(vec![0x00, 0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::from_bytes(Bytes::from( + vec![0x00, 0x00, 0xFF]) + ).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::from_bytes(Bytes::from( + vec![0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF] + )).unwrap(), + b"\x00\xDE\xAD\xBE\xEF" + ); } #[test] fn encode_variable_length_unsigned_try_from_bytes() { - assert_eq!(Unsigned::try_from(Bytes::new()), Err(crate::decode::Error::Malformed)); - test_der(&Unsigned::try_from(Bytes::from(vec![0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::try_from(Bytes::from(vec![0x00, 0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::try_from(Bytes::from(vec![0x00, 0x00, 0xFF])).unwrap(), b"\x00\xFF"); - test_der(&Unsigned::try_from(Bytes::from(vec![0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF])).unwrap(), b"\x00\xDE\xAD\xBE\xEF"); + assert!(Unsigned::try_from(Bytes::new()).is_err()); + test_der( + &Unsigned::try_from(Bytes::from(vec![0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::try_from(Bytes::from(vec![0x00, 0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::try_from(Bytes::from(vec![0x00, 0x00, 0xFF])).unwrap(), + b"\x00\xFF" + ); + test_der( + &Unsigned::try_from(Bytes::from( + vec![0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF] + )).unwrap(), + b"\x00\xDE\xAD\xBE\xEF" + ); } } diff --git a/src/length.rs b/src/length.rs index e2ecba4..5c36ff6 100644 --- a/src/length.rs +++ b/src/length.rs @@ -3,7 +3,7 @@ //! This is a private module. Its public tiems are re-exported by the parent. use std::io; -use crate::decode; +use crate::decode::{DecodeError, Source}; use crate::mode::Mode; @@ -52,10 +52,10 @@ pub enum Length { impl Length { /// Takes a length value from the beginning of a source. - pub fn take_from( + pub fn take_from( source: &mut S, mode: Mode - ) -> Result { + ) -> Result> { match source.take_u8()? { // Bit 7 clear: other bits are the length n if (n & 0x80) == 0 => Ok(Length::Definite(n as usize)), @@ -70,7 +70,7 @@ impl Length { Ok(Length::Definite(len)) } else { - Err(decode::Error::Malformed.into()) + Err(source.content_err("invalid length")) } } 0x82 => { @@ -81,7 +81,7 @@ impl Length { Ok(Length::Definite(len)) } else { - Err(decode::Error::Malformed.into()) + Err(source.content_err("invalid length")) } } 0x83 => { @@ -93,7 +93,7 @@ impl Length { Ok(Length::Definite(len)) } else { - Err(decode::Error::Malformed.into()) + Err(source.content_err("invalid length")) } } 0x84 => { @@ -106,12 +106,14 @@ impl Length { Ok(Length::Definite(len)) } else { - Err(decode::Error::Malformed.into()) + Err(source.content_err("invalid length")) } } _ => { // We only implement up to two length bytes for now. - Err(decode::Error::Unimplemented.into()) + Err(source.content_err( + "lengths over 4 bytes not implemented" + )) } } } diff --git a/src/lib.rs b/src/lib.rs index 79974f8..de5084b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,8 +51,6 @@ pub use self::tag::Tag; //--- Public modules -#[macro_use] pub mod debug; - pub mod decode; pub mod encode; diff --git a/src/mode.rs b/src/mode.rs index 9ce8eeb..7fd2bed 100644 --- a/src/mode.rs +++ b/src/mode.rs @@ -3,6 +3,7 @@ //! This is a private module. It’s public items are re-exported by the parent. use crate::decode; +use crate::decode::DecodeError; //------------ Mode ---------------------------------------------------------- @@ -50,10 +51,14 @@ impl Mode { /// by this value. The closure `op` will be given the content of the /// source as a sequence of values. The closure does not need to process /// all values in the source. - pub fn decode(self, source: S, op: F) -> Result + pub fn decode( + self, source: S, op: F, + ) -> Result::Error>> where - S: decode::Source, - F: FnOnce(&mut decode::Constructed) -> Result + S: decode::IntoSource, + F: FnOnce( + &mut decode::Constructed + ) -> Result::Error>>, { decode::Constructed::decode(source, self, op) } diff --git a/src/oid.rs b/src/oid.rs index 7bcda9f..a783bae 100644 --- a/src/oid.rs +++ b/src/oid.rs @@ -8,8 +8,8 @@ use std::{fmt, hash, io}; use bytes::Bytes; -use crate::{decode, encode}; -use crate::decode::Source; +use crate::encode; +use crate::decode::{Constructed, DecodeError, Source}; use crate::mode::Mode; use crate::tag::Tag; @@ -44,7 +44,7 @@ use crate::tag::Tag; /// and produces the `u8` array for their encoded value. You can install /// this binary via `cargo install ber`. #[derive(Clone, Debug)] -pub struct Oid=Bytes>(pub T); +pub struct Oid = Bytes>(pub T); /// A type alias for `Oid<&'static [u8]>. /// @@ -60,9 +60,9 @@ impl Oid { /// If the source has reached its end, if the next value does not have /// the `Tag::OID`, or if it is not a primitive value, returns a malformed /// error. - pub fn skip_in( - cons: &mut decode::Constructed - ) -> Result<(), S::Err> { + pub fn skip_in( + cons: &mut Constructed + ) -> Result<(), DecodeError> { cons.take_primitive_if(Tag::OID, |prim| prim.skip_all()) } @@ -71,9 +71,9 @@ impl Oid { /// If the source has reached its end of if the next value does not have /// the `Tag::OID`, returns `Ok(None)`. If the next value has the right /// tag but is not a primitive value, returns a malformed error. - pub fn skip_opt_in( - cons: &mut decode::Constructed - ) -> Result, S::Err> { + pub fn skip_opt_in( + cons: &mut Constructed + ) -> Result, DecodeError> { cons.take_opt_primitive_if(Tag::OID, |prim| prim.skip_all()) } @@ -82,9 +82,9 @@ impl Oid { /// If the source has reached its end, if the next value does not have /// the `Tag::OID`, or if it is not a primitive value, returns a malformed /// error. - pub fn take_from( - constructed: &mut decode::Constructed - ) -> Result { + pub fn take_from( + constructed: &mut Constructed + ) -> Result> { constructed.take_primitive_if(Tag::OID, |content| { content.take_all().map(Oid) }) @@ -95,9 +95,9 @@ impl Oid { /// If the source has reached its end of if the next value does not have /// the `Tag::OID`, returns `Ok(None)`. If the next value has the right /// tag but is not a primitive value, returns a malformed error. - pub fn take_opt_from( - constructed: &mut decode::Constructed - ) -> Result, S::Err> { + pub fn take_opt_from( + constructed: &mut Constructed + ) -> Result, DecodeError> { constructed.take_opt_primitive_if(Tag::OID, |content| { content.take_all().map(Oid) }) @@ -106,10 +106,9 @@ impl Oid { impl> Oid { /// Skip over an object identifier if it matches `self`. - pub fn skip_if( - &self, - constructed: &mut decode::Constructed - ) -> Result<(), S::Err> { + pub fn skip_if( + &self, constructed: &mut Constructed, + ) -> Result<(), DecodeError> { constructed.take_primitive_if(Tag::OID, |content| { let len = content.remaining(); content.request(len)?; @@ -118,7 +117,7 @@ impl> Oid { Ok(()) } else { - xerr!(Err(decode::Error::Malformed.into())) + Err(content.content_err("object identifier mismatch")) } }) } @@ -177,7 +176,7 @@ impl> fmt::Display for Oid { // I can’t be bothered to figure out how to convert a seven // bit integer into decimal. let mut components = self.iter(); - // There’s at least one and it is always an valid u32. + // There’s at least one and it is always a valid u32. write!(f, "{}", components.next().unwrap().to_u32().unwrap())?; for component in components { if let Some(val) = component.to_u32() { diff --git a/src/string/bit.rs b/src/string/bit.rs index b998551..0c5624e 100644 --- a/src/string/bit.rs +++ b/src/string/bit.rs @@ -5,7 +5,7 @@ use std::io; use bytes::Bytes; use crate::{decode, encode}; -use crate::decode::Source; +use crate::decode::{DecodeError, Source}; use crate::length::Length; use crate::mode::Mode; use crate::tag::Tag; @@ -129,25 +129,27 @@ impl BitString { /// Takes a single bit string value from constructed content. pub fn take_from( constructed: &mut decode::Constructed - ) -> Result { + ) -> Result> { constructed.take_value_if(Tag::BIT_STRING, Self::from_content) } /// Skip over a single bit string value inside constructed content. pub fn skip_in( cons: &mut decode::Constructed - ) -> Result<(), S::Err> { + ) -> Result<(), DecodeError> { cons.take_value_if(Tag::BIT_STRING, Self::skip_content) } /// Parses the content octets of a bit string value. pub fn from_content( content: &mut decode::Content - ) -> Result { + ) -> Result> { match *content { decode::Content::Primitive(ref mut inner) => { if inner.mode() == Mode::Cer && inner.remaining() > 1000 { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(content.content_err( + "long bit string component in CER mode" + )) } Ok(BitString { unused: inner.take_u8()?, @@ -156,10 +158,14 @@ impl BitString { } decode::Content::Constructed(ref inner) => { if inner.mode() == Mode::Der { - xerr!(Err(decode::Error::Malformed.into())) + Err(content.content_err( + "constructed bit string in DER mode" + )) } else { - xerr!(Err(decode::Error::Unimplemented.into())) + Err(content.content_err( + "constructed bit string not implemented" + )) } } } @@ -168,20 +174,26 @@ impl BitString { /// Skips over the content octets of a bit string value. pub fn skip_content( content: &mut decode::Content - ) -> Result<(), S::Err> { + ) -> Result<(), DecodeError> { match *content { decode::Content::Primitive(ref mut inner) => { if inner.mode() == Mode::Cer && inner.remaining() > 1000 { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(content.content_err( + "long bit string component in CER mode" + )) } inner.skip_all() } decode::Content::Constructed(ref inner) => { if inner.mode() == Mode::Der { - xerr!(Err(decode::Error::Malformed.into())) + Err(content.content_err( + "constructed bit string in DER mode" + )) } else { - xerr!(Err(decode::Error::Unimplemented.into())) + Err(content.content_err( + "constructed bit string not implemented" + )) } } } diff --git a/src/string/octet.rs b/src/string/octet.rs index 68a5b66..a39a000 100644 --- a/src/string/octet.rs +++ b/src/string/octet.rs @@ -7,9 +7,13 @@ //! CER mode. An implementation of that is TODO. use std::{cmp, hash, io, mem}; +use std::convert::Infallible; use bytes::{BytesMut, Bytes}; use crate::captured::Captured; use crate::{decode, encode}; +use crate::decode::{ + BytesSource, DecodeError, IntoSource, Pos, SliceSource, Source +}; use crate::mode::Mode; use crate::length::Length; use crate::tag::Tag; @@ -82,7 +86,9 @@ impl OctetString { OctetStringIter(Inner::Primitive(inner.as_ref())) } Inner::Constructed(ref inner) => { - OctetStringIter(Inner::Constructed(inner.as_ref())) + OctetStringIter( + Inner::Constructed(inner.as_slice().into_source()) + ) } } } @@ -154,15 +160,6 @@ impl OctetString { } !self.iter().any(|s| !s.is_empty()) } - - /// Creates a source that can be used to decode the string’s content. - /// - /// The returned value contains a clone of the string (which, because of - /// the use of `Bytes` is rather cheap) that implements the `Source` - /// trait and thus can be used to decode the string’s content. - pub fn to_source(&self) -> OctetStringSource { - OctetStringSource::new(self) - } } @@ -176,7 +173,7 @@ impl OctetString { /// octet string, a malformed error is returned. pub fn take_from( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.take_value_if(Tag::OCTET_STRING, Self::from_content) } @@ -189,18 +186,20 @@ impl OctetString { /// malformed error is returned. pub fn take_opt_from( cons: &mut decode::Constructed - ) -> Result, S::Err> { + ) -> Result, DecodeError> { cons.take_opt_value_if(Tag::OCTET_STRING, Self::from_content) } /// Takes an octet string value from content. pub fn from_content( content: &mut decode::Content - ) -> Result { + ) -> Result> { match *content { decode::Content::Primitive(ref mut inner) => { if inner.mode() == Mode::Cer && inner.remaining() > 1000 { - xerr!(return Err(decode::Error::Malformed.into())) + return Err(content.content_err( + "long string component in CER mode" + )) } Ok(OctetString(Inner::Primitive(inner.take_all()?))) } @@ -209,7 +208,9 @@ impl OctetString { Mode::Ber => Self::take_constructed_ber(inner), Mode::Cer => Self::take_constructed_cer(inner), Mode::Der => { - xerr!(Err(decode::Error::Malformed.into())) + Err(content.content_err( + "constructed string in DER mode" + )) } } } @@ -221,14 +222,14 @@ impl OctetString { /// It consists octet string values either primitive or constructed. fn take_constructed_ber( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.capture(|cons| { while cons.skip_opt(|tag, _, _| if tag == Tag::OCTET_STRING { Ok(()) } else { - xerr!(Err(decode::Malformed.into())) + Err("expected octet string".into()) } )?.is_some() { } Ok(()) @@ -241,17 +242,21 @@ impl OctetString { /// values each except for the last one exactly 1000 octets long. fn take_constructed_cer( constructed: &mut decode::Constructed - ) -> Result { + ) -> Result> { let mut short = false; constructed.capture(|con| { while let Some(()) = con.take_opt_primitive_if(Tag::OCTET_STRING, |primitive| { if primitive.remaining() > 1000 { - xerr!(return Err(decode::Error::Malformed.into())); + return Err(primitive.content_err( + "long string component in CER mode" + )); } if primitive.remaining() < 1000 { if short { - xerr!(return Err(decode::Error::Malformed.into())); + return Err(primitive.content_err( + "short non-terminal string component in CER mode" + )); } short = true } @@ -467,11 +472,20 @@ impl<'a> IntoIterator for &'a OctetString { } +//--- IntoSource + +impl IntoSource for OctetString { + type Source = OctetStringSource; + + fn into_source(self) -> Self::Source { + OctetStringSource::new(self) + } +} + + //------------ OctetStringSource --------------------------------------------- -/// A source atop an octet string. -/// -/// You can get a value of this type by calling `OctetString::source`. +/// A decode source atop an octet string. // // Assuming we have a correctly encoded octet string, its content is a // sequence of value headers (i.e., tag and length octets) and actual string @@ -487,34 +501,52 @@ pub struct OctetStringSource { current: Bytes, /// The remainder of the value after the value in `current`. - remainder: Bytes, + remainder: BytesSource, + + /// The current position in the string. + pos: Pos, } impl OctetStringSource { /// Creates a new source atop an existing octet string. - fn new(from: &OctetString) -> Self { + fn new(from: OctetString) -> Self { + Self::with_offset(from, Pos::default()) + } + + /// Creates a new source with a given start position. + fn with_offset(from: OctetString, offset: Pos) -> Self { match from.0 { - Inner::Primitive(ref inner) => { + Inner::Primitive(inner) => { OctetStringSource { - current: inner.clone(), - remainder: Bytes::new(), + current: inner, + remainder: Bytes::new().into_source(), + pos: offset, } } - Inner::Constructed(ref inner) => { + Inner::Constructed(inner) => { OctetStringSource { current: Bytes::new(), - remainder: inner.clone().into_bytes() + remainder: inner.into_bytes().into_source(), + pos: offset, } } } } - /// Returns the bytes of the next primitive value in the string. + /// Returns the next value for `self.current`. + /// + /// This is the content of the first primitive value found in the + /// remainder. /// /// Returns `None` if we are done. - fn next_primitive(&mut self) -> Option { - while !self.remainder.is_empty() { - let (tag, cons) = Tag::take_from(&mut self.remainder).unwrap(); + fn next_current(&mut self) -> Option { + // Unwrapping here is okay. The only error that can happen is that + // the tag is longer that we support. However, we already checked that + // there’s only OctetString or End of Value tags which we _do_ + // support. + while let Some((tag, cons)) = Tag::take_opt_from( + &mut self.remainder + ).unwrap() { let length = Length::take_from( &mut self.remainder, Mode::Ber ).unwrap(); @@ -538,15 +570,19 @@ impl OctetStringSource { } impl decode::Source for OctetStringSource { - type Err = decode::Error; + type Error = Infallible; - fn request(&mut self, len: usize) -> Result { + fn pos(&self) -> Pos { + self.pos + } + + fn request(&mut self, len: usize) -> Result { if self.current.len() < len && !self.remainder.is_empty() { // Make a new current that is at least `len` long. let mut current = BytesMut::with_capacity(self.current.len()); current.extend_from_slice(&self.current.clone()); while current.len() < len { - if let Some(bytes) = self.next_primitive() { + if let Some(bytes) = self.next_current() { current.extend_from_slice(bytes.as_ref()) } else { @@ -558,18 +594,10 @@ impl decode::Source for OctetStringSource { Ok(self.current.len()) } - fn advance(&mut self, mut len: usize) -> Result<(), decode::Error> { - while len > self.current.len() { - len -= self.current.len(); - self.current = match self.next_primitive() { - Some(value) => value, - None => { - xerr!(return Err(decode::Error::Malformed)) - } - } - } - self.current.advance(len)?; - Ok(()) + fn advance(&mut self, len: usize) { + assert!(len <= self.current.len()); + self.pos = self.pos + len.into(); + bytes::Buf::advance(&mut self.current, len) } fn slice(&self) -> &[u8] { @@ -589,7 +617,7 @@ impl decode::Source for OctetStringSource { /// You can get a value of this type by calling `OctetString::iter` or relying /// on the `IntoIterator` impl for a `&OctetString`. #[derive(Clone, Debug)] -pub struct OctetStringIter<'a>(Inner<&'a [u8], &'a [u8]>); +pub struct OctetStringIter<'a>(Inner<&'a [u8], SliceSource<'a>>); impl<'a> Iterator for OctetStringIter<'a> { type Item = &'a [u8]; @@ -617,9 +645,7 @@ impl<'a> Iterator for OctetStringIter<'a> { Length::Definite(len) => len, _ => unreachable!() }; - let res = &inner[..length]; - *inner = &inner[length..]; - return Some(res) + return Some(inner.split_to(length)) } Tag::END_OF_VALUE => continue, _ => unreachable!() @@ -888,7 +914,7 @@ mod tests { assert_eq!( decode::Constructed::decode( b"\x24\x04\ - \x04\x02ab".as_ref(), + \x04\x02ab".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) @@ -902,7 +928,7 @@ mod tests { decode::Constructed::decode( b"\x24\x06\ \x04\x01a\ - \x04\x01b".as_ref(), + \x04\x01b".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) @@ -917,7 +943,7 @@ mod tests { b"\x24\x08\ \x24\x80\ \x04\x02ab\ - \0\0".as_ref(), + \0\0".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) @@ -926,13 +952,12 @@ mod tests { "ab" ); - println!("lllllll"); // I(p) assert_eq!( decode::Constructed::decode( b"\x24\x80\ \x04\x02ab\ - \0\0".as_ref(), + \0\0".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) @@ -948,7 +973,7 @@ mod tests { \x04\x01a\ \x24\x80\ \x04\x01b\ - \0\0".as_ref(), + \0\0".into_source(), Mode::Ber, |cons| { OctetString::take_from(cons) diff --git a/src/string/restricted.rs b/src/string/restricted.rs index d7ce1a9..0c3e28e 100644 --- a/src/string/restricted.rs +++ b/src/string/restricted.rs @@ -8,6 +8,7 @@ use std::borrow::Cow; use std::marker::PhantomData; use bytes::Bytes; use crate::{decode, encode}; +use crate::decode::DecodeError; use crate::tag::Tag; use super::octet::{OctetString, OctetStringIter, OctetStringOctets}; @@ -158,16 +159,16 @@ impl RestrictedString { /// returned. pub fn take_from( cons: &mut decode::Constructed - ) -> Result { + ) -> Result> { cons.take_value_if(L::TAG, Self::from_content) } /// Takes a character set from content. pub fn from_content( content: &mut decode::Content - ) -> Result { + ) -> Result> { let os = OctetString::from_content(content)?; - Self::new(os).map_err(|_| decode::Error::Malformed.into()) + Self::new(os).map_err(|_| content.content_err("invalid character")) } /// Returns a value encoder for the character string with the natural tag. @@ -542,6 +543,7 @@ mod test { use super::*; use bytes::Bytes; + use crate::decode::IntoSource; use crate::mode::Mode; use crate::encode::Values; @@ -554,7 +556,7 @@ mod test { ps.encode_ref().write_encoded(Mode::Der, &mut v).unwrap(); let decoded = Mode::Der.decode( - v.as_slice(), + v.as_slice().into_source(), PrintableString::take_from ).unwrap(); diff --git a/src/tag.rs b/src/tag.rs index 06f074f..36aefa4 100644 --- a/src/tag.rs +++ b/src/tag.rs @@ -3,7 +3,7 @@ //! This is a private module. Its public items are re-exported by the parent. use std::{fmt, io}; -use crate::decode; +use crate::decode::{DecodeError, Source}; //------------ Tag ----------------------------------------------------------- @@ -22,8 +22,8 @@ use crate::decode; /// /// # Limitations /// -/// We can only decode up to four identifier octets. That is, we only support tag -/// numbers between 0 and 1fffff. +/// We can only decode up to four identifier octets. That is, we only support +/// tag numbers between 0 and 1fffff. /// /// [`Primitive`]: decode/struct.Primitive.html /// [`Constructed`]: decode/struct.Constructed.html @@ -316,7 +316,9 @@ impl Tag { /// Returns the number of the tag. pub fn number(self) -> u32 { - if (Tag::SINGLEBYTE_DATA_MASK & self.0[0]) != Tag::SINGLEBYTE_DATA_MASK { + if (Tag::SINGLEBYTE_DATA_MASK & self.0[0]) + != Tag::SINGLEBYTE_DATA_MASK + { // It's a single byte identifier u32::from(Tag::SINGLEBYTE_DATA_MASK & self.0[0]) } else if Tag::LAST_OCTET_MASK & self.0[1] == 0 { @@ -335,15 +337,17 @@ impl Tag { } } - /// Takes a tag from the beginning of a source. + /// Takes an optional tag from the beginning of a source. /// /// Upon success, returns both the tag and whether the value is - /// constructed. If there are no more octets available in the source, - /// an error is returned. - pub fn take_from( + /// constructed. + pub fn take_opt_from( source: &mut S, - ) -> Result<(Self, bool), S::Err> { - let byte = source.take_u8()?; + ) -> Result, DecodeError> { + let byte = match source.take_opt_u8()? { + Some(byte) => byte, + None => return Ok(None) + }; // clear constructed bit let mut data = [byte & !Tag::CONSTRUCTED_MASK, 0, 0, 0]; let constructed = byte & Tag::CONSTRUCTED_MASK != 0; @@ -351,13 +355,29 @@ impl Tag { for i in 1..=3 { data[i] = source.take_u8()?; if data[i] & Tag::LAST_OCTET_MASK == 0 { - return Ok((Tag(data), constructed)); + return Ok(Some((Tag(data), constructed))); } } } else { - return Ok((Tag(data), constructed)); + return Ok(Some((Tag(data), constructed))); + } + Err(source.content_err( + "tag values longer than 4 bytes not implemented" + )) + } + + /// Takes a tag from the beginning of a source. + /// + /// Upon success, returns both the tag and whether the value is + /// constructed. If there are no more octets available in the source, + /// an error is returned. + pub fn take_from( + source: &mut S, + ) -> Result<(Self, bool), DecodeError> { + match Self::take_opt_from(source)? { + Some(res) => Ok(res), + None => Err(source.content_err("additional values expected")) } - xerr!(Err(decode::Error::Unimplemented.into())) } /// Takes a tag from the beginning of a resource if it matches this tag. @@ -365,10 +385,10 @@ impl Tag { /// If there is no more data available in the source or if the tag is /// something else, returns `Ok(None)`. If the tag matches `self`, returns /// whether the value is constructed. - pub fn take_from_if( + pub fn take_from_if( self, source: &mut S, - ) -> Result, S::Err> { + ) -> Result, DecodeError> { if source.request(1)? == 0 { return Ok(None) } @@ -380,7 +400,7 @@ impl Tag { loop { if source.request(i + 1)? == 0 { // Not enough data for a complete tag. - xerr!(return Err(decode::Error::Malformed.into())) + return Err(source.content_err("short tag value")) } data[i] = source.slice()[i]; if data[i] & Tag::LAST_OCTET_MASK == 0 { @@ -388,14 +408,16 @@ impl Tag { } // We don’t support tags larger than 4 bytes. if i == 3 { - xerr!(return Err(decode::Error::Unimplemented.into())) + return Err(source.content_err( + "tag values longer than 4 bytes not implemented" + )) } i += 1; } } let (tag, compressed) = (Tag(data), byte & Tag::CONSTRUCTED_MASK != 0); if tag == self { - source.advance(tag.encoded_len())?; + source.advance(tag.encoded_len()); Ok(Some(compressed)) } else { @@ -490,6 +512,7 @@ impl fmt::Debug for Tag { #[cfg(test)] mod test { use super::*; + use crate::decode::IntoSource; const TYPES: &[u8] = &[Tag::UNIVERSAL, Tag::APPLICATION, Tag::CONTEXT_SPECIFIC, Tag::PRIVATE]; @@ -503,10 +526,15 @@ mod test { for i in range.clone() { let tag = Tag::new(typ, i); let expected = Tag([typ | i as u8, 0, 0, 0]); - let decoded = Tag::take_from(&mut &tag.0[..]).unwrap(); - assert_eq!(tag.take_from_if(&mut &tag.0[..]), Ok(Some(false))); + let decoded = Tag::take_from( + &mut tag.0.into_source() + ).unwrap(); + assert_eq!( + tag.take_from_if(&mut tag.0.into_source()).unwrap(), + Some(false) + ); // The value is not constructed. - assert_eq!(decoded.1, false); + assert!(!decoded.1); // The tag is the same assert_eq!(decoded.0, expected); // We get the same number back. @@ -532,10 +560,15 @@ mod test { let expected = Tag([ Tag::SINGLEBYTE_DATA_MASK | typ, i as u8, 0, 0 ]); - let decoded = Tag::take_from(&mut &tag.0[..]).unwrap(); - assert_eq!(tag.take_from_if(&mut &tag.0[..]), Ok(Some(false))); + let decoded = Tag::take_from( + &mut tag.0.into_source() + ).unwrap(); + assert_eq!( + tag.take_from_if(&mut tag.0.into_source()).unwrap(), + Some(false) + ); // The value is not constructed. - assert_eq!(decoded.1, false); + assert!(!decoded.1); // The tag is the same assert_eq!(decoded.0, expected); assert_eq!(tag.number(), i); @@ -561,10 +594,15 @@ mod test { i as u8 & !Tag::LAST_OCTET_MASK, 0 ]); - let decoded = Tag::take_from(&mut &tag.0[..]).unwrap(); - assert_eq!(tag.take_from_if(&mut &tag.0[..]), Ok(Some(false))); + let decoded = Tag::take_from( + &mut tag.0.into_source() + ).unwrap(); + assert_eq!( + tag.take_from_if(&mut tag.0.into_source()).unwrap(), + Some(false) + ); // The value is not constructed. - assert_eq!(decoded.1, false); + assert!(!decoded.1); // The tag is the same assert_eq!(decoded.0, expected); assert_eq!(tag.number(), i); @@ -590,10 +628,15 @@ mod test { (i >> 7) as u8 | Tag::LAST_OCTET_MASK, i as u8 & !Tag::LAST_OCTET_MASK ]); - let decoded = Tag::take_from(&mut &tag.0[..]).unwrap(); - assert_eq!(tag.take_from_if(&mut &tag.0[..]), Ok(Some(false))); + let decoded = Tag::take_from( + &mut tag.0.into_source() + ).unwrap(); + assert_eq!( + tag.take_from_if(&mut tag.0.into_source()).unwrap(), + Some(false) + ); // The value is not constructed. - assert_eq!(decoded.1, false); + assert!(!decoded.1); // The tag is the same assert_eq!(decoded.0, expected); assert_eq!(tag.number(), i); @@ -607,14 +650,12 @@ mod test { let large_tag = [ 0b1111_1111, 0b1000_0000, 0b1000_0000, 0b1000_0000, 0b1000_0000 ]; - assert_eq!( - Tag::take_from(&mut &large_tag[..]), - Err(decode::Error::Unimplemented) + assert!( + Tag::take_from(&mut large_tag.into_source()).is_err() ); let short_tag = [0b1111_1111, 0b1000_0000]; - assert_eq!( - Tag::take_from(&mut &short_tag[..]), - Err(decode::Error::Malformed) + assert!( + Tag::take_from(&mut short_tag.into_source()).is_err() ); } }