From a9a9b9bf243862bd2afbf6853fca97f30dc4f620 Mon Sep 17 00:00:00 2001 From: Christoph Otter Date: Tue, 23 Jan 2024 16:31:55 +0100 Subject: [PATCH] Add recursion limit --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/de/errors.rs | 4 ++++ src/de/mod.rs | 59 ++++++++++++++++++++++++++++++++++++++---------- src/lib.rs | 25 ++++++++++++++++++++ 5 files changed, 78 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a3386ade..bcac5e78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,7 +40,7 @@ checksum = "ce31e24b01e1e524df96f1c2fdd054405f8d7376249a5110886fb4b658484789" [[package]] name = "serde-json-wasm" -version = "0.5.1" +version = "0.5.2" dependencies = [ "serde", "serde_derive", diff --git a/Cargo.toml b/Cargo.toml index a56dbfa5..9b5ad058 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ license = "MIT OR Apache-2.0" name = "serde-json-wasm" readme = "README.md" repository = "https://github.com/CosmWasm/serde-json-wasm" -version = "0.5.1" +version = "0.5.2" exclude = [ ".cargo/", ".github/", diff --git a/src/de/errors.rs b/src/de/errors.rs index 47876acf..0f9e640b 100644 --- a/src/de/errors.rs +++ b/src/de/errors.rs @@ -68,6 +68,9 @@ pub enum Error { /// JSON has a comma after the last value in an array or map. TrailingComma, + /// JSON is nested too deeply, exceeeded the recursion limit. + RecursionLimitExceeded, + /// Custom error message from serde Custom(String), } @@ -132,6 +135,7 @@ impl fmt::Display for Error { value." } Error::TrailingComma => "JSON has a comma after the last value in an array or map.", + Error::RecursionLimitExceeded => "JSON is nested too deeply, exceeeded the recursion limit.", Error::Custom(msg) => msg, } ) diff --git a/src/de/mod.rs b/src/de/mod.rs index 1c6fd50d..bbeb0edb 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -20,6 +20,9 @@ use std::str::from_utf8; pub struct Deserializer<'b> { slice: &'b [u8], index: usize, + + /// Remaining depth until we hit the recursion limit + remaining_depth: u8, } enum StringLike<'a> { @@ -29,7 +32,11 @@ enum StringLike<'a> { impl<'a> Deserializer<'a> { fn new(slice: &'a [u8]) -> Deserializer<'_> { - Deserializer { slice, index: 0 } + Deserializer { + slice, + index: 0, + remaining_depth: 128, + } } fn eat_char(&mut self) { @@ -286,16 +293,22 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { } } b'[' => { - self.eat_char(); - let ret = visitor.visit_seq(SeqAccess::new(self))?; + check_recursion! { + self.eat_char(); + let ret = visitor.visit_seq(SeqAccess::new(self)); + } + let ret = ret?; self.end_seq()?; Ok(ret) } b'{' => { - self.eat_char(); - let ret = visitor.visit_map(MapAccess::new(self))?; + check_recursion! { + self.eat_char(); + let ret = visitor.visit_map(MapAccess::new(self)); + } + let ret = ret?; self.end_map()?; @@ -548,8 +561,11 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { { match self.parse_whitespace().ok_or(Error::EofWhileParsingValue)? { b'[' => { - self.eat_char(); - let ret = visitor.visit_seq(SeqAccess::new(self))?; + check_recursion! { + self.eat_char(); + let ret = visitor.visit_seq(SeqAccess::new(self)); + } + let ret = ret?; self.end_seq()?; @@ -585,9 +601,11 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { let peek = self.parse_whitespace().ok_or(Error::EofWhileParsingValue)?; if peek == b'{' { - self.eat_char(); - - let ret = visitor.visit_map(MapAccess::new(self))?; + check_recursion! { + self.eat_char(); + let ret = visitor.visit_map(MapAccess::new(self)); + } + let ret = ret?; self.end_map()?; @@ -623,8 +641,11 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { b'"' => visitor.visit_enum(UnitVariantAccess::new(self)), // if it is a struct enum b'{' => { - self.eat_char(); - visitor.visit_enum(StructVariantAccess::new(self)) + check_recursion! { + self.eat_char(); + let value = visitor.visit_enum(StructVariantAccess::new(self)); + } + value } _ => Err(Error::ExpectedSomeIdent), } @@ -684,6 +705,20 @@ where from_slice(s.as_bytes()) } +macro_rules! check_recursion { + ($this:ident $($body:tt)*) => { + $this.remaining_depth -= 1; + if $this.remaining_depth == 0 { + return Err($crate::de::Error::RecursionLimitExceeded); + } + + $this $($body)* + + $this.remaining_depth += 1; + }; +} +pub(crate) use check_recursion; + #[cfg(test)] mod tests { use super::from_str; diff --git a/src/lib.rs b/src/lib.rs index 4ce4e51e..4d377163 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -214,4 +214,29 @@ mod test { item ); } + + #[test] + fn no_stack_overflow() { + const AMOUNT: usize = 2000; + let mut json = String::from(r#"{"":"#); + + #[derive(Debug, Deserialize, Serialize)] + pub struct Person { + name: String, + age: u8, + phones: Vec, + } + + for _ in 0..AMOUNT { + json.push('['); + } + for _ in 0..AMOUNT { + json.push(']'); + } + + json.push_str(r#"] }[[[[[[[[[[[[[[[[[[[[[ ""","age":35,"phones":["#); + + let err = from_str::(&json).unwrap_err(); + assert_eq!(err, crate::de::Error::RecursionLimitExceeded); + } }