-
Notifications
You must be signed in to change notification settings - Fork 601
Avoid stack overflows via configurable with_recursion_limit
#764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ use crate::tokenizer::*; | |
pub enum ParserError { | ||
TokenizerError(String), | ||
ParserError(String), | ||
RecursionLimitExceeded, | ||
} | ||
|
||
// Use `Parser::expected` instead, if possible | ||
|
@@ -55,6 +56,92 @@ macro_rules! return_ok_if_some { | |
}}; | ||
} | ||
|
||
#[cfg(feature = "std")] | ||
/// Implemenation [`RecursionCounter`] if std is available | ||
mod recursion { | ||
use core::sync::atomic::{AtomicUsize, Ordering}; | ||
use std::rc::Rc; | ||
|
||
use super::ParserError; | ||
|
||
/// Tracks remaining recursion depth. This value is decremented on | ||
/// each call to `try_decrease()`, when it reaches 0 an error will | ||
/// be returned. | ||
/// | ||
/// Note: Uses an Rc and AtomicUsize in order to satisfy the Rust | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Rc approach is from @46bit -- if anyone has ideas about how to avoid it, I would love a PR to help. |
||
/// borrow checker so the automatic DepthGuard decrement a | ||
/// reference to the counter. The actual value is not modified | ||
/// concurrently | ||
pub(crate) struct RecursionCounter { | ||
remaining_depth: Rc<AtomicUsize>, | ||
} | ||
|
||
impl RecursionCounter { | ||
/// Creates a [`RecursionCounter`] with the specified maximum | ||
/// depth | ||
pub fn new(remaining_depth: usize) -> Self { | ||
Self { | ||
remaining_depth: Rc::new(remaining_depth.into()), | ||
} | ||
} | ||
|
||
/// Decreases the remaining depth by 1. | ||
/// | ||
/// Returns `Err` if the remaining depth falls to 0. | ||
/// | ||
/// Returns a [`DepthGuard`] which will adds 1 to the | ||
/// remaining depth upon drop; | ||
pub fn try_decrease(&self) -> Result<DepthGuard, ParserError> { | ||
let old_value = self.remaining_depth.fetch_sub(1, Ordering::SeqCst); | ||
// ran out of space | ||
if old_value == 0 { | ||
Err(ParserError::RecursionLimitExceeded) | ||
} else { | ||
Ok(DepthGuard::new(Rc::clone(&self.remaining_depth))) | ||
} | ||
} | ||
} | ||
|
||
/// Guard that increass the remaining depth by 1 on drop | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
pub struct DepthGuard { | ||
remaining_depth: Rc<AtomicUsize>, | ||
} | ||
|
||
impl DepthGuard { | ||
fn new(remaining_depth: Rc<AtomicUsize>) -> Self { | ||
Self { remaining_depth } | ||
} | ||
} | ||
impl Drop for DepthGuard { | ||
fn drop(&mut self) { | ||
self.remaining_depth.fetch_add(1, Ordering::SeqCst); | ||
} | ||
} | ||
} | ||
|
||
#[cfg(not(feature = "std"))] | ||
mod recursion { | ||
/// Implemenation [`RecursionCounter`] if std is NOT available (and does not | ||
/// guard against stack overflow). | ||
/// | ||
/// Has the same API as the std RecursionCounter implementation | ||
/// but does not actually limit stack depth. | ||
pub(crate) struct RecursionCounter {} | ||
|
||
impl RecursionCounter { | ||
pub fn new(_remaining_depth: usize) -> Self { | ||
Self {} | ||
} | ||
pub fn try_decrease(&self) -> Result<DepthGuard, super::ParserError> { | ||
Ok(DepthGuard {}) | ||
} | ||
} | ||
|
||
pub struct DepthGuard {} | ||
} | ||
|
||
use recursion::RecursionCounter; | ||
|
||
#[derive(PartialEq, Eq)] | ||
pub enum IsOptional { | ||
Optional, | ||
|
@@ -96,6 +183,7 @@ impl fmt::Display for ParserError { | |
match self { | ||
ParserError::TokenizerError(s) => s, | ||
ParserError::ParserError(s) => s, | ||
ParserError::RecursionLimitExceeded => "recursion limit exceeded", | ||
} | ||
) | ||
} | ||
|
@@ -104,22 +192,78 @@ impl fmt::Display for ParserError { | |
#[cfg(feature = "std")] | ||
impl std::error::Error for ParserError {} | ||
|
||
// By default, allow expressions up to this deep before erroring | ||
const DEFAULT_REMAINING_DEPTH: usize = 50; | ||
|
||
pub struct Parser<'a> { | ||
tokens: Vec<TokenWithLocation>, | ||
/// The index of the first unprocessed token in `self.tokens` | ||
index: usize, | ||
/// The current dialect to use | ||
dialect: &'a dyn Dialect, | ||
/// ensure the stack does not overflow by limiting recusion depth | ||
recursion_counter: RecursionCounter, | ||
} | ||
|
||
impl<'a> Parser<'a> { | ||
/// Parse the specified tokens | ||
/// To avoid breaking backwards compatibility, this function accepts | ||
/// bare tokens. | ||
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self { | ||
Parser::new_without_locations(tokens, dialect) | ||
/// Create a parser for a [`Dialect`] | ||
/// | ||
/// See also [`Parser::parse_sql`] | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect}; | ||
/// # fn main() -> Result<(), ParserError> { | ||
/// let dialect = GenericDialect{}; | ||
/// let statements = Parser::new(&dialect) | ||
/// .try_with_sql("SELECT * FROM foo")? | ||
/// .parse_statements()?; | ||
/// # Ok(()) | ||
/// # } | ||
/// ``` | ||
pub fn new(dialect: &'a dyn Dialect) -> Self { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the signature here has changed - I tried to illustrate the intended usage with doc comments |
||
Self { | ||
tokens: vec![], | ||
index: 0, | ||
dialect, | ||
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH), | ||
} | ||
} | ||
|
||
/// Specify the maximum recursion limit while parsing. | ||
/// | ||
/// | ||
/// [`Parser`] prevents stack overflows by returning | ||
/// [`ParserError::RecursionLimitExceeded`] if the parser exceeds | ||
/// this depth while processing the query. | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect}; | ||
/// # fn main() -> Result<(), ParserError> { | ||
/// let dialect = GenericDialect{}; | ||
/// let result = Parser::new(&dialect) | ||
/// .with_recursion_limit(1) | ||
/// .try_with_sql("SELECT * FROM foo WHERE (a OR (b OR (c OR d)))")? | ||
/// .parse_statements(); | ||
/// assert_eq!(result, Err(ParserError::RecursionLimitExceeded)); | ||
/// # Ok(()) | ||
/// # } | ||
/// ``` | ||
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self { | ||
self.recursion_counter = RecursionCounter::new(recursion_limit); | ||
self | ||
} | ||
|
||
/// Reset this parser to parse the specified token stream | ||
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithLocation>) -> Self { | ||
self.tokens = tokens; | ||
self.index = 0; | ||
self | ||
} | ||
|
||
pub fn new_without_locations(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self { | ||
/// Reset this parser state to parse the specified tokens | ||
pub fn with_tokens(self, tokens: Vec<Token>) -> Self { | ||
// Put in dummy locations | ||
let tokens_with_locations: Vec<TokenWithLocation> = tokens | ||
.into_iter() | ||
|
@@ -128,49 +272,84 @@ impl<'a> Parser<'a> { | |
location: Location { line: 0, column: 0 }, | ||
}) | ||
.collect(); | ||
Parser::new_with_locations(tokens_with_locations, dialect) | ||
self.with_tokens_with_locations(tokens_with_locations) | ||
} | ||
|
||
/// Parse the specified tokens | ||
pub fn new_with_locations(tokens: Vec<TokenWithLocation>, dialect: &'a dyn Dialect) -> Self { | ||
Parser { | ||
tokens, | ||
index: 0, | ||
dialect, | ||
} | ||
/// Tokenize the sql string and sets this [`Parser`]'s state to | ||
/// parse the resulting tokens | ||
/// | ||
/// Returns an error if there was an error tokenizing the SQL string. | ||
/// | ||
/// See example on [`Parser::new()`] for an example | ||
pub fn try_with_sql(self, sql: &str) -> Result<Self, ParserError> { | ||
debug!("Parsing sql '{}'...", sql); | ||
let mut tokenizer = Tokenizer::new(self.dialect, sql); | ||
let tokens = tokenizer.tokenize()?; | ||
Ok(self.with_tokens(tokens)) | ||
} | ||
|
||
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST) | ||
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. github mangles the diff -- this function still exists with the same signature. It now also has docstrings |
||
let mut tokenizer = Tokenizer::new(dialect, sql); | ||
let tokens = tokenizer.tokenize()?; | ||
let mut parser = Parser::new(tokens, dialect); | ||
/// Parse potentially multiple statements | ||
/// | ||
/// Example | ||
/// ``` | ||
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect}; | ||
/// # fn main() -> Result<(), ParserError> { | ||
/// let dialect = GenericDialect{}; | ||
/// let statements = Parser::new(&dialect) | ||
/// // Parse a SQL string with 2 separate statements | ||
/// .try_with_sql("SELECT * FROM foo; SELECT * FROM bar;")? | ||
/// .parse_statements()?; | ||
/// assert_eq!(statements.len(), 2); | ||
/// # Ok(()) | ||
/// # } | ||
/// ``` | ||
pub fn parse_statements(&mut self) -> Result<Vec<Statement>, ParserError> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code was factored out of |
||
let mut stmts = Vec::new(); | ||
let mut expecting_statement_delimiter = false; | ||
debug!("Parsing sql '{}'...", sql); | ||
loop { | ||
// ignore empty statements (between successive statement delimiters) | ||
while parser.consume_token(&Token::SemiColon) { | ||
while self.consume_token(&Token::SemiColon) { | ||
expecting_statement_delimiter = false; | ||
} | ||
|
||
if parser.peek_token() == Token::EOF { | ||
if self.peek_token() == Token::EOF { | ||
break; | ||
} | ||
if expecting_statement_delimiter { | ||
return parser.expected("end of statement", parser.peek_token()); | ||
return self.expected("end of statement", self.peek_token()); | ||
} | ||
|
||
let statement = parser.parse_statement()?; | ||
let statement = self.parse_statement()?; | ||
stmts.push(statement); | ||
expecting_statement_delimiter = true; | ||
} | ||
Ok(stmts) | ||
} | ||
|
||
/// Convience method to parse a string with one or more SQL | ||
/// statements into produce an Abstract Syntax Tree (AST). | ||
/// | ||
/// Example | ||
/// ``` | ||
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect}; | ||
/// # fn main() -> Result<(), ParserError> { | ||
/// let dialect = GenericDialect{}; | ||
/// let statements = Parser::parse_sql( | ||
/// &dialect, "SELECT * FROM foo" | ||
/// )?; | ||
/// assert_eq!(statements.len(), 1); | ||
/// # Ok(()) | ||
/// # } | ||
/// ``` | ||
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> { | ||
Parser::new(dialect).try_with_sql(sql)?.parse_statements() | ||
} | ||
|
||
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), | ||
/// stopping before the statement separator, if any. | ||
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> { | ||
let _guard = self.recursion_counter.try_decrease()?; | ||
|
||
// allow the dialect to override statement parsing | ||
if let Some(statement) = self.dialect.parse_statement(self) { | ||
return statement; | ||
|
@@ -364,6 +543,7 @@ impl<'a> Parser<'a> { | |
|
||
/// Parse a new expression | ||
pub fn parse_expr(&mut self) -> Result<Expr, ParserError> { | ||
let _guard = self.recursion_counter.try_decrease()?; | ||
self.parse_subexpr(0) | ||
} | ||
|
||
|
@@ -4454,6 +4634,7 @@ impl<'a> Parser<'a> { | |
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't | ||
/// expect the initial keyword to be already consumed | ||
pub fn parse_query(&mut self) -> Result<Query, ParserError> { | ||
let _guard = self.recursion_counter.try_decrease()?; | ||
let with = if self.parse_keyword(Keyword::WITH) { | ||
Some(With { | ||
recursive: self.parse_keyword(Keyword::RECURSIVE), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the key change for actually limiting recursion depth