Skip to content

Commit

Permalink
Document the recursion limit
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Cameron <nrc@ncameron.org>
  • Loading branch information
nrc committed May 23, 2019
1 parent 7e9eea4 commit 0582d99
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
70 changes: 56 additions & 14 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,41 +77,41 @@ where
}
}

/// Additional information passed to every decode/merge function.
#[derive(Clone, Debug)]
pub struct DecodeContext {
#[cfg(feature = "recursion-limit")]
/// How many times we can recurse in the current decode stack before we hit
/// the recursion limit.
///
/// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
/// customized. The recursion limit can be ignored by building the Prost
/// crate without the `recursion-limit` feature (which is set by default).
recurse_count: u32,
}

pub(crate) struct RecursionGuard {
#[cfg(feature = "recursion-limit")]
ctx: *mut DecodeContext,
}

impl Default for DecodeContext {
#[cfg(feature = "recursion-limit")]
fn default() -> DecodeContext {
DecodeContext {
recurse_count: crate::RECURSION_LIMIT,
}
}

#[cfg(not(feature = "recursion-limit"))]
fn default() -> DecodeContext {
DecodeContext {}
}
}

#[cfg(feature = "recursion-limit")]
impl Drop for RecursionGuard {
fn drop(&mut self) {
unsafe {
(*self.ctx).recurse_count += 1;
}
}
}

impl DecodeContext {
#[cfg(feature = "recursion-limit")]
/// Call this function before recursively decoding.
///
/// This function returns a guard object which will automatically restore
/// the recursion counter when it is destroyed by going out of scope.
///
/// See the safety note on `RecursionGuard` for important information.
pub(crate) fn enter_recursion(&mut self) -> RecursionGuard {
self.recurse_count -= 1;
RecursionGuard { ctx: self }
Expand All @@ -124,8 +124,49 @@ impl DecodeContext {
}
}

/// RAII guard created by `DecodeContext::enter_recursion` to ensure recursion is
/// tracked correctly.
///
/// ## Safety note
///
/// This object uses a raw pointer and unsafe code to avoid dynamic ownership
/// checking of it's reference to a `DecodeContext`. This could be implemented
/// using `Rc` and `RefCell` or closures, but that could be expensive since we
/// might expect decoding to be in a program's hot path.
///
/// Usage is safe under normal usage patterns:
///
/// ```ignore
/// fn foo(..., ctx: &mut DecodeContext) {
/// let _guard = ctx.enter_recursion(); // `_guard` keeps a mutable reference to ctx.
/// some_recusive_fn(..., ctx); // `ctx` is passed to `some_recusive_fn`.
/// }
/// ```
///
/// In the above scenario, `ctx.recurse_count` must be the same before and after
/// the call to `some_recusive_fn`. In particular, it must still be valid memory.
/// `_guard` should not be passed out of `foo` nor should it be stored.
pub(crate) struct RecursionGuard {
#[cfg(feature = "recursion-limit")]
ctx: *mut DecodeContext,
}

#[cfg(feature = "recursion-limit")]
impl Drop for RecursionGuard {
fn drop(&mut self) {
unsafe {
(*self.ctx).recurse_count += 1;
}
}
}

impl RecursionGuard {
#[cfg(feature = "recursion-limit")]
/// Checks whether the recursion limit has been reached in the stack of
/// decodes described by the `DecodeContext` at `self.ctx`.
///
/// Returns `Ok<()>` if it is ok to continue recursing.
/// Returns `Err<DecodeError>` if the recursion limit has been reached.
pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
unsafe {
if (*self.ctx).recurse_count == 0 {
Expand All @@ -142,6 +183,7 @@ impl RecursionGuard {
Ok(())
}
}

/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
/// number of bytes read.
///
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use bytes::{BufMut, IntoBuf};

use crate::encoding::{decode_varint, encode_varint, encoded_len_varint};

// See `encoding::DecodeContext` for more info.
// 100 is the default recursion limit in the C++ implementation.
#[cfg(feature = "recursion-limit")]
const RECURSION_LIMIT: u32 = 100;
Expand Down

0 comments on commit 0582d99

Please sign in to comment.