diff --git a/crates/sats/src/bsatn.rs b/crates/sats/src/bsatn.rs index f5624e755d4..cea35fdcee5 100644 --- a/crates/sats/src/bsatn.rs +++ b/crates/sats/src/bsatn.rs @@ -152,7 +152,7 @@ impl ToBsatn for ProductValue { #[cfg(test)] mod tests { - use super::to_vec; + use super::{to_vec, DecodeError}; use crate::proptest::generate_typed_value; use crate::{meta_type::MetaType, AlgebraicType, AlgebraicValue}; use proptest::prelude::*; @@ -179,5 +179,14 @@ mod tests { let val_decoded = AlgebraicValue::decode(&ty, &mut &bytes[..]).unwrap(); prop_assert_eq!(val, val_decoded); } + + #[test] + fn bsatn_non_zero_one_u8_aint_bool(val in 2u8..) { + let bytes = [val]; + prop_assert_eq!( + AlgebraicValue::decode(&AlgebraicType::Bool, &mut &bytes[..]), + Err(DecodeError::InvalidBool(val)) + ); + } } } diff --git a/crates/sats/src/bsatn/de.rs b/crates/sats/src/bsatn/de.rs index 2d4cb123f42..1424a8aa348 100644 --- a/crates/sats/src/bsatn/de.rs +++ b/crates/sats/src/bsatn/de.rs @@ -60,7 +60,12 @@ impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> { } fn deserialize_bool(self) -> Result { - self.reader.get_u8().map(|x| x != 0) + let byte = self.reader.get_u8()?; + match byte { + 0 => Ok(false), + 1 => Ok(true), + b => Err(DecodeError::InvalidBool(b)), + } } fn deserialize_u8(self) -> Result { self.reader.get_u8() diff --git a/crates/sats/src/buffer.rs b/crates/sats/src/buffer.rs index 231e638c50d..67ccb5a3f22 100644 --- a/crates/sats/src/buffer.rs +++ b/crates/sats/src/buffer.rs @@ -20,6 +20,8 @@ pub enum DecodeError { InvalidTag { tag: u8, sum_name: Option }, /// Expected data to be UTF-8 but it wasn't. InvalidUtf8, + /// Expected the byte to be 0 or 1 to be a valid bool. + InvalidBool(u8), /// Custom error not in the other variants of `DecodeError`. Other(String), } @@ -40,6 +42,7 @@ impl fmt::Display for DecodeError { ) } DecodeError::InvalidUtf8 => f.write_str("invalid utf8"), + DecodeError::InvalidBool(byte) => write!(f, "byte {byte} not valid as `bool` (must be 0 or 1)"), DecodeError::Other(err) => f.write_str(err), } } @@ -158,7 +161,7 @@ pub trait BufReader<'de> { /// Reads and returns a byte slice of `.len() = size` advancing the cursor. #[inline] fn get_slice(&mut self, size: usize) -> Result<&'de [u8], DecodeError> { - self.get_chunk(size).ok_or(DecodeError::BufferLength { + self.get_chunk(size).ok_or_else(|| DecodeError::BufferLength { for_type: "[u8]", expected: size, given: self.remaining(), @@ -168,7 +171,7 @@ pub trait BufReader<'de> { /// Reads an array of type `[u8; N]` from the input. #[inline] fn get_array(&mut self) -> Result<&'de [u8; N], DecodeError> { - self.get_array_chunk().ok_or(DecodeError::BufferLength { + self.get_array_chunk().ok_or_else(|| DecodeError::BufferLength { for_type: "[u8; _]", expected: N, given: self.remaining(),