Skip to content

Commit

Permalink
Prevent panics from escaping SqliteAggregatorFunction
Browse files Browse the repository at this point in the history
Panics across FFI boundaries cause undefined behavior in current Rust.
The aggregator functions are callbacks invoked from C (libsqlite).
To safe-guard against panics, std::panic::catch_unwind() is used. On
panic the functions now return with an error result indicating the
unexpected panic occurred.

std::panic::catch_unwind() requires types to implement
std::panic::UnwindSafe, a marker trait indicating that care must be
taken since panics introduce control-flow that is not very visible.
Refer to https://doc.rust-lang.org/std/panic/trait.UnwindSafe.html for a
more detailed explanation.
For SqliteAggregatorFunction::step() we must use
std::panic::AssertUnwindSafe, since &mut references are never considered
UnwindSafe, and the requirement to ensure unwind-safety is documented on
the method.
Of note is that in safe Rust, even if the method is not unwind-safe
the language still guarantees memory-safety. The marker trait is mainly
to prevent logic bugs.
  • Loading branch information
z33ky committed Sep 14, 2020
1 parent cbed866 commit ee2f792
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 39 deletions.
10 changes: 8 additions & 2 deletions diesel/src/sqlite/connection/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,14 @@ pub fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
fn_name: &str,
) -> QueryResult<()>
where
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
A: SqliteAggregateFunction<Args, Output = Ret>
+ 'static
+ Send
+ std::panic::UnwindSafe
+ std::panic::RefUnwindSafe,
Args: FromSqlRow<ArgsSqlType, Sqlite>
+ StaticallySizedRow<ArgsSqlType, Sqlite>
+ std::panic::UnwindSafe,
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
{
Expand Down
10 changes: 8 additions & 2 deletions diesel/src/sqlite/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,14 @@ impl SqliteConnection {
fn_name: &str,
) -> QueryResult<()>
where
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
A: SqliteAggregateFunction<Args, Output = Ret>
+ 'static
+ Send
+ std::panic::UnwindSafe
+ std::panic::RefUnwindSafe,
Args: FromSqlRow<ArgsSqlType, Sqlite>
+ StaticallySizedRow<ArgsSqlType, Sqlite>
+ std::panic::UnwindSafe,
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
{
Expand Down
71 changes: 42 additions & 29 deletions diesel/src/sqlite/connection/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,12 @@ impl RawConnection {
num_args: usize,
) -> QueryResult<()>
where
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
Args: FromSqlRow<ArgsSqlType, Sqlite>,
A: SqliteAggregateFunction<Args, Output = Ret>
+ 'static
+ Send
+ std::panic::UnwindSafe
+ std::panic::RefUnwindSafe,
Args: FromSqlRow<ArgsSqlType, Sqlite> + std::panic::UnwindSafe,
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
{
Expand Down Expand Up @@ -294,8 +298,8 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
num_args: libc::c_int,
value_ptr: *mut *mut ffi::sqlite3_value,
) where
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
Args: FromSqlRow<ArgsSqlType, Sqlite>,
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::RefUnwindSafe,
Args: FromSqlRow<ArgsSqlType, Sqlite> + std::panic::UnwindSafe,
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
{
Expand All @@ -321,21 +325,19 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
// the memory will have a correct alignment.
// (Note I(weiznich): would assume that it is aligned correctly, but we
// we cannot guarantee it, so better be safe than sorry)
let aggregate_context = ffi::sqlite3_aggregate_context(
ctx,
std::mem::size_of::<OptionalAggregator<A>>() as i32,
);
let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
let aggregator = match aggregate_context.map(|a| &mut *a.as_ptr()) {
ffi::sqlite3_aggregate_context(ctx, std::mem::size_of::<OptionalAggregator<A>>() as i32)
};
let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
let aggregator = unsafe {
match aggregate_context.map(|a| &mut *a.as_ptr()) {
Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
Some(mut a_ptr @ &mut OptionalAggregator::None) => {
ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
if let &mut OptionalAggregator::Some(ref mut agg) = a_ptr {
agg
} else {
unreachable!(
"We've written the aggregator above to that location, it must be there"
)
eprintln!("We've written the aggregator to the aggregate context, but it could not be retrieved");
std::process::abort();
}
}
None => {
Expand All @@ -344,28 +346,31 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
}
};

let mut f = |args: &[*mut ffi::sqlite3_value]| -> Result<(), Error> {
let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;

Ok(aggregator.step(args))
};

let args = slice::from_raw_parts(value_ptr, num_args as _);
match f(args) {
Err(e) => {
let args = build_sql_function_args::<ArgsSqlType, Args>(args);
let mut aggregator = std::panic::AssertUnwindSafe(aggregator);
let result = args
.map(|args| std::panic::catch_unwind(move || Ok(aggregator.step(args))))
.unwrap_or_else(|e| Ok(Err(e)));
match result {
Ok(Ok(())) => (),
Ok(Err(e)) => {
let msg = e.to_string();
unsafe { ffi::sqlite3_result_error(ctx, msg.as_ptr() as *const _, msg.len() as _) };
}
Err(_) => {
let msg = format!("{}::step() panicked", std::any::type_name::<A>());
ffi::sqlite3_result_error(ctx, msg.as_ptr() as *const _, msg.len() as _);
}
_ => (),
};
}
}

extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
ctx: *mut ffi::sqlite3_context,
) where
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
Args: FromSqlRow<ArgsSqlType, Sqlite>,
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
Args: FromSqlRow<ArgsSqlType, Sqlite> + std::panic::UnwindSafe,
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
{
Expand All @@ -381,19 +386,27 @@ extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret,
let aggregator = match aggregate_context {
Some(ref mut a) => match std::mem::replace(a.as_mut(), OptionalAggregator::None) {
OptionalAggregator::Some(agg) => Some(agg),
OptionalAggregator::None => unreachable!("We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer")
OptionalAggregator::None => {
eprintln!("We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer");
std::process::abort();
}
},
None => None,
};

let result = A::finalize(aggregator);
let result = std::panic::catch_unwind(|| A::finalize(aggregator))
.map(process_sql_function_result::<RetSqlType, Ret>);

match process_sql_function_result::<RetSqlType, Ret>(result) {
Ok(value) => value.result_of(ctx),
Err(e) => {
match result {
Ok(Ok(value)) => value.result_of(ctx),
Ok(Err(e)) => {
let msg = e.to_string();
ffi::sqlite3_result_error(ctx, msg.as_ptr() as *const _, msg.len() as _);
}
Err(_) => {
let msg = format!("{}::finalize() panicked", std::any::type_name::<A>());
ffi::sqlite3_result_error(ctx, msg.as_ptr() as *const _, msg.len() as _);
}
}
}
}
Expand Down
11 changes: 9 additions & 2 deletions diesel/src/sqlite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,19 @@ pub trait SqliteAggregateFunction<Args>: Default {
/// The result type of the SQLite aggregate function
type Output;

/// The `step()` method is called once for every record of the query
/// The `step()` method is called once for every record of the query.
///
/// This is called through a C FFI, as such panics do not propagate to the caller. Panics are
/// caught and cause a return with an error value. The implementation must still ensure that
/// state remains in a valid state (refer to [`std::panic::UnwindSafe`] for a bit more detail).
fn step(&mut self, args: Args);

/// After the last row has been processed, the `finalize()` method is
/// called to compute the result of the aggregate function. If no rows
/// were processed `aggregator` will be `None` and `finalize()` can be
/// used to specify a default result
/// used to specify a default result.
///
/// This is called through a C FFI, as such panics do not propagate to the caller. Panics are
/// caught and cause a return with an error value.
fn finalize(aggregator: Option<Self>) -> Self::Output;
}
18 changes: 14 additions & 4 deletions diesel_derives/src/sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,15 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic>
conn: &SqliteConnection
) -> QueryResult<()>
where
A: SqliteAggregateFunction<(#(#arg_name,)*)> + Send + 'static,
A: SqliteAggregateFunction<(#(#arg_name,)*)>
+ Send
+ 'static
+ ::std::panic::UnwindSafe
+ ::std::panic::RefUnwindSafe,
A::Output: ToSql<#return_type, Sqlite>,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
::std::panic::UnwindSafe,
{
conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
}
Expand All @@ -204,10 +209,15 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic>
conn: &SqliteConnection
) -> QueryResult<()>
where
A: SqliteAggregateFunction<#arg_name> + Send + 'static,
A: SqliteAggregateFunction<#arg_name>
+ Send
+ 'static
+ std::panic::UnwindSafe
+ std::panic::RefUnwindSafe,
A::Output: ToSql<#return_type, Sqlite>,
#arg_name: FromSqlRow<#arg_type, Sqlite> +
StaticallySizedRow<#arg_type, Sqlite>,
StaticallySizedRow<#arg_type, Sqlite> +
::std::panic::UnwindSafe,
{
conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
}
Expand Down

0 comments on commit ee2f792

Please sign in to comment.