Skip to content

Commit

Permalink
Merge pull request from GHSA-ch89-5g45-qwc7
Browse files Browse the repository at this point in the history
* Fix miscompile from functions mutating `VMContext`

This commit fixes a miscompilation in Wasmtime on LLVM 16 where methods
on `Instance` which mutated the state of the internal `VMContext` were
optimized to not actually mutate the state. The root cause of this issue
is a change in LLVM which takes advantage of `noalias readonly` pointers
which is how `&self` methods are translated. This means that `Instance`
methods which take `&self` but actually mutate the `VMContext` end up
being undefined behavior from LLVM's point of view, meaning that the
writes are candidate for removal.

The fix applied here is intended to be a temporary one while a more
formal fix, ideally backed by `cargo miri` verification, is implemented
on `main`. The fix here is to change the return value of
`vmctx_plus_offset` to return `*const T` instead of `*mut T`. This
caused lots of portions of the runtime code to stop compiling because
mutations were indeed happening. To cover these a new
`vmctx_plus_offset_mut` method was added which notably takes `&mut self`
instead of `&self`. This forced all callers which may mutate to reflect
the `&mut self` requirement, propagating that outwards.

This fixes the miscompilation with LLVM 16 in the immediate future and
should be at least a meager line of defense against issues like this in
the future. This is not a long-term fix, though, since `cargo miri`
still does not like what's being done in `Instance` and with
`VMContext`. That fix is likely to be more invasive, though, so it's
being deferred to later.

* Update release notes

* Fix release date
  • Loading branch information
alexcrichton authored Apr 27, 2023
1 parent b6bc33d commit 4b9ce0e
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 82 deletions.
11 changes: 11 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
--------------------------------------------------------------------------------

## 6.0.2

Released 2023-04-27.

### Fixed

* Undefined Behavior in Rust runtime functions
[GHSA-ch89-5g45-qwc7](https://github.com/bytecodealliance/wasmtime/security/advisories/GHSA-ch89-5g45-qwc7)

--------------------------------------------------------------------------------

## 6.0.1

Released 2023-03-08.
Expand Down
28 changes: 15 additions & 13 deletions crates/environ/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ impl ModuleTranslation<'_> {
}
let mut idx = 0;
let ok = self.module.memory_initialization.init_memory(
&mut (),
InitMemory::CompileTime(&self.module),
&mut |memory, init| {
|(), memory, init| {
// Currently `Static` only applies to locally-defined memories,
// so if a data segment references an imported memory then
// transitioning to a `Static` memory initializer is not
Expand Down Expand Up @@ -525,10 +526,11 @@ impl MemoryInitialization {
/// question needs to be deferred to runtime, and at runtime this means
/// that an invalid initializer has been found and a trap should be
/// generated.
pub fn init_memory(
pub fn init_memory<T>(
&self,
state: InitMemory<'_>,
write: &mut dyn FnMut(MemoryIndex, &StaticMemoryInitializer) -> bool,
state: &mut T,
init: InitMemory<'_, T>,
mut write: impl FnMut(&mut T, MemoryIndex, &StaticMemoryInitializer) -> bool,
) -> bool {
let initializers = match self {
// Fall through below to the segmented memory one-by-one
Expand All @@ -543,7 +545,7 @@ impl MemoryInitialization {
MemoryInitialization::Static { map } => {
for (index, init) in map {
if let Some(init) = init {
let result = write(index, init);
let result = write(state, index, init);
if !result {
return result;
}
Expand All @@ -567,10 +569,10 @@ impl MemoryInitialization {
// (e.g. this is a task happening before instantiation at
// compile-time).
let base = match base {
Some(index) => match &state {
Some(index) => match &init {
InitMemory::Runtime {
get_global_as_u64, ..
} => get_global_as_u64(index),
} => get_global_as_u64(state, index),
InitMemory::CompileTime(_) => return false,
},
None => 0,
Expand All @@ -585,12 +587,12 @@ impl MemoryInitialization {
None => return false,
};

let cur_size_in_pages = match &state {
let cur_size_in_pages = match &init {
InitMemory::CompileTime(module) => module.memory_plans[memory_index].memory.minimum,
InitMemory::Runtime {
memory_size_in_pages,
..
} => memory_size_in_pages(memory_index),
} => memory_size_in_pages(state, memory_index),
};

// Note that this `minimum` can overflow if `minimum` is
Expand All @@ -616,7 +618,7 @@ impl MemoryInitialization {
offset: start,
data: data.clone(),
};
let result = write(memory_index, &init);
let result = write(state, memory_index, &init);
if !result {
return result;
}
Expand All @@ -628,7 +630,7 @@ impl MemoryInitialization {

/// Argument to [`MemoryInitialization::init_memory`] indicating the current
/// status of the instance.
pub enum InitMemory<'a> {
pub enum InitMemory<'a, T> {
/// This evaluation of memory initializers is happening at compile time.
/// This means that the current state of memories is whatever their initial
/// state is, and additionally globals are not available if data segments
Expand All @@ -640,10 +642,10 @@ pub enum InitMemory<'a> {
/// instance's state.
Runtime {
/// Returns the size, in wasm pages, of the the memory specified.
memory_size_in_pages: &'a dyn Fn(MemoryIndex) -> u64,
memory_size_in_pages: &'a dyn Fn(&mut T, MemoryIndex) -> u64,
/// Returns the value of the global, as a `u64`. Note that this may
/// involve zero-extending a 32-bit global to a 64-bit number.
get_global_as_u64: &'a dyn Fn(GlobalIndex) -> u64,
get_global_as_u64: &'a dyn Fn(&mut T, GlobalIndex) -> u64,
},
}

Expand Down
80 changes: 44 additions & 36 deletions crates/runtime/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,14 @@ impl Instance {

/// Helper function to access various locations offset from our `*mut
/// VMContext` object.
unsafe fn vmctx_plus_offset<T>(&self, offset: u32) -> *mut T {
(self.vmctx_ptr().cast::<u8>())
unsafe fn vmctx_plus_offset<T>(&self, offset: u32) -> *const T {
(std::ptr::addr_of!(self.vmctx).cast::<u8>())
.add(usize::try_from(offset).unwrap())
.cast()
}

unsafe fn vmctx_plus_offset_mut<T>(&mut self, offset: u32) -> *mut T {
(std::ptr::addr_of_mut!(self.vmctx).cast::<u8>())
.add(usize::try_from(offset).unwrap())
.cast()
}
Expand Down Expand Up @@ -183,20 +189,20 @@ impl Instance {

/// Return the indexed `VMTableDefinition`.
#[allow(dead_code)]
fn table(&self, index: DefinedTableIndex) -> VMTableDefinition {
fn table(&mut self, index: DefinedTableIndex) -> VMTableDefinition {
unsafe { *self.table_ptr(index) }
}

/// Updates the value for a defined table to `VMTableDefinition`.
fn set_table(&self, index: DefinedTableIndex, table: VMTableDefinition) {
fn set_table(&mut self, index: DefinedTableIndex, table: VMTableDefinition) {
unsafe {
*self.table_ptr(index) = table;
}
}

/// Return the indexed `VMTableDefinition`.
fn table_ptr(&self, index: DefinedTableIndex) -> *mut VMTableDefinition {
unsafe { self.vmctx_plus_offset(self.offsets().vmctx_vmtable_definition(index)) }
fn table_ptr(&mut self, index: DefinedTableIndex) -> *mut VMTableDefinition {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_vmtable_definition(index)) }
}

/// Get a locally defined or imported memory.
Expand Down Expand Up @@ -238,21 +244,21 @@ impl Instance {
}

/// Return the indexed `VMGlobalDefinition`.
fn global(&self, index: DefinedGlobalIndex) -> &VMGlobalDefinition {
fn global(&mut self, index: DefinedGlobalIndex) -> &VMGlobalDefinition {
unsafe { &*self.global_ptr(index) }
}

/// Return the indexed `VMGlobalDefinition`.
fn global_ptr(&self, index: DefinedGlobalIndex) -> *mut VMGlobalDefinition {
unsafe { self.vmctx_plus_offset(self.offsets().vmctx_vmglobal_definition(index)) }
fn global_ptr(&mut self, index: DefinedGlobalIndex) -> *mut VMGlobalDefinition {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_vmglobal_definition(index)) }
}

/// Get a raw pointer to the global at the given index regardless whether it
/// is defined locally or imported from another module.
///
/// Panics if the index is out of bound or is the reserved value.
pub(crate) fn defined_or_imported_global_ptr(
&self,
&mut self,
index: GlobalIndex,
) -> *mut VMGlobalDefinition {
if let Some(index) = self.module().defined_global_index(index) {
Expand All @@ -263,18 +269,18 @@ impl Instance {
}

/// Return a pointer to the interrupts structure
pub fn runtime_limits(&self) -> *mut *const VMRuntimeLimits {
unsafe { self.vmctx_plus_offset(self.offsets().vmctx_runtime_limits()) }
pub fn runtime_limits(&mut self) -> *mut *const VMRuntimeLimits {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_runtime_limits()) }
}

/// Return a pointer to the global epoch counter used by this instance.
pub fn epoch_ptr(&self) -> *mut *const AtomicU64 {
unsafe { self.vmctx_plus_offset(self.offsets().vmctx_epoch_ptr()) }
pub fn epoch_ptr(&mut self) -> *mut *const AtomicU64 {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_epoch_ptr()) }
}

/// Return a pointer to the `VMExternRefActivationsTable`.
pub fn externref_activations_table(&self) -> *mut *mut VMExternRefActivationsTable {
unsafe { self.vmctx_plus_offset(self.offsets().vmctx_externref_activations_table()) }
pub fn externref_activations_table(&mut self) -> *mut *mut VMExternRefActivationsTable {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_externref_activations_table()) }
}

/// Gets a pointer to this instance's `Store` which was originally
Expand All @@ -297,7 +303,7 @@ impl Instance {

pub unsafe fn set_store(&mut self, store: Option<*mut dyn Store>) {
if let Some(store) = store {
*self.vmctx_plus_offset(self.offsets().vmctx_store()) = store;
*self.vmctx_plus_offset_mut(self.offsets().vmctx_store()) = store;
*self.runtime_limits() = (*store).vmruntime_limits();
*self.epoch_ptr() = (*store).epoch_ptr();
*self.externref_activations_table() = (*store).externref_activations_table().0;
Expand All @@ -306,7 +312,7 @@ impl Instance {
mem::size_of::<*mut dyn Store>(),
mem::size_of::<[*mut (); 2]>()
);
*self.vmctx_plus_offset::<[*mut (); 2]>(self.offsets().vmctx_store()) =
*self.vmctx_plus_offset_mut::<[*mut (); 2]>(self.offsets().vmctx_store()) =
[ptr::null_mut(), ptr::null_mut()];

*self.runtime_limits() = ptr::null_mut();
Expand All @@ -316,7 +322,7 @@ impl Instance {
}

pub(crate) unsafe fn set_callee(&mut self, callee: Option<NonNull<VMFunctionBody>>) {
*self.vmctx_plus_offset(self.offsets().vmctx_callee()) =
*self.vmctx_plus_offset_mut(self.offsets().vmctx_callee()) =
callee.map_or(ptr::null_mut(), |c| c.as_ptr());
}

Expand Down Expand Up @@ -402,7 +408,7 @@ impl Instance {
}

/// Return the table index for the given `VMTableDefinition`.
unsafe fn table_index(&self, table: &VMTableDefinition) -> DefinedTableIndex {
unsafe fn table_index(&mut self, table: &VMTableDefinition) -> DefinedTableIndex {
let index = DefinedTableIndex::new(
usize::try_from(
(table as *const VMTableDefinition)
Expand Down Expand Up @@ -515,7 +521,7 @@ impl Instance {
) {
let type_index = unsafe {
let base: *const VMSharedSignatureIndex =
*self.vmctx_plus_offset(self.offsets().vmctx_signature_ids_array());
*self.vmctx_plus_offset_mut(self.offsets().vmctx_signature_ids_array());
*base.add(sig.index())
};

Expand Down Expand Up @@ -584,7 +590,7 @@ impl Instance {
let func = &self.module().functions[index];
let sig = func.signature;
let anyfunc: *mut VMCallerCheckedAnyfunc = self
.vmctx_plus_offset::<VMCallerCheckedAnyfunc>(
.vmctx_plus_offset_mut::<VMCallerCheckedAnyfunc>(
self.offsets().vmctx_anyfunc(func.anyfunc),
);
self.construct_anyfunc(index, sig, anyfunc);
Expand Down Expand Up @@ -923,40 +929,41 @@ impl Instance {
) {
assert!(std::ptr::eq(module, self.module().as_ref()));

*self.vmctx_plus_offset(offsets.vmctx_magic()) = VMCONTEXT_MAGIC;
*self.vmctx_plus_offset_mut(offsets.vmctx_magic()) = VMCONTEXT_MAGIC;
self.set_callee(None);
self.set_store(store.as_raw());

// Initialize shared signatures
let signatures = self.runtime_info.signature_ids();
*self.vmctx_plus_offset(offsets.vmctx_signature_ids_array()) = signatures.as_ptr();
*self.vmctx_plus_offset_mut(offsets.vmctx_signature_ids_array()) = signatures.as_ptr();

// Initialize the built-in functions
*self.vmctx_plus_offset(offsets.vmctx_builtin_functions()) = &VMBuiltinFunctionsArray::INIT;
*self.vmctx_plus_offset_mut(offsets.vmctx_builtin_functions()) =
&VMBuiltinFunctionsArray::INIT;

// Initialize the imports
debug_assert_eq!(imports.functions.len(), module.num_imported_funcs);
ptr::copy_nonoverlapping(
imports.functions.as_ptr(),
self.vmctx_plus_offset(offsets.vmctx_imported_functions_begin()),
self.vmctx_plus_offset_mut(offsets.vmctx_imported_functions_begin()),
imports.functions.len(),
);
debug_assert_eq!(imports.tables.len(), module.num_imported_tables);
ptr::copy_nonoverlapping(
imports.tables.as_ptr(),
self.vmctx_plus_offset(offsets.vmctx_imported_tables_begin()),
self.vmctx_plus_offset_mut(offsets.vmctx_imported_tables_begin()),
imports.tables.len(),
);
debug_assert_eq!(imports.memories.len(), module.num_imported_memories);
ptr::copy_nonoverlapping(
imports.memories.as_ptr(),
self.vmctx_plus_offset(offsets.vmctx_imported_memories_begin()),
self.vmctx_plus_offset_mut(offsets.vmctx_imported_memories_begin()),
imports.memories.len(),
);
debug_assert_eq!(imports.globals.len(), module.num_imported_globals);
ptr::copy_nonoverlapping(
imports.globals.as_ptr(),
self.vmctx_plus_offset(offsets.vmctx_imported_globals_begin()),
self.vmctx_plus_offset_mut(offsets.vmctx_imported_globals_begin()),
imports.globals.len(),
);

Expand All @@ -967,7 +974,7 @@ impl Instance {
// any state now.

// Initialize the defined tables
let mut ptr = self.vmctx_plus_offset(offsets.vmctx_tables_begin());
let mut ptr = self.vmctx_plus_offset_mut(offsets.vmctx_tables_begin());
for i in 0..module.table_plans.len() - module.num_imported_tables {
ptr::write(ptr, self.tables[DefinedTableIndex::new(i)].vmtable());
ptr = ptr.add(1);
Expand All @@ -978,8 +985,8 @@ impl Instance {
// time. Entries in `defined_memories` hold a pointer to a definition
// (all memories) whereas the `owned_memories` hold the actual
// definitions of memories owned (not shared) in the module.
let mut ptr = self.vmctx_plus_offset(offsets.vmctx_memories_begin());
let mut owned_ptr = self.vmctx_plus_offset(offsets.vmctx_owned_memories_begin());
let mut ptr = self.vmctx_plus_offset_mut(offsets.vmctx_memories_begin());
let mut owned_ptr = self.vmctx_plus_offset_mut(offsets.vmctx_owned_memories_begin());
for i in 0..module.memory_plans.len() - module.num_imported_memories {
let defined_memory_index = DefinedMemoryIndex::new(i);
let memory_index = module.memory_index(defined_memory_index);
Expand Down Expand Up @@ -1051,8 +1058,9 @@ impl Instance {
impl Drop for Instance {
fn drop(&mut self) {
// Drop any defined globals
for (idx, global) in self.module().globals.iter() {
let idx = match self.module().defined_global_index(idx) {
let module = self.module().clone();
for (idx, global) in module.globals.iter() {
let idx = match module.defined_global_index(idx) {
Some(idx) => idx,
None => continue,
};
Expand Down Expand Up @@ -1165,8 +1173,8 @@ impl InstanceHandle {
}

/// Return the table index for the given `VMTableDefinition` in this instance.
pub unsafe fn table_index(&self, table: &VMTableDefinition) -> DefinedTableIndex {
self.instance().table_index(table)
pub unsafe fn table_index(&mut self, table: &VMTableDefinition) -> DefinedTableIndex {
self.instance_mut().table_index(table)
}

/// Get a table defined locally within this module.
Expand Down
Loading

0 comments on commit 4b9ce0e

Please sign in to comment.