Skip to content
This repository has been archived by the owner on Mar 25, 2024. It is now read-only.

Commit

Permalink
Merge pull request #105 from xfix/prevent-too-deep-recursion
Browse files Browse the repository at this point in the history
Prevent too deep recursion
  • Loading branch information
dtolnay authored Sep 15, 2018
2 parents 5911699 + b93aff6 commit 41d5823
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 8 deletions.
32 changes: 24 additions & 8 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ struct Deserializer<'a> {
aliases: &'a BTreeMap<usize, usize>,
pos: &'a mut usize,
path: Path<'a>,
remaining_depth: u8,
}

impl<'a> Deserializer<'a> {
Expand Down Expand Up @@ -109,6 +110,7 @@ impl<'a> Deserializer<'a> {
aliases: self.aliases,
pos: pos,
path: Path::Alias { parent: &self.path },
remaining_depth: self.remaining_depth,
})
}
None => panic!("unresolved alias: {}", *pos),
Expand Down Expand Up @@ -161,11 +163,11 @@ impl<'a> Deserializer<'a> {
where
V: Visitor<'de>,
{
let (value, len) = {
let mut seq = SeqAccess { de: self, len: 0 };
let (value, len) = self.recursion_check(|de| {
let mut seq = SeqAccess { de: de, len: 0 };
let value = visitor.visit_seq(&mut seq)?;
(value, seq.len)
};
Ok((value, seq.len))
})?;
self.end_sequence(len)?;
Ok(value)
}
Expand All @@ -174,15 +176,15 @@ impl<'a> Deserializer<'a> {
where
V: Visitor<'de>,
{
let (value, len) = {
let (value, len) = self.recursion_check(|de| {
let mut map = MapAccess {
de: &mut *self,
de: de,
len: 0,
key: None,
};
let value = visitor.visit_map(&mut map)?;
(value, map.len)
};
Ok((value, map.len))
})?;
self.end_mapping(len)?;
Ok(value)
}
Expand Down Expand Up @@ -238,6 +240,16 @@ impl<'a> Deserializer<'a> {
Err(de::Error::invalid_length(total, &ExpectedMap(len)))
}
}

fn recursion_check<F: FnOnce(&mut Self) -> Result<T>, T>(&mut self, f: F) -> Result<T> {
let previous_depth = self.remaining_depth;
self.remaining_depth = previous_depth
.checked_sub(1)
.ok_or_else(Error::recursion_limit_exceeded)?;
let result = f(self);
self.remaining_depth = previous_depth;
result
}
}

fn visit_scalar<'de, V>(
Expand Down Expand Up @@ -303,6 +315,7 @@ impl<'de, 'a, 'r> de::SeqAccess<'de> for SeqAccess<'a, 'r> {
parent: &self.de.path,
index: self.len,
},
remaining_depth: self.de.remaining_depth,
};
self.len += 1;
seed.deserialize(&mut element_de).map(Some)
Expand Down Expand Up @@ -357,6 +370,7 @@ impl<'de, 'a, 'r> de::MapAccess<'de> for MapAccess<'a, 'r> {
parent: &self.de.path,
}
},
remaining_depth: self.de.remaining_depth,
};
seed.deserialize(&mut value_de)
}
Expand Down Expand Up @@ -409,6 +423,7 @@ impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> {
parent: &self.de.path,
key: variant,
},
remaining_depth: self.de.remaining_depth,
};
Ok((ret, variant_visitor))
}
Expand Down Expand Up @@ -949,6 +964,7 @@ where
aliases: &loader.aliases,
pos: &mut pos,
path: Path::Root,
remaining_depth: 128,
})?;
if pos == loader.events.len() {
Ok(t)
Expand Down
12 changes: 12 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub enum ErrorImpl {

EndOfStream,
MoreThanOneDocument,
RecursionLimitExceeded,
}

#[derive(Debug)]
Expand Down Expand Up @@ -157,6 +158,12 @@ impl Error {
Error(Box::new(ErrorImpl::FromUtf8(err)))
}

// Not public API. Should be pub(crate).
#[doc(hidden)]
pub fn recursion_limit_exceeded() -> Error {
Error(Box::new(ErrorImpl::RecursionLimitExceeded))
}

// Not public API. Should be pub(crate).
#[doc(hidden)]
pub fn fix_marker(mut self, marker: Marker, path: Path) -> Self {
Expand All @@ -183,6 +190,7 @@ impl error::Error for Error {
ErrorImpl::MoreThanOneDocument => {
"deserializing from YAML containing more than one document is not supported"
}
ErrorImpl::RecursionLimitExceeded => "recursion limit exceeded",
}
}

Expand Down Expand Up @@ -218,6 +226,7 @@ impl Display for Error {
ErrorImpl::MoreThanOneDocument => f.write_str(
"deserializing from YAML containing more than one document is not supported",
),
ErrorImpl::RecursionLimitExceeded => f.write_str("recursion limit exceeded"),
}
}
}
Expand All @@ -241,6 +250,9 @@ impl Debug for Error {
}
ErrorImpl::EndOfStream => formatter.debug_tuple("EndOfStream").finish(),
ErrorImpl::MoreThanOneDocument => formatter.debug_tuple("MoreThanOneDocument").finish(),
ErrorImpl::RecursionLimitExceeded => {
formatter.debug_tuple("RecursionLimitExceeded").finish()
}
}
}
}
Expand Down
48 changes: 48 additions & 0 deletions tests/test_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,51 @@ fn test_invalid_scalar_type() {
let expected = "x: invalid type: unit value, expected an array of length 1 at line 2 column 1";
test_error::<S>(yaml, expected);
}

#[test]
fn test_infinite_recursion_objects() {
#[derive(Deserialize, Debug)]
struct S {
x: Option<Box<S>>,
}

let yaml = "&a {x: *a}";
let expected = "recursion limit exceeded";
test_error::<S>(yaml, expected);
}

#[test]
fn test_infinite_recursion_arrays() {
#[derive(Deserialize, Debug)]
struct S {
x: Option<Box<S>>,
}

let yaml = "&a [*a]";
let expected = "recursion limit exceeded";
test_error::<S>(yaml, expected);
}

#[test]
fn test_finite_recursion_objects() {
#[derive(Deserialize, Debug)]
struct S {
x: Option<Box<S>>,
}

let yaml = "{x:".repeat(1_000) + &"}".repeat(1_000);
let expected = "recursion limit exceeded";
test_error::<i32>(&yaml, expected);
}

#[test]
fn test_finite_recursion_arrays() {
#[derive(Deserialize, Debug)]
struct S {
x: Option<Box<S>>,
}

let yaml = "[".repeat(1_000) + &"]".repeat(1_000);
let expected = "recursion limit exceeded";
test_error::<S>(&yaml, expected);
}

0 comments on commit 41d5823

Please sign in to comment.