diff --git a/crates/continuations/src/lib.rs b/crates/continuations/src/lib.rs index 56647aecabdf..a2a0ac8a3e07 100644 --- a/crates/continuations/src/lib.rs +++ b/crates/continuations/src/lib.rs @@ -1,4 +1,4 @@ -use std::ptr; +use std::{cell::UnsafeCell, ptr}; use wasmtime_fibre::Fiber; /// TODO @@ -132,6 +132,25 @@ impl StackChain { pub const CONTINUATION_DISCRIMINANT: usize = STACK_CHAIN_CONTINUATION_DISCRIMINANT; } +#[repr(transparent)] +pub struct StackChainCell(pub UnsafeCell); + +impl StackChainCell { + pub fn absent() -> Self { + StackChainCell(UnsafeCell::new(StackChain::Absent)) + } +} + +// Since `StackChainCell` and `StackLimits` objects appear in the `StoreOpaque`, +// they need to be `Send` and `Sync`. +// This is safe for the same reason it is for `VMRuntimeLimits` (see comment +// there): Both types are pod-type with no destructor, and we don't access any +// of their fields from other threads. +unsafe impl Send for StackLimits {} +unsafe impl Sync for StackLimits {} +unsafe impl Send for StackChainCell {} +unsafe impl Sync for StackChainCell {} + pub struct Payloads { /// Number of currently occupied slots. pub length: types::payloads::Length, diff --git a/crates/cranelift/src/wasmfx/optimized.rs b/crates/cranelift/src/wasmfx/optimized.rs index e3c97937e2aa..04a37ca09c96 100644 --- a/crates/cranelift/src/wasmfx/optimized.rs +++ b/crates/cranelift/src/wasmfx/optimized.rs @@ -762,7 +762,20 @@ pub(crate) mod typed_continuation_helpers { let offset = i32::try_from(env.offsets.vmctx_typed_continuations_stack_chain()).unwrap(); - StackChain::load(env, builder, base_addr, offset, self.pointer_type) + + // The `typed_continuations_stack_chain` field of the VMContext only + // contains a pointer to the `StackChainCell` in the `Store`. + // The pointer never changes through the liftime of a `VMContext`, + // which is why this load is `readonly`. + // TODO(frank-emrich) Consider turning this pointer into a global + // variable, similar to `env.vmruntime_limits_ptr`. + let memflags = ir::MemFlags::trusted().with_readonly(); + let stack_chain_ptr = + builder + .ins() + .load(self.pointer_type, memflags, base_addr, offset); + + StackChain::load(env, builder, stack_chain_ptr, 0, self.pointer_type) } /// Stores the given stack chain saved in this `VMContext`, overwriting @@ -777,7 +790,16 @@ pub(crate) mod typed_continuation_helpers { let offset = i32::try_from(env.offsets.vmctx_typed_continuations_stack_chain()).unwrap(); - stack_chain.store(env, builder, base_addr, offset) + + // Same situation as in `load_stack_chain` regarding pointer + // indirection and it being `readonly`. + let memflags = ir::MemFlags::trusted().with_readonly(); + let stack_chain_ptr = + builder + .ins() + .load(self.pointer_type, memflags, base_addr, offset); + + stack_chain.store(env, builder, stack_chain_ptr, 0) } /// Similar to `store_stack_chain`, but instead of storing an arbitrary diff --git a/crates/environ/src/vmoffsets.rs b/crates/environ/src/vmoffsets.rs index 3a358ef656af..3ee07e321f9f 100644 --- a/crates/environ/src/vmoffsets.rs +++ b/crates/environ/src/vmoffsets.rs @@ -92,11 +92,9 @@ pub struct VMOffsets

{ defined_func_refs: u32, size: u32, - // The following field stores a value of type - // `wasmtime_continuations::StackLimits`. - typed_continuations_main_stack_limits: u32, - // The following field stores a value of type - // `wasmtime_continuations::StackChain`. The head of the chain is the + // The following field stores a pointer into the StoreOpauqe, to value of + // type `wasmtime_continuations::StackChain`. + // The head of the chain is the // currently executing stack (main stack or a continuation). typed_continuations_stack_chain: u32, typed_continuations_payloads: u32, @@ -363,7 +361,6 @@ impl VMOffsets

{ calculate_sizes! { typed_continuations_payloads: "typed continuations payloads object", typed_continuations_stack_chain: "typed continuations stack chain", - typed_continuations_main_stack_limits: "typed continuations main stack limits", defined_func_refs: "module functions", defined_globals: "defined globals", owned_memories: "owned memories", @@ -416,7 +413,6 @@ impl From> for VMOffsets

{ defined_globals: 0, defined_func_refs: 0, size: 0, - typed_continuations_main_stack_limits: 0, typed_continuations_stack_chain: 0, typed_continuations_payloads: 0, }; @@ -482,14 +478,8 @@ impl From> for VMOffsets

{ ret.ptr.size_of_vm_func_ref(), ), - align(std::mem::align_of::() as u32), - size(typed_continuations_main_stack_limits) - = std::mem::size_of::() as u32, - - align(std::mem::align_of::() as u32), size(typed_continuations_stack_chain) - = std::mem::size_of::() as u32, - + = ret.ptr.size(), align(std::mem::align_of::() as u32), size(typed_continuations_payloads) = std::mem::size_of::() as u32, @@ -746,12 +736,6 @@ impl VMOffsets

{ self.builtin_functions } - /// TODO - #[inline] - pub fn vmctx_typed_continuations_main_stack_limits(&self) -> u32 { - self.typed_continuations_main_stack_limits - } - /// TODO #[inline] pub fn vmctx_typed_continuations_stack_chain(&self) -> u32 { diff --git a/crates/runtime/src/continuation.rs b/crates/runtime/src/continuation.rs index b8e7cf1351c7..acea0eeaeae4 100644 --- a/crates/runtime/src/continuation.rs +++ b/crates/runtime/src/continuation.rs @@ -7,7 +7,7 @@ use std::mem; use wasmtime_continuations::{debug_println, ENABLE_DEBUG_PRINTING}; pub use wasmtime_continuations::{ ContinuationFiber, ContinuationObject, ContinuationReference, Payloads, StackChain, - StackLimits, State, + StackChainCell, StackLimits, State, }; use wasmtime_fibre::{Fiber, FiberStack, Suspend}; @@ -187,7 +187,7 @@ pub fn resume( // SAFETY: We maintain as an invariant that the stack chain field in the // VMContext is non-null and contains a chain of zero or more // StackChain::Continuation values followed by StackChain::Main. - match unsafe { &*chain } { + match unsafe { (**chain).0.get_mut() } { StackChain::Continuation(running_contobj) => { debug_assert_eq!(contobj, *running_contobj); debug_println!( @@ -273,7 +273,7 @@ pub fn suspend(instance: &mut Instance, tag_index: u32) -> Result<(), TrapReason // SAFETY: We maintain as an invariant that the stack chain field in the // VMContext is non-null and contains a chain of zero or more // StackChain::Continuation values followed by StackChain::Main. - let chain = unsafe { &*chain_ptr }; + let chain = unsafe { (**chain_ptr).0.get_mut() }; let running = match chain { StackChain::Absent => Err(TrapReason::user_without_backtrace(anyhow::anyhow!( "Internal error: StackChain not initialised" diff --git a/crates/runtime/src/instance.rs b/crates/runtime/src/instance.rs index 402f4b143597..c7cc223000ee 100644 --- a/crates/runtime/src/instance.rs +++ b/crates/runtime/src/instance.rs @@ -26,6 +26,7 @@ use std::ptr::NonNull; use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::{mem, ptr}; +use wasmtime_continuations::StackChainCell; use wasmtime_environ::ModuleInternedTypeIndex; use wasmtime_environ::{ packed_option::ReservedValue, DataIndex, DefinedGlobalIndex, DefinedMemoryIndex, @@ -432,6 +433,14 @@ impl Instance { unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_runtime_limits()) } } + /// Return a pointer to the stack chain + #[inline] + pub fn stack_chain(&mut self) -> *mut *mut StackChainCell { + unsafe { + self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_stack_chain()) + } + } + /// Return a pointer to the global epoch counter used by this instance. pub fn epoch_ptr(&mut self) -> *mut *const AtomicU64 { unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_epoch_ptr()) } @@ -464,6 +473,7 @@ impl Instance { if let Some(store) = store { *self.vmctx_plus_offset_mut(self.offsets().vmctx_store()) = store; *self.runtime_limits() = (*store).vmruntime_limits(); + *self.stack_chain() = (*store).stack_chain(); *self.epoch_ptr() = (*store).epoch_ptr(); *self.externref_activations_table() = (*store).externref_activations_table().0; } else { @@ -1133,13 +1143,6 @@ impl Instance { *self.vmctx_plus_offset_mut(offsets.vmctx_builtin_functions()) = &VMBuiltinFunctionsArray::INIT; - let main_stack_limits_ptr = - self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_main_stack_limits()); - *main_stack_limits_ptr = wasmtime_continuations::StackLimits::default(); - - *self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_stack_chain()) = - wasmtime_continuations::StackChain::MainStack(main_stack_limits_ptr); - // Initialize the Payloads object to be empty let vmctx_payloads: *mut wasmtime_continuations::Payloads = self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_payloads()); @@ -1283,18 +1286,9 @@ impl Instance { fault } - #[allow(dead_code)] - pub(crate) fn typed_continuations_main_stack_limits( - &mut self, - ) -> *mut wasmtime_continuations::StackLimits { - unsafe { - self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_main_stack_limits()) - } - } - pub(crate) fn typed_continuations_stack_chain( &mut self, - ) -> *mut wasmtime_continuations::StackChain { + ) -> *mut *mut wasmtime_continuations::StackChainCell { unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_stack_chain()) } @@ -1303,7 +1297,7 @@ impl Instance { #[allow(dead_code)] pub(crate) fn set_typed_continuations_stack_chain( &mut self, - chain: *mut wasmtime_continuations::StackChain, + chain: *mut *mut wasmtime_continuations::StackChainCell, ) { unsafe { let ptr = diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index c7907408d4e9..4ba4436beeab 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -8,6 +8,7 @@ use std::fmt; use std::ptr::NonNull; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; +use wasmtime_continuations::StackChainCell; use wasmtime_environ::{DefinedFuncIndex, DefinedMemoryIndex, HostPtr, VMOffsets}; mod arch; @@ -97,6 +98,10 @@ pub unsafe trait Store { /// in the `VMContext`. fn vmruntime_limits(&self) -> *mut VMRuntimeLimits; + /// Used to configure `VMContext` initialization and store the right pointer + /// in the `VMContext`. + fn stack_chain(&self) -> *mut StackChainCell; + /// Returns a pointer to the global epoch counter. /// /// Used to configure the `VMContext` on initialization. diff --git a/crates/wasmtime/src/runtime/store.rs b/crates/wasmtime/src/runtime/store.rs index db9d917b60d4..b4fde3f55289 100644 --- a/crates/wasmtime/src/runtime/store.rs +++ b/crates/wasmtime/src/runtime/store.rs @@ -95,6 +95,7 @@ use std::ptr; use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::task::{Context, Poll}; +use wasmtime_runtime::continuation::{StackChain, StackChainCell, StackLimits}; use wasmtime_runtime::mpk::{self, ProtectionKey, ProtectionMask}; use wasmtime_runtime::{ ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle, ModuleInfo, @@ -303,6 +304,21 @@ pub struct StoreOpaque { engine: Engine, runtime_limits: VMRuntimeLimits, + + // Stack information used by typed continuations instructions. See + // documentation on `wasmtime_continuations::StackChain` for details. + // + // Note that in terms of (interior) mutability, we generally follow the same + // pattern as the `VMRuntimeLimits` object above: In the case of + // `StackLimits`, all of its fields are `UnsafeCell`s. For the stack chain, + // we wrap the entire `StackChainObject` in an `UnsafeCell`. + // + // Finally, observe that the stack chain adds more internal self references: + // The stack chain always contains a `MainStack` element at the ends which + // has a pointer to the `main_stack_limits` field of the same `StoreOpaque`. + main_stack_limits: StackLimits, + stack_chain: StackChainCell, + instances: Vec, #[cfg(feature = "component-model")] num_component_instances: usize, @@ -492,6 +508,8 @@ impl Store { _marker: marker::PhantomPinned, engine: engine.clone(), runtime_limits: Default::default(), + main_stack_limits: Default::default(), + stack_chain: StackChainCell::absent(), instances: Vec::new(), #[cfg(feature = "component-model")] num_component_instances: 0, @@ -573,6 +591,15 @@ impl Store { instance }; + unsafe { + // NOTE(frank-emrich) The setup code for `default_caller` above + // together with the comment on the `PhantomPinned` marker inside + // `Store` indicates that `inner` is supposed to be at a stable + // location at this point, without explicitly being `Pin`-ed. + let stack_chain = inner.stack_chain.0.get(); + *stack_chain = StackChain::MainStack(inner.main_stack_limits()); + } + Self { inner: ManuallyDrop::new(inner), } @@ -1513,6 +1540,20 @@ impl StoreOpaque { &self.runtime_limits as *const VMRuntimeLimits as *mut VMRuntimeLimits } + #[inline] + pub fn main_stack_limits(&self) -> *mut StackLimits { + // NOTE(frank-emrich) This looks dogdy, but follows the same pattern as + // `vmruntime_limits()` above. + &self.main_stack_limits as *const StackLimits as *mut StackLimits + } + + #[inline] + pub fn stack_chain(&self) -> *mut StackChainCell { + // NOTE(frank-emrich) This looks dogdy, but follows the same pattern as + // `vmruntime_limits()` above. + &self.stack_chain as *const StackChainCell as *mut StackChainCell + } + pub unsafe fn insert_vmexternref_without_gc(&mut self, r: VMExternRef) { self.externref_activations_table.insert_without_gc(r); } @@ -2025,6 +2066,10 @@ unsafe impl wasmtime_runtime::Store for StoreInner { ::vmruntime_limits(self) } + fn stack_chain(&self) -> *mut StackChainCell { + ::stack_chain(self) + } + fn epoch_ptr(&self) -> *const AtomicU64 { self.engine.epoch_counter() as *const _ } diff --git a/tests/all/pooling_allocator.rs b/tests/all/pooling_allocator.rs index df8cdb7c791f..8b6f82274220 100644 --- a/tests/all/pooling_allocator.rs +++ b/tests/all/pooling_allocator.rs @@ -661,12 +661,12 @@ configured maximum of 16 bytes; breakdown of allocation requirement: " } else { "\ -instance allocation for this module requires 320 bytes which exceeds the \ +instance allocation for this module requires 272 bytes which exceeds the \ configured maximum of 16 bytes; breakdown of allocation requirement: - * 50.00% - 160 bytes - instance state management - * 10.00% - 32 bytes - typed continuations payloads object - * 10.00% - 32 bytes - typed continuations main stack limits + * 58.82% - 160 bytes - instance state management + * 8.82% - 24 bytes - typed continuations payloads object + * 5.88% - 16 bytes - jit store state " }; match Module::new(&engine, "(module)") { @@ -690,11 +690,11 @@ configured maximum of 16 bytes; breakdown of allocation requirement: " } else { "\ -instance allocation for this module requires 1920 bytes which exceeds the \ +instance allocation for this module requires 1872 bytes which exceeds the \ configured maximum of 16 bytes; breakdown of allocation requirement: - * 8.33% - 160 bytes - instance state management - * 83.33% - 1600 bytes - defined globals + * 8.55% - 160 bytes - instance state management + * 85.47% - 1600 bytes - defined globals " }; match Module::new(&engine, &lots_of_globals) { diff --git a/tests/all/typed_continuations.rs b/tests/all/typed_continuations.rs index 1567e0d92efc..c48b57baf638 100644 --- a/tests/all/typed_continuations.rs +++ b/tests/all/typed_continuations.rs @@ -182,3 +182,76 @@ async fn sched_yield_test_async() -> Result<()> { assert_eq!(run_wasi_test_async(SCHED_YIELD_WAT).await?, 0); Ok(()) } + +/// Test that we can handle a `suspend` from another instance. Note that this +/// test is working around the fact that wasmtime does not support exporting +/// tags at the moment. Thus, instead of sharing a tag between two modules, we +/// instantiate the same module twice to share a tag. +#[test] +fn inter_instance_suspend() -> Result<()> { + let mut config = Config::default(); + config.wasm_function_references(true); + config.wasm_exceptions(true); + config.wasm_typed_continuations(true); + + let engine = Engine::new(&config)?; + + let mut store = Store::<()>::new(&engine, ()); + + let wat_other = r#" + (module + + (type $ft (func)) + (type $ct (cont $ft)) + (tag $tag) + + + (func $suspend (export "suspend") + (suspend $tag) + ) + + (func $resume (export "resume") (param $f (ref $ct)) + (block $handler (result (ref $ct)) + (resume $ct (tag $tag $handler) (local.get $f)) + (unreachable) + ) + (drop) + ) + ) + "#; + + let wat_main = r#" + (module + + (type $ft (func)) + (type $ct (cont $ft)) + + (import "other" "suspend" (func $suspend)) + (import "other" "resume" (func $resume (param (ref $ct)))) + + (elem declare func $suspend) + + + (func $entry (export "entry") + (call $resume (cont.new $ct (ref.func $suspend))) + ) + ) + "#; + + let module_other = Module::new(&engine, wat_other)?; + + let other_inst1 = Instance::new(&mut store, &module_other, &[])?; + let other_inst2 = Instance::new(&mut store, &module_other, &[])?; + + // Crucially, suspend and resume are from two instances of the same module. + let suspend = other_inst1.get_func(&mut store, "suspend").unwrap(); + let resume = other_inst2.get_func(&mut store, "resume").unwrap(); + + let module_main = Module::new(&engine, wat_main)?; + let main_instance = Instance::new(&mut store, &module_main, &[suspend.into(), resume.into()])?; + let entry_func = main_instance.get_func(&mut store, "entry").unwrap(); + + entry_func.call(&mut store, &[], &mut [])?; + + Ok(()) +}