Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a use-after-free of trampoline code #2408

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/jit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub mod trampoline;

pub use crate::code_memory::CodeMemory;
pub use crate::compiler::{Compilation, CompilationStrategy, Compiler};
pub use crate::instantiate::{CompilationArtifacts, CompiledModule, SetupError};
pub use crate::instantiate::{CompilationArtifacts, CompiledModule, ModuleCode, SetupError};
pub use crate::link::link_module;

/// Version number of this crate.
Expand Down
3 changes: 2 additions & 1 deletion crates/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ pub use crate::memory::{RuntimeLinearMemory, RuntimeMemoryCreator};
pub use crate::mmap::Mmap;
pub use crate::table::{Table, TableElement};
pub use crate::traphandlers::{
catch_traps, init_traps, raise_lib_trap, raise_user_trap, resume_panic, SignalHandler, Trap,
catch_traps, init_traps, raise_lib_trap, raise_user_trap, resume_panic, with_last_info,
SignalHandler, Trap, TrapInfo,
};
pub use crate::vmcontext::{
VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition,
Expand Down
70 changes: 44 additions & 26 deletions crates/runtime/src/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,7 @@ impl Trap {
/// Highly unsafe since `closure` won't have any dtors run.
pub unsafe fn catch_traps<F>(
vmctx: *mut VMContext,
max_wasm_stack: usize,
is_wasm_code: impl Fn(usize) -> bool,
signal_handler: Option<&SignalHandler>,
trap_info: &impl TrapInfo,
mut closure: F,
) -> Result<(), Trap>
where
Expand All @@ -382,7 +380,7 @@ where
#[cfg(unix)]
setup_unix_sigaltstack()?;

return CallThreadState::new(vmctx, &is_wasm_code, signal_handler).with(max_wasm_stack, |cx| {
return CallThreadState::new(vmctx, trap_info).with(|cx| {
RegisterSetjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
Expand All @@ -398,15 +396,46 @@ where
}
}

/// Runs `func` with the last `trap_info` object registered by `catch_traps`.
///
/// Calls `func` with `None` if `catch_traps` wasn't previously called from this
/// stack frame.
pub fn with_last_info<R>(func: impl FnOnce(Option<&dyn Any>) -> R) -> R {
tls::with(|state| func(state.map(|s| s.trap_info.as_any())))
}

/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState<'a> {
unwind: Cell<UnwindReason>,
jmp_buf: Cell<*const u8>,
vmctx: *mut VMContext,
handling_trap: Cell<bool>,
is_wasm_code: &'a (dyn Fn(usize) -> bool + 'a),
signal_handler: Option<&'a SignalHandler<'a>>,
trap_info: &'a (dyn TrapInfo + 'a),
}

/// A package of functionality needed by `catch_traps` to figure out what to do
/// when handling a trap.
///
/// Note that this is an `unsafe` trait at least because it's being run in the
/// context of a synchronous signal handler, so it needs to be careful to not
/// access too much state in answering these queries.
pub unsafe trait TrapInfo {
/// Converts this object into an `Any` to dynamically check its type.
fn as_any(&self) -> &dyn Any;

/// Returns whether the given program counter lies within wasm code,
/// indicating whether we should handle a trap or not.
fn is_wasm_code(&self, pc: usize) -> bool;

/// Uses `call` to call a custom signal handler, if one is specified.
///
/// Returns `true` if `call` returns true, otherwise returns `false`.
fn custom_signal_handler(&self, call: &dyn Fn(&SignalHandler) -> bool) -> bool;

/// Returns the maximum size, in bytes, the wasm native stack is allowed to
/// grow to.
fn max_wasm_stack(&self) -> usize;
}

enum UnwindReason {
Expand All @@ -418,27 +447,18 @@ enum UnwindReason {
}

impl<'a> CallThreadState<'a> {
fn new(
vmctx: *mut VMContext,
is_wasm_code: &'a (dyn Fn(usize) -> bool + 'a),
signal_handler: Option<&'a SignalHandler<'a>>,
) -> CallThreadState<'a> {
fn new(vmctx: *mut VMContext, trap_info: &'a (dyn TrapInfo + 'a)) -> CallThreadState<'a> {
CallThreadState {
unwind: Cell::new(UnwindReason::None),
vmctx,
jmp_buf: Cell::new(ptr::null()),
handling_trap: Cell::new(false),
is_wasm_code,
signal_handler,
trap_info,
}
}

fn with(
self,
max_wasm_stack: usize,
closure: impl FnOnce(&CallThreadState) -> i32,
) -> Result<(), Trap> {
let _reset = self.update_stack_limit(max_wasm_stack)?;
fn with(self, closure: impl FnOnce(&CallThreadState) -> i32) -> Result<(), Trap> {
let _reset = self.update_stack_limit()?;
let ret = tls::set(&self, || closure(&self));
match self.unwind.replace(UnwindReason::None) {
UnwindReason::None => {
Expand Down Expand Up @@ -498,7 +518,7 @@ impl<'a> CallThreadState<'a> {
///
/// Note that this function must be called with `self` on the stack, not the
/// heap/etc.
fn update_stack_limit(&self, max_wasm_stack: usize) -> Result<impl Drop + '_, Trap> {
fn update_stack_limit(&self) -> Result<impl Drop + '_, Trap> {
// Determine the stack pointer where, after which, any wasm code will
// immediately trap. This is checked on the entry to all wasm functions.
//
Expand All @@ -510,7 +530,7 @@ impl<'a> CallThreadState<'a> {
// to it). In any case it's expected to be at most a few hundred bytes
// of slop one way or another. When wasm is typically given a MB or so
// (a million bytes) the slop shouldn't matter too much.
let wasm_stack_limit = psm::stack_pointer() as usize - max_wasm_stack;
let wasm_stack_limit = psm::stack_pointer() as usize - self.trap_info.max_wasm_stack();

let interrupts = unsafe { &**(&*self.vmctx).instance().interrupts() };
let reset_stack_limit = match interrupts.stack_limit.compare_exchange(
Expand Down Expand Up @@ -604,14 +624,12 @@ impl<'a> CallThreadState<'a> {
// First up see if any instance registered has a custom trap handler,
// in which case run them all. If anything handles the trap then we
// return that the trap was handled.
if let Some(handler) = self.signal_handler {
if call_handler(handler) {
return 1 as *const _;
}
if self.trap_info.custom_signal_handler(&call_handler) {
return 1 as *const _;
}

// If this fault wasn't in wasm code, then it's not our problem
if !(self.is_wasm_code)(pc as usize) {
if !self.trap_info.is_wasm_code(pc as usize) {
return ptr::null();
}

Expand Down
1 change: 0 additions & 1 deletion crates/wasmtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ libc = "0.2"
cfg-if = "1.0"
backtrace = "0.3.42"
rustc-demangle = "0.1.16"
lazy_static = "1.4"
log = "0.4.8"
wat = { version = "1.0.18", optional = true }
smallvec = "1.4.0"
Expand Down
10 changes: 5 additions & 5 deletions crates/wasmtime/src/externals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,13 @@ impl Table {
// come from different modules.

let dst_table_index = dst_table.wasmtime_table_index();
let dst_table = dst_table.instance.get_defined_table(dst_table_index);
let dst_table_index = dst_table.instance.get_defined_table(dst_table_index);

let src_table_index = src_table.wasmtime_table_index();
let src_table = src_table.instance.get_defined_table(src_table_index);
let src_table_index = src_table.instance.get_defined_table(src_table_index);

runtime::Table::copy(dst_table, src_table, dst_index, src_index, len)
.map_err(Trap::from_runtime)?;
runtime::Table::copy(dst_table_index, src_table_index, dst_index, src_index, len)
.map_err(|e| Trap::from_runtime(&dst_table.instance.store, e))?;
Ok(())
}

Expand All @@ -523,7 +523,7 @@ impl Table {
self.instance
.handle
.defined_table_fill(table_index, dst, val.into_table_element()?, len)
.map_err(Trap::from_runtime)?;
.map_err(|e| Trap::from_runtime(&self.instance.store, e))?;

Ok(())
}
Expand Down
Loading