Skip to content

Commit

Permalink
Add recursion limit checks
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Cameron <nrc@ncameron.org>
  • Loading branch information
nrc committed May 23, 2019
1 parent fd7abeb commit 361957d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
67 changes: 66 additions & 1 deletion src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,70 @@ where
}

#[derive(Clone, Debug)]
pub struct DecodeContext {}
pub struct DecodeContext {
#[cfg(feature = "recursion-limit")]
recurse_count: u32,
}

pub(crate) struct RecursionGuard {
#[cfg(feature = "recursion-limit")]
ctx: *mut DecodeContext,
}

impl Default for DecodeContext {
#[cfg(feature = "recursion-limit")]
fn default() -> DecodeContext {
DecodeContext {
recurse_count: crate::RECURSION_LIMIT,
}
}
#[cfg(not(feature = "recursion-limit"))]
fn default() -> DecodeContext {
DecodeContext {}
}
}

#[cfg(feature = "recursion-limit")]
impl Drop for RecursionGuard {
fn drop(&mut self) {
unsafe {
(*self.ctx).recurse_count += 1;
}
}
}

impl DecodeContext {
#[cfg(feature = "recursion-limit")]
pub(crate) fn enter_recursion(&mut self) -> RecursionGuard {
self.recurse_count -= 1;
RecursionGuard { ctx: self }
}

#[cfg(not(feature = "recursion-limit"))]
#[inline(always)]
pub(crate) fn enter_recursion(&mut self) -> RecursionGuard {
RecursionGuard {}
}
}

impl RecursionGuard {
#[cfg(feature = "recursion-limit")]
pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
unsafe {
if (*self.ctx).recurse_count == 0 {
Err(DecodeError::new("Recursion limit reached"))
} else {
Ok(())
}
}
}

#[cfg(not(feature = "recursion-limit"))]
#[inline(always)]
pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
Ok(())
}
}
/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
/// number of bytes read.
///
Expand Down Expand Up @@ -375,6 +430,8 @@ macro_rules! merge_repeated_numeric {
{
if wire_type == WireType::LengthDelimited {
// Packed.
let recursion_guard = ctx.enter_recursion();
recursion_guard.limit_reached()?;
merge_loop(values, buf, ctx, |values, buf, ctx| {
let mut value = Default::default();
$merge($wire_type, &mut value, buf, ctx)?;
Expand All @@ -385,6 +442,8 @@ macro_rules! merge_repeated_numeric {
// Unpacked.
check_wire_type($wire_type, wire_type)?;
let mut value = Default::default();
let recursion_guard = ctx.enter_recursion();
recursion_guard.limit_reached()?;
$merge(wire_type, &mut value, buf, ctx)?;
values.push(value);
Ok(())
Expand Down Expand Up @@ -835,6 +894,8 @@ pub mod message {
B: Buf,
{
check_wire_type(WireType::LengthDelimited, wire_type)?;
let recursion_guard = ctx.enter_recursion();
recursion_guard.limit_reached()?;
merge_loop(msg, buf, ctx, |msg: &mut M, buf: &mut B, ctx| {
let (tag, wire_type) = decode_key(buf)?;
msg.merge_field(tag, wire_type, buf, ctx)
Expand Down Expand Up @@ -917,6 +978,8 @@ pub mod group {
{
check_wire_type(WireType::StartGroup, wire_type)?;

let recursion_guard = ctx.enter_recursion();
recursion_guard.limit_reached()?;
loop {
let (field_tag, field_wire_type) = decode_key(buf)?;
if field_wire_type == WireType::EndGroup {
Expand Down Expand Up @@ -1108,6 +1171,8 @@ macro_rules! map {
{
let mut key = Default::default();
let mut val = val_default;
let recursion_guard = ctx.enter_recursion();
recursion_guard.limit_reached()?;
merge_loop(
&mut (&mut key, &mut val),
buf,
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ use bytes::{BufMut, IntoBuf};

use crate::encoding::{decode_varint, encode_varint, encoded_len_varint};

// 100 is the default recursion limit in the C++ implementation.
#[cfg(feature = "recursion-limit")]
const RECURSION_LIMIT: u32 = 128;
const RECURSION_LIMIT: u32 = 100;

/// Encodes a length delimiter to the buffer.
///
Expand Down

0 comments on commit 361957d

Please sign in to comment.