diff --git a/packages/libs/deer/Cargo.toml b/packages/libs/deer/Cargo.toml index b7c26569854..0e6176c015d 100644 --- a/packages/libs/deer/Cargo.toml +++ b/packages/libs/deer/Cargo.toml @@ -25,4 +25,4 @@ std = ['serde/std', 'error-stack/std'] arbitrary-precision = [] [workspace] -members = ['.', 'macros', 'json'] +members = ['.', 'macros', 'json', 'desert'] diff --git a/packages/libs/deer/desert/Cargo.toml b/packages/libs/deer/desert/Cargo.toml new file mode 100644 index 00000000000..72d1f130b85 --- /dev/null +++ b/packages/libs/deer/desert/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "deer-desert" +version = "0.0.0" +edition = "2021" +# NOTE: THIS PACKAGE IS NEVER INTENDED TO BE PUBLISHED +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +deer = { path = ".." } +error-stack = { version = "0.2.4", default_features = false } +serde_json = { version = "1.0.91", default_features = false, features = ['alloc'] } +bitvec = { version = "1", default_features = false, features = ['alloc', 'atomic'] } diff --git a/packages/libs/deer/desert/README.md b/packages/libs/deer/desert/README.md new file mode 100644 index 00000000000..6e229ad5b71 --- /dev/null +++ b/packages/libs/deer/desert/README.md @@ -0,0 +1,6 @@ +# deer-desert + +desert is the the internal only deserialization testing framework used throughout the integration tests and should never +be published. + +`desert` = `deser` (`deserialization`) + `t` (`test`) diff --git a/packages/libs/deer/desert/src/array.rs b/packages/libs/deer/desert/src/array.rs new file mode 100644 index 00000000000..b82824a1e0a --- /dev/null +++ b/packages/libs/deer/desert/src/array.rs @@ -0,0 +1,149 @@ +use deer::{ + error::{ + ArrayAccessError, ArrayLengthError, BoundedContractViolationError, ExpectedLength, + ReceivedLength, Variant, + }, + Deserialize, Deserializer as _, +}; +use error_stack::{Report, Result, ResultExt}; + +use crate::{ + deserializer::{Deserializer, DeserializerNone}, + token::Token, +}; + +pub struct ArrayAccess<'a, 'b, 'de: 'a> { + deserializer: &'a mut Deserializer<'b, 'de>, + + length: Option, + remaining: Option, + consumed: usize, +} + +impl<'a, 'b, 'de> ArrayAccess<'a, 'b, 'de> { + pub fn new(deserializer: &'a mut Deserializer<'b, 'de>, length: Option) -> Self { + Self { + deserializer, + consumed: 0, + length, + remaining: None, + } + } + + fn scan_end(&self) -> Option { + let mut objects: usize = 0; + let mut arrays: usize = 0; + + let mut n = 0; + + loop { + let token = self.deserializer.peek_n(n)?; + + match token { + Token::Array { .. } => arrays += 1, + Token::ArrayEnd if arrays == 0 && objects == 0 => { + // we're at the outer layer, meaning we can know where we end + return Some(n); + } + Token::ArrayEnd => arrays = arrays.saturating_sub(1), + Token::Object { .. } => objects += 1, + Token::ObjectEnd => objects = objects.saturating_sub(1), + _ => {} + } + + n += 1; + } + } +} + +impl<'de> deer::ArrayAccess<'de> for ArrayAccess<'_, '_, 'de> { + fn set_bounded(&mut self, length: usize) -> Result<(), ArrayAccessError> { + if self.consumed > 0 { + return Err( + Report::new(BoundedContractViolationError::SetDirty.into_error()) + .change_context(ArrayAccessError), + ); + } + + if self.remaining.is_some() { + return Err(Report::new( + BoundedContractViolationError::SetCalledMultipleTimes.into_error(), + ) + .change_context(ArrayAccessError)); + } + + self.remaining = Some(length); + + Ok(()) + } + + fn next(&mut self) -> Option> + where + T: Deserialize<'de>, + { + self.consumed += 1; + + if self.deserializer.peek() == Token::ArrayEnd { + // we have reached the ending, if `self.remaining` is set we use the `DeserializerNone` + // to deserialize any values that require `None` + if let Some(remaining) = &mut self.remaining { + if *remaining == 0 { + return None; + } + + *remaining = remaining.saturating_sub(1); + + let value = T::deserialize(DeserializerNone { + context: self.deserializer.context(), + }); + + Some(value.change_context(ArrayAccessError)) + } else { + None + } + } else { + let value = T::deserialize(&mut *self.deserializer); + Some(value.change_context(ArrayAccessError)) + } + } + + fn size_hint(&self) -> Option { + self.length + } + + fn end(self) -> Result<(), ArrayAccessError> { + let mut result = Ok(()); + + // ensure that we consume the last token, if it is the wrong token error out + if self.deserializer.peek() != Token::ArrayEnd { + let mut error = Report::new(ArrayLengthError.into_error()) + .attach(ExpectedLength::new(self.consumed)); + + if let Some(length) = self.size_hint() { + error = error.attach(ReceivedLength::new(length)); + } + + result = Err(error); + } + + // bump until the very end, which ensures that deserialize calls after this might succeed! + let bump = self + .scan_end() + .unwrap_or_else(|| self.deserializer.tape().remaining()); + self.deserializer.tape_mut().bump_n(bump); + + if let Some(remaining) = self.remaining { + if remaining > 0 { + let error = + Report::new(BoundedContractViolationError::EndRemainingItems.into_error()); + + match &mut result { + Err(result) => result.extend_one(error), + result => *result = Err(error), + } + } + } + + result.change_context(ArrayAccessError) + } +} diff --git a/packages/libs/deer/desert/src/assert.rs b/packages/libs/deer/desert/src/assert.rs new file mode 100644 index 00000000000..825760ae9ba --- /dev/null +++ b/packages/libs/deer/desert/src/assert.rs @@ -0,0 +1,50 @@ +use core::fmt::Debug; + +use deer::{error::ReportExt, Context, Deserialize}; +use serde_json::to_value; + +use crate::{deserializer::Deserializer, token::Token}; + +pub fn assert_tokens_with_context<'de, T>(expected: &T, tokens: &'de [Token], context: &Context) +where + T: Deserialize<'de> + PartialEq + Debug, +{ + let mut de = Deserializer::new(tokens, context); + let received = T::deserialize(&mut de).expect("should deserialize"); + + if de.remaining() > 0 { + panic!("{} remaining tokens", de.remaining()); + } + + assert_eq!(received, *expected); +} + +pub fn assert_tokens<'de, T>(value: &T, tokens: &'de [Token]) +where + T: Deserialize<'de> + PartialEq + Debug, +{ + assert_tokens_with_context(value, tokens, &Context::new()); +} + +pub fn assert_tokens_with_context_error<'de, T>( + error: &serde_json::Value, + tokens: &'de [Token], + context: &Context, +) where + T: Deserialize<'de> + Debug, +{ + let mut de = Deserializer::new(tokens, context); + let received = T::deserialize(&mut de).expect_err("value of type T should fail serialization"); + + let received = received.export(); + let received = to_value(received).expect("error should serialize"); + + assert_eq!(received, *error) +} + +pub fn assert_tokens_error<'de, T>(error: &serde_json::Value, tokens: &'de [Token]) +where + T: Deserialize<'de> + Debug, +{ + assert_tokens_with_context_error::(error, tokens, &Context::new()); +} diff --git a/packages/libs/deer/desert/src/deserializer.rs b/packages/libs/deer/desert/src/deserializer.rs new file mode 100644 index 00000000000..cfd7c451d22 --- /dev/null +++ b/packages/libs/deer/desert/src/deserializer.rs @@ -0,0 +1,145 @@ +use alloc::borrow::ToOwned; +use core::ops::Range; + +use deer::{error::DeserializerError, Context, Visitor}; +use error_stack::{Result, ResultExt}; + +use crate::{array::ArrayAccess, object::ObjectAccess, tape::Tape, token::Token}; + +macro_rules! forward { + ($($method:ident),*) => { + $( + fn $method(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_any(visitor) + } + )* + }; +} + +#[derive(Debug)] +pub struct Deserializer<'a, 'de> { + context: &'a Context, + tape: Tape<'a, 'de>, +} + +impl<'a, 'de> Deserializer<'a, 'de> { + pub(crate) fn erase(&mut self, range: Range) { + self.tape.set_trivia(range); + } +} + +impl<'a, 'de> deer::Deserializer<'de> for &mut Deserializer<'a, 'de> { + forward!( + deserialize_null, + deserialize_bool, + deserialize_number, + deserialize_char, + deserialize_string, + deserialize_str, + deserialize_bytes, + deserialize_bytes_buffer, + deserialize_array, + deserialize_object + ); + + fn context(&self) -> &Context { + self.context + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let token = self.next(); + + match token { + Token::Bool(value) => visitor.visit_bool(value), + Token::Number(value) => visitor.visit_number(value.clone()), + Token::Char(value) => visitor.visit_char(value), + Token::Str(value) => visitor.visit_str(value), + Token::BorrowedStr(value) => visitor.visit_borrowed_str(value), + Token::String(value) => visitor.visit_string(value.to_owned()), + Token::Bytes(value) => visitor.visit_bytes(value), + Token::BorrowedBytes(value) => visitor.visit_borrowed_bytes(value), + Token::BytesBuf(value) => visitor.visit_bytes_buffer(value.to_vec()), + Token::Array { length } => visitor.visit_array(ArrayAccess::new(self, length)), + Token::Object { length } => visitor.visit_object(ObjectAccess::new(self, length)), + _ => { + panic!("Deserializer did not expect {token}"); + } + } + .change_context(DeserializerError) + } +} + +impl<'a, 'de> Deserializer<'a, 'de> { + pub(crate) fn new_bare(tape: Tape<'a, 'de>, context: &'a Context) -> Self { + Self { tape, context } + } + + pub fn new(tokens: &'de [Token], context: &'a Context) -> Self { + Self::new_bare(tokens.into(), context) + } + + pub(crate) fn peek(&self) -> Token { + self.tape.peek().expect("should have token to deserialize") + } + + pub(crate) fn peek_n(&self, n: usize) -> Option { + self.tape.peek_n(n) + } + + pub(crate) fn next(&mut self) -> Token { + self.tape.next().expect("should have token to deserialize") + } + + pub(crate) fn tape(&self) -> &Tape<'a, 'de> { + &self.tape + } + + pub(crate) fn tape_mut(&mut self) -> &mut Tape<'a, 'de> { + &mut self.tape + } + + pub fn remaining(&self) -> usize { + self.tape.remaining() + } + + pub fn is_empty(&self) -> bool { + self.tape.is_empty() + } +} + +#[derive(Debug)] +pub(crate) struct DeserializerNone<'a> { + pub(crate) context: &'a Context, +} + +impl<'de> deer::Deserializer<'de> for DeserializerNone<'_> { + forward!( + deserialize_null, + deserialize_bool, + deserialize_number, + deserialize_char, + deserialize_string, + deserialize_str, + deserialize_bytes, + deserialize_bytes_buffer, + deserialize_array, + deserialize_object + ); + + fn context(&self) -> &Context { + self.context + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_none().change_context(DeserializerError) + } +} diff --git a/packages/libs/deer/desert/src/lib.rs b/packages/libs/deer/desert/src/lib.rs new file mode 100644 index 00000000000..7c7948124cf --- /dev/null +++ b/packages/libs/deer/desert/src/lib.rs @@ -0,0 +1,16 @@ +#![no_std] + +extern crate alloc; + +pub(crate) mod array; +mod assert; +mod deserializer; +pub(crate) mod object; +pub(crate) mod tape; +mod token; + +pub use assert::{ + assert_tokens, assert_tokens_error, assert_tokens_with_context, + assert_tokens_with_context_error, +}; +pub use token::Token; diff --git a/packages/libs/deer/desert/src/object.rs b/packages/libs/deer/desert/src/object.rs new file mode 100644 index 00000000000..1d5575463d7 --- /dev/null +++ b/packages/libs/deer/desert/src/object.rs @@ -0,0 +1,273 @@ +use deer::{ + error::{ + BoundedContractViolationError, ExpectedLength, ObjectAccessError, ObjectLengthError, + ReceivedLength, Variant, + }, + Deserialize, Deserializer as _, +}; +use error_stack::{Report, Result, ResultExt}; + +use crate::{ + deserializer::{Deserializer, DeserializerNone}, + tape::Tape, + token::Token, +}; + +pub struct ObjectAccess<'a, 'b, 'de: 'a> { + deserializer: &'a mut Deserializer<'b, 'de>, + + length: Option, + remaining: Option, + consumed: usize, +} + +impl<'a, 'b, 'de: 'a> ObjectAccess<'a, 'b, 'de> { + pub fn new(deserializer: &'a mut Deserializer<'b, 'de>, length: Option) -> Self { + Self { + deserializer, + length, + remaining: None, + consumed: 0, + } + } + + // This assumes that Str and such are atomic, meaning `Str Str` as a deserialize value is + // considered invalid, as that should use `ArrayAccess` instead. + fn scan(&self, key: &str) -> Option { + let mut objects: usize = 0; + let mut arrays: usize = 0; + let mut n = 0; + + #[derive(Copy, Clone, Eq, PartialEq)] + enum State { + Key, + Value, + } + + impl State { + fn flip(&mut self) { + match *self { + State::Key => *self = State::Value, + State::Value => *self = State::Key, + } + } + } + + let mut state = State::Key; + + loop { + let next = self.deserializer.peek_n(n)?; + + match next { + Token::Array { .. } => arrays += 1, + Token::ArrayEnd => arrays = arrays.saturating_sub(1), + Token::Object { .. } => objects += 1, + Token::ObjectEnd if objects == 0 && arrays == 0 => { + // this is for the outer layer (that's us), therefore we can abort our linear + // search + return None; + } + Token::ObjectEnd => objects = objects.saturating_sub(1), + Token::Str(value) | Token::BorrowedStr(value) | Token::String(value) + if objects == 0 && arrays == 0 && value == key && state == State::Key => + { + // we found an element that matches the element value that is next in line + return Some(n); + } + _ => {} + } + + if arrays == 0 && objects == 0 { + // we're dependent on the fact if something is a key or value, if we're not nested + // then we can switch the state. + state.flip(); + } + + n += 1; + } + } + + fn scan_end(&self) -> Option { + let mut objects: usize = 0; + let mut arrays: usize = 0; + + let mut n = 0; + + loop { + let token = self.deserializer.peek_n(n)?; + + match token { + Token::Array { .. } => arrays += 1, + Token::ArrayEnd => arrays = arrays.saturating_sub(1), + Token::Object { .. } => objects += 1, + Token::ObjectEnd if arrays == 0 && objects == 0 => { + // we're at the outer layer, meaning we can know where we end + return Some(n); + } + Token::ObjectEnd => objects = objects.saturating_sub(1), + _ => {} + } + + n += 1; + } + } +} + +impl<'de> deer::ObjectAccess<'de> for ObjectAccess<'_, '_, 'de> { + fn set_bounded(&mut self, length: usize) -> Result<(), ObjectAccessError> { + if self.consumed > 0 { + return Err( + Report::new(BoundedContractViolationError::SetDirty.into_error()) + .change_context(ObjectAccessError), + ); + } + + if self.remaining.is_some() { + return Err(Report::new( + BoundedContractViolationError::SetCalledMultipleTimes.into_error(), + ) + .change_context(ObjectAccessError)); + } + + self.remaining = Some(length); + + Ok(()) + } + + fn value(&mut self, key: &str) -> Result + where + T: Deserialize<'de>, + { + if self.remaining == Some(0) { + return T::deserialize(DeserializerNone { + context: self.deserializer.context(), + }) + .change_context(ObjectAccessError); + } + + self.consumed += 1; + + if let Some(remaining) = &mut self.remaining { + *remaining = remaining.saturating_sub(1); + } + + match self.scan(key) { + Some(offset) => { + // now we need to figure out which values are used, we can do this through offset + // calculations + let remaining = self.deserializer.remaining() - offset; + + let tape = self.deserializer.tape().view(offset + 1..); + + let mut deserializer = Deserializer::new_bare( + tape.unwrap_or_else(Tape::empty), + self.deserializer.context(), + ); + + let value = T::deserialize(&mut deserializer); + + let erase = remaining - deserializer.remaining(); + drop(deserializer); + + self.deserializer.erase(offset..offset + erase); + + value + } + None => T::deserialize(DeserializerNone { + context: self.deserializer.context(), + }), + } + .change_context(ObjectAccessError) + } + + fn next(&mut self) -> Option> + where + K: Deserialize<'de>, + V: Deserialize<'de>, + { + if self.remaining == Some(0) { + return None; + } + + self.consumed += 1; + + if let Some(remaining) = &mut self.remaining { + *remaining = remaining.saturating_sub(1); + } + + let (key, value) = if self.deserializer.peek() == Token::ObjectEnd { + // we're not in bounded mode, which means we need to signal that we're done + self.remaining?; + + if self.remaining.is_some() { + let key = K::deserialize(DeserializerNone { + context: self.deserializer.context(), + }); + let value = V::deserialize(DeserializerNone { + context: self.deserializer.context(), + }); + + (key, value) + } else { + return None; + } + } else { + let key = K::deserialize(&mut *self.deserializer); + let value = V::deserialize(&mut *self.deserializer); + + (key, value) + }; + + let result = match (key, value) { + (Err(mut key), Err(value)) => { + key.extend_one(value); + + Err(key.change_context(ObjectAccessError)) + } + (Err(error), _) | (_, Err(error)) => Err(error.change_context(ObjectAccessError)), + (Ok(key), Ok(value)) => Ok((key, value)), + }; + + Some(result) + } + + fn size_hint(&self) -> Option { + self.length + } + + fn end(self) -> Result<(), ObjectAccessError> { + let mut result = Ok(()); + + // ensure that we consume the last token, if it is the wrong token error out + if self.deserializer.peek() != Token::ObjectEnd { + let mut error = Report::new(ObjectLengthError.into_error()) + .attach(ExpectedLength::new(self.consumed)); + + if let Some(length) = self.size_hint() { + error = error.attach(ReceivedLength::new(length)); + } + + result = Err(error); + } + + // bump until the very end, which ensures that deserialize calls after this might succeed! + let bump = self + .scan_end() + .unwrap_or_else(|| self.deserializer.tape().remaining()); + self.deserializer.tape_mut().bump_n(bump); + + if let Some(remaining) = self.remaining { + if remaining > 0 { + let error = + Report::new(BoundedContractViolationError::EndRemainingItems.into_error()); + + match &mut result { + Err(result) => result.extend_one(error), + result => *result = Err(error), + } + } + } + + result.change_context(ObjectAccessError) + } +} diff --git a/packages/libs/deer/desert/src/tape.rs b/packages/libs/deer/desert/src/tape.rs new file mode 100644 index 00000000000..20ad97a2215 --- /dev/null +++ b/packages/libs/deer/desert/src/tape.rs @@ -0,0 +1,180 @@ +use core::{ + ops::{Deref, Range}, + slice::SliceIndex, +}; + +use bitvec::{ + boxed::BitBox, + order::Lsb0, + prelude::{BitSlice, BitVec}, + slice::BitSliceIndex, +}; + +use crate::token::Token; + +#[derive(Debug)] +enum Trivia<'a> { + Owned(BitBox), + Slice(&'a BitSlice), +} + +impl<'a> Deref for Trivia<'a> { + type Target = BitSlice; + + fn deref(&self) -> &Self::Target { + match self { + Trivia::Owned(value) => value.as_bitslice(), + Trivia::Slice(value) => value, + } + } +} + +impl<'a> Trivia<'a> { + fn to_mut(&mut self) -> &mut BitSlice { + match self { + Trivia::Owned(value) => value.as_mut_bitslice(), + Trivia::Slice(value) => { + let owned = BitBox::from_bitslice(*value); + *self = Self::Owned(owned); + + self.to_mut() + } + } + } +} + +#[derive(Debug)] +pub struct Tape<'a, 'de> { + tokens: &'de [Token], + trivia: Trivia<'a>, +} + +impl Tape<'_, '_> { + pub(crate) fn empty() -> Self { + Self { + tokens: &[], + trivia: Trivia::Slice(BitSlice::empty()), + } + } +} + +impl<'a, 'de> Tape<'a, 'de> { + // also includes trivia + fn peek_all_n(&self, n: usize) -> Option { + self.tokens.get(n).copied() + } + + fn is_trivia_n(&self, n: usize) -> Option { + self.trivia.get(n).as_deref().copied() + } + + /// ## Panics + /// + /// if range.start > range.end + pub(crate) fn set_trivia(&mut self, mut range: Range) { + // ensure that the start range smaller than or equal to the end range + // doing this we can ensure that `0..1` is valid, but `1..0` is not. + assert!(range.start <= range.end); + + // automatically adjust so that we're able to always index to the end, even if the the end + // is out of bounds + if range.end > self.tokens.len() { + range.end = self.tokens.len(); + } + + // we have already asserted that `range.start <= range.end`, therefore if range.start is out + // of bounds, range.end must be out of bounds as well, in that case we do not need to fill + // the slice, as `.get_mut` will return `None` + if range.start >= self.tokens.len() { + return; + } + + if let Some(slice) = self.trivia.to_mut().get_mut(range) { + slice.fill(true); + } + } + + pub(crate) fn peek_n(&self, n: usize) -> Option { + let mut offset = 0; + let mut m = 0; + + while m != n { + if !self.is_trivia_n(offset)? { + m += 1; + } + + offset += 1; + } + + self.peek_all_n(m) + } + + pub(crate) fn peek(&self) -> Option { + let mut n = 0; + + while self.is_trivia_n(n)? { + n += 1; + } + + self.peek_all_n(n) + } + + fn bump(&mut self) -> Option<(Token, bool)> { + // naive version of bump, which just takes the token and returns it with the status + let (token, tokens) = self.tokens.split_first()?; + let is_trivia = *self.trivia.get(0)?; + // use trivia like a feed tape, this avoid reallocation + self.trivia.to_mut().shift_left(1); + self.tokens = tokens; + + Some((*token, is_trivia)) + } + + pub(crate) fn bump_n(&mut self, i: usize) { + for _ in 0..i { + self.bump(); + } + } + + pub(crate) fn next(&mut self) -> Option { + loop { + let (token, is_trivia) = self.bump()?; + + if !is_trivia { + return Some(token); + } + } + } + + pub(crate) fn remaining(&self) -> usize { + self.tokens.len() + } + + pub(crate) fn is_empty(&self) -> bool { + self.tokens.is_empty() + } + + pub(crate) fn view<'b, B>(&'b self, n: B) -> Option> + where + B: BitSliceIndex<'b, usize, Lsb0, Immut = &'b BitSlice> + + SliceIndex<[Token], Output = [Token]> + + Clone, + { + let tokens = self.tokens.get(n.clone())?; + let trivia = self.trivia.get(n)?; + + Some(Tape { + tokens, + trivia: Trivia::Slice(trivia), + }) + } +} + +impl<'de> From<&'de [Token]> for Tape<'_, 'de> { + fn from(value: &'de [Token]) -> Self { + Self { + tokens: value, + trivia: Trivia::Owned(BitVec::repeat(false, value.len()).into_boxed_bitslice()), + } + } +} diff --git a/packages/libs/deer/desert/src/token.rs b/packages/libs/deer/desert/src/token.rs new file mode 100644 index 00000000000..60351ee86bb --- /dev/null +++ b/packages/libs/deer/desert/src/token.rs @@ -0,0 +1,75 @@ +use core::fmt::{Debug, Display, Formatter}; + +use deer::Number; + +// TODO: test +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum Token { + /// A serialized `bool` + /// + /// ``` + /// # use error_stack::ResultExt; + /// use deer::{ + /// error::{DeserializeError, VisitorError}, + /// Deserialize, Deserializer, Document, Reflection, Schema, Visitor, + /// }; + /// use deer_desert::{assert_tokens, Token}; + /// + /// #[derive(Debug, PartialEq)] + /// struct Bool(bool); + /// + /// impl Reflection for Bool { + /// fn schema(_: &mut Document) -> Schema { + /// Schema::new("boolean") + /// } + /// } + /// + /// impl<'de> Deserialize<'de> for Bool { + /// type Reflection = Self; + /// + /// fn deserialize>(de: D) -> error_stack::Result { + /// struct BoolVisitor; + /// + /// impl<'de> Visitor<'de> for BoolVisitor { + /// type Value = Bool; + /// + /// fn expecting(&self) -> Document { + /// Bool::reflection() + /// } + /// + /// fn visit_bool(self, v: bool) -> error_stack::Result { + /// Ok(Bool(v)) + /// } + /// } + /// + /// de.deserialize_bool(BoolVisitor) + /// .change_context(DeserializeError) + /// } + /// } + /// + /// assert_tokens(&Bool(true), &[Token::Bool(true)]) + /// ``` + Bool(bool), + Number(&'static Number), + Char(char), + Str(&'static str), + BorrowedStr(&'static str), + String(&'static str), + Bytes(&'static [u8]), + BorrowedBytes(&'static [u8]), + BytesBuf(&'static [u8]), + Array { + length: Option, + }, + ArrayEnd, + Object { + length: Option, + }, + ObjectEnd, +} + +impl Display for Token { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + Debug::fmt(self, f) + } +} diff --git a/packages/libs/deer/json/src/error.rs b/packages/libs/deer/json/src/error.rs index 261db2642d4..04112b96a7b 100644 --- a/packages/libs/deer/json/src/error.rs +++ b/packages/libs/deer/json/src/error.rs @@ -61,33 +61,3 @@ impl Variant for OverflowError { Ok(()) } } - -#[derive(Debug)] -pub(crate) enum SetBoundedError { - Dirty, - CalledMultipleTimes, -} - -impl Display for SetBoundedError { - fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - match self { - Self::Dirty => f.write_str("unable to set bounds after calling `.next()`"), - Self::CalledMultipleTimes => f.write_str("cannot call set_bounds() multiple times"), - } - } -} - -impl Variant for SetBoundedError { - type Properties = (Location,); - - const ID: Id = id!["internal", "access", "set_bounds"]; - const NAMESPACE: Namespace = NAMESPACE; - - fn message<'a>( - &self, - fmt: &mut Formatter, - _: &::Value<'a>, - ) -> core::fmt::Result { - Display::fmt(&self, fmt) - } -} diff --git a/packages/libs/deer/json/src/lib.rs b/packages/libs/deer/json/src/lib.rs index 7d7e2ae753d..04d8e573cf3 100644 --- a/packages/libs/deer/json/src/lib.rs +++ b/packages/libs/deer/json/src/lib.rs @@ -32,16 +32,17 @@ use std::any::Demand; use deer::{ error::{ - ArrayAccessError, ArrayLengthError, DeserializeError, DeserializerError, ExpectedLength, - ExpectedType, MissingError, ObjectAccessError, ObjectItemsExtraError, ReceivedKey, - ReceivedLength, ReceivedType, ReceivedValue, TypeError, ValueError, Variant, + ArrayAccessError, ArrayLengthError, BoundedContractViolationError, DeserializeError, + DeserializerError, ExpectedLength, ExpectedType, MissingError, ObjectAccessError, + ObjectItemsExtraError, ReceivedKey, ReceivedLength, ReceivedType, ReceivedValue, TypeError, + ValueError, Variant, }, Context, Deserialize, DeserializeOwned, Document, Reflection, Schema, Visitor, }; use error_stack::{IntoReport, Report, Result, ResultExt}; use serde_json::{Map, Value}; -use crate::error::{BytesUnsupportedError, OverflowError, SetBoundedError}; +use crate::error::{BytesUnsupportedError, OverflowError}; #[cfg(not(feature = "arbitrary-precision"))] fn serde_to_deer_number(number: &serde_json::Number) -> Option { @@ -399,15 +400,16 @@ impl<'a, 'de> deer::ArrayAccess<'de> for ArrayAccess<'a> { fn set_bounded(&mut self, length: usize) -> Result<(), ArrayAccessError> { if self.dirty { return Err( - Report::new(SetBoundedError::Dirty.into_error()).change_context(ArrayAccessError) + Report::new(BoundedContractViolationError::SetDirty.into_error()) + .change_context(ArrayAccessError), ); } if self.remaining.is_some() { - return Err( - Report::new(SetBoundedError::CalledMultipleTimes.into_error()) - .change_context(ArrayAccessError), - ); + return Err(Report::new( + BoundedContractViolationError::SetCalledMultipleTimes.into_error(), + ) + .change_context(ArrayAccessError)); } self.remaining = Some(length); @@ -453,6 +455,7 @@ impl<'a, 'de> deer::ArrayAccess<'de> for ArrayAccess<'a> { } fn end(self) -> Result<(), ArrayAccessError> { + // TODO: error if self.remaining isn't Some(0) or None let count = self.inner.count(); if count == 0 { Ok(()) @@ -492,15 +495,16 @@ impl<'a, 'de> deer::ObjectAccess<'de> for ObjectAccess<'a> { fn set_bounded(&mut self, length: usize) -> Result<(), ObjectAccessError> { if self.dirty { return Err( - Report::new(SetBoundedError::Dirty.into_error()).change_context(ObjectAccessError) + Report::new(BoundedContractViolationError::SetDirty.into_error()) + .change_context(ObjectAccessError), ); } if self.remaining.is_some() { - return Err( - Report::new(SetBoundedError::CalledMultipleTimes.into_error()) - .change_context(ObjectAccessError), - ); + return Err(Report::new( + BoundedContractViolationError::SetCalledMultipleTimes.into_error(), + ) + .change_context(ObjectAccessError)); } self.remaining = Some(length); diff --git a/packages/libs/deer/src/context.rs b/packages/libs/deer/src/context.rs index fd111cf68f9..09063e95c75 100644 --- a/packages/libs/deer/src/context.rs +++ b/packages/libs/deer/src/context.rs @@ -1,6 +1,7 @@ use alloc::{boxed::Box, collections::BTreeMap}; use core::any::{Any, TypeId}; +#[derive(Debug)] pub struct Context { inner: BTreeMap>, } diff --git a/packages/libs/deer/src/error/extra.rs b/packages/libs/deer/src/error/extra.rs index 0ff1bdcf14f..a0609f0315e 100644 --- a/packages/libs/deer/src/error/extra.rs +++ b/packages/libs/deer/src/error/extra.rs @@ -78,6 +78,52 @@ impl Display for ObjectItemsExtraError { } } +#[derive(Debug)] +pub struct ObjectLengthError; + +impl Variant for ObjectLengthError { + type Properties = (Location, ExpectedLength, ReceivedLength); + + const ID: Id = id!["object", "length"]; + const NAMESPACE: Namespace = NAMESPACE; + + fn message<'a>( + &self, + fmt: &mut Formatter, + properties: &::Value<'a>, + ) -> fmt::Result { + // expected object of length {expected}, but received object of length {received} + let (_, expected, received) = properties; + + let has_expected = expected.is_some(); + let has_received = received.is_some(); + + if let Some(ExpectedLength(length)) = expected { + fmt.write_fmt(format_args!("expected object of length {length}"))?; + } + + if has_expected && has_received { + fmt.write_str(", but ")?; + } + + if let Some(ReceivedLength(length)) = received { + fmt.write_fmt(format_args!("received object of length {length}"))?; + } + + if !has_expected && !has_received { + Display::fmt(self, fmt)?; + } + + Ok(()) + } +} + +impl Display for ObjectLengthError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("received more items than expected") + } +} + #[derive(serde::Serialize)] pub struct ExpectedLength(usize); @@ -240,7 +286,65 @@ mod tests { } #[test] - fn object() { + fn object_length() { + // we simulate that the error happens in: + // [..., {field1: {_: _, _: _, _: _} <- here}] + let error = Report::new(Error::new(ObjectLengthError)) + .attach(Location::Field("field1")) + .attach(Location::Array(1)) + .attach(ExpectedLength::new(2)) + .attach(ReceivedLength::new(3)); + + let value = to_json::(&error); + + assert_eq!( + value, + json!({ + "location": [ + {"type": "array", "value": 1}, + {"type": "field", "value": "field1"} + ], + "expected": 2, + "received": 3 + }) + ); + } + + #[test] + fn object_length_message() { + assert_eq!( + to_message::(&Report::new(ObjectLengthError.into_error())), + "received more items than expected" + ); + + assert_eq!( + to_message::( + &Report::new(ObjectLengthError.into_error()) // + .attach(ReceivedLength::new(3)) + ), + "received object of length 3" + ); + + assert_eq!( + to_message::( + &Report::new(ObjectLengthError.into_error()) // + .attach(ExpectedLength::new(2)) + ), + "expected object of length 2" + ); + + assert_eq!( + to_message::( + &Report::new(ObjectLengthError.into_error()) + .attach(ExpectedLength::new(2)) + .attach(ReceivedLength::new(3)) + ), + "expected object of length 2, but received object of length 3" + ); + } + + #[test] + fn object_extra() { // we simulate that the error happens in: // [..., {field1: [...], field2: [...]} <- here] let error = Report::new(ObjectItemsExtraError.into_error()) @@ -261,7 +365,7 @@ mod tests { } #[test] - fn object_message() { + fn object_extra_message() { assert_eq!( to_message::(&Report::new(ObjectItemsExtraError.into_error())), "received unexpected keys" diff --git a/packages/libs/deer/src/error/internal.rs b/packages/libs/deer/src/error/internal.rs new file mode 100644 index 00000000000..700c899a98e --- /dev/null +++ b/packages/libs/deer/src/error/internal.rs @@ -0,0 +1,43 @@ +use core::fmt::{Display, Formatter}; + +use crate::{ + error::{ErrorProperties, Id, Location, Namespace, Variant, NAMESPACE}, + id, +}; + +// TODO: name set_size? +#[derive(Debug)] +pub enum BoundedContractViolationError { + SetDirty, + SetCalledMultipleTimes, + EndRemainingItems, +} + +impl Display for BoundedContractViolationError { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + match self { + Self::SetDirty => f.write_str("unable to set bounds after calling `.next()`"), + Self::SetCalledMultipleTimes => { + f.write_str("cannot call `set_bounded()` multiple times") + } + Self::EndRemainingItems => { + f.write_str("`.next()` was not called exactly `n` times before calling `.end()`") + } + } + } +} + +impl Variant for BoundedContractViolationError { + type Properties = (Location,); + + const ID: Id = id!["internal", "access", "bounded"]; + const NAMESPACE: Namespace = NAMESPACE; + + fn message<'a>( + &self, + fmt: &mut Formatter, + _: &::Value<'a>, + ) -> core::fmt::Result { + Display::fmt(&self, fmt) + } +} diff --git a/packages/libs/deer/src/error/mod.rs b/packages/libs/deer/src/error/mod.rs index a0ab28e490b..ea1d29b98f7 100644 --- a/packages/libs/deer/src/error/mod.rs +++ b/packages/libs/deer/src/error/mod.rs @@ -65,8 +65,10 @@ use core::{ use error_stack::{Context, Frame, IntoReport, Report, Result}; pub use extra::{ - ArrayLengthError, ExpectedLength, ObjectItemsExtraError, ReceivedKey, ReceivedLength, + ArrayLengthError, ExpectedLength, ObjectItemsExtraError, ObjectLengthError, ReceivedKey, + ReceivedLength, }; +pub use internal::BoundedContractViolationError; pub use location::Location; use serde::ser::SerializeMap; pub use r#type::{ExpectedType, ReceivedType, TypeError}; @@ -79,6 +81,7 @@ pub use value::{MissingError, ReceivedValue, ValueError}; use crate::error::serialize::{impl_serialize, Export}; mod extra; +mod internal; mod location; mod macros; mod serialize;