Skip to content

Commit

Permalink
mpk: restore PKRU state when a fiber resumes execution (#7789)
Browse files Browse the repository at this point in the history
* mpk: restore PKRU state when a fiber resumes execution

Previously, when a fiber was suspended, other computation could change
the PKRU state on the current CPU. This means that the fiber could be
resumed with a different PKRU state. This could be bad, resulting in
situations in which the fiber can access more memory slots than it
should or cannot even access its own memory slots.

This change saves the PKRU state prior to a fiber being suspended. When
the fiber resumes execution, that PKRU state is restored.

* mpk: check correct PKRU switching on async suspension

This adds a test that alternately polls two Wasm instances in a loop.
Since the instances are async, we can set up epochs to suspend each
fiber as we iterate over a loop. Because we alternate between the two
instances, it checks that `AsyncCx::block_on` has correctly restored the
PKRU bits; otherwise we should see test failures. In the process of
writing this test I discovered #7942, which can be solved separately
(it has to do with the interaction between memory images, _not_ used
here, and MPK).

prtest:full

* fix: condition the PKRU context switches

Not all stores have protection keys and MPK is not always enabled. This
change checks for these conditions before context-switching the PKRU
bits.
  • Loading branch information
abrown authored Feb 15, 2024
1 parent 83a5a1a commit 2aaeddb
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 5 deletions.
4 changes: 4 additions & 0 deletions crates/runtime/src/mpk/disabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ pub fn keys(_: usize) -> &'static [ProtectionKey] {
}
pub fn allow(_: ProtectionMask) {}

pub fn current_mask() -> ProtectionMask {
ProtectionMask
}

#[derive(Clone, Copy, Debug)]
pub enum ProtectionKey {}
impl ProtectionKey {
Expand Down
5 changes: 5 additions & 0 deletions crates/runtime/src/mpk/enabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ pub fn allow(mask: ProtectionMask) {
log::trace!("PKRU change: {:#034b} => {:#034b}", previous, pkru::read());
}

/// Retrieve the current protection mask.
pub fn current_mask() -> ProtectionMask {
ProtectionMask(pkru::read())
}

/// An MPK protection key.
///
/// The expected usage is:
Expand Down
4 changes: 2 additions & 2 deletions crates/runtime/src/mpk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ cfg_if::cfg_if! {
mod enabled;
mod pkru;
mod sys;
pub use enabled::{allow, is_supported, keys, ProtectionKey, ProtectionMask};
pub use enabled::{allow, current_mask, is_supported, keys, ProtectionKey, ProtectionMask};
} else {
mod disabled;
pub use disabled::{allow, is_supported, keys, ProtectionKey, ProtectionMask};
pub use disabled::{allow, current_mask, is_supported, keys, ProtectionKey, ProtectionMask};
}
}

Expand Down
23 changes: 20 additions & 3 deletions crates/wasmtime/src/runtime/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,11 @@ use std::ptr;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::task::{Context, Poll};
use wasmtime_runtime::mpk::{self, ProtectionKey, ProtectionMask};
use wasmtime_runtime::{
mpk::ProtectionKey, ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle,
ModuleInfo, OnDemandInstanceAllocator, SignalHandler, StoreBox, StorePtr, VMContext,
VMExternRef, VMExternRefActivationsTable, VMFuncRef, VMRuntimeLimits, WasmFault,
ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle, ModuleInfo,
OnDemandInstanceAllocator, SignalHandler, StoreBox, StorePtr, VMContext, VMExternRef,
VMExternRefActivationsTable, VMFuncRef, VMRuntimeLimits, WasmFault,
};

mod context;
Expand Down Expand Up @@ -1401,6 +1402,7 @@ impl StoreOpaque {
Some(AsyncCx {
current_suspend: self.async_state.current_suspend.get(),
current_poll_cx: poll_cx_box_ptr,
track_pkey_context_switch: mpk::is_supported() && self.pkey.is_some(),
})
}

Expand Down Expand Up @@ -1938,6 +1940,7 @@ impl<T> StoreContextMut<'_, T> {
pub struct AsyncCx {
current_suspend: *mut *const wasmtime_fiber::Suspend<Result<()>, (), Result<()>>,
current_poll_cx: *mut *mut Context<'static>,
track_pkey_context_switch: bool,
}

#[cfg(feature = "async")]
Expand Down Expand Up @@ -1998,7 +2001,21 @@ impl AsyncCx {
Poll::Pending => {}
}

// In order to prevent this fiber's MPK state from being munged by
// other fibers while it is suspended, we save and restore it once
// once execution resumes. Note that when MPK is not supported,
// these are noops.
let previous_mask = if self.track_pkey_context_switch {
let previous_mask = mpk::current_mask();
mpk::allow(ProtectionMask::all());
previous_mask
} else {
ProtectionMask::all()
};
(*suspend).suspend(())?;
if self.track_pkey_context_switch {
mpk::allow(previous_mask);
}
}
}
}
Expand Down
78 changes: 78 additions & 0 deletions tests/all/async_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,84 @@ async fn async_host_func_with_pooling_stacks() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn async_mpk_protection() -> Result<()> {
let _ = env_logger::try_init();

// Construct a pool with MPK protection enabled; note that the MPK
// protection is configured in `small_pool_config`.
let mut pooling = crate::small_pool_config();
pooling
.total_memories(10)
.total_stacks(2)
.memory_pages(1)
.table_elements(0);
let mut config = Config::new();
config.async_support(true);
config.allocation_strategy(InstanceAllocationStrategy::Pooling(pooling));
config.static_memory_maximum_size(1 << 26);
config.epoch_interruption(true);
let engine = Engine::new(&config)?;

// Craft a module that loops for several iterations and checks whether it
// has access to its memory range (0x0-0x10000).
const WAT: &str = "
(module
(func $start
(local $i i32)
(local.set $i (i32.const 3))
(loop $cont
(drop (i32.load (i32.const 0)))
(drop (i32.load (i32.const 0xfffc)))
(br_if $cont (local.tee $i (i32.sub (local.get $i) (i32.const 1))))))
(memory 1)
(start $start))
";

// Start two instances of the module in separate fibers, `a` and `b`.
async fn run_instance(engine: &Engine, name: &str) -> Instance {
let mut store = Store::new(&engine, ());
store.set_epoch_deadline(0);
store.epoch_deadline_async_yield_and_update(0);
let module = Module::new(store.engine(), WAT).unwrap();
println!("[{name}] building instance");
Instance::new_async(&mut store, &module, &[]).await.unwrap()
}
let mut a = Box::pin(run_instance(&engine, "a"));
let mut b = Box::pin(run_instance(&engine, "b"));

// Alternately poll each instance until completion. This should exercise
// fiber suspensions requiring the `Store` to appropriately save and restore
// the PKRU context between suspensions (see `AsyncCx::block_on`).
for i in 0..10 {
if i % 2 == 0 {
match PollOnce::new(a).await {
Ok(_) => {
println!("[a] done");
break;
}
Err(a_) => {
println!("[a] not done");
a = a_;
}
}
} else {
match PollOnce::new(b).await {
Ok(_) => {
println!("[b] done");
break;
}
Err(b_) => {
println!("[b] not done");
b = b_;
}
}
}
}

Ok(())
}

/// This will execute the `future` provided to completion and each invocation of
/// `poll` for the future will be executed on a separate thread.
pub async fn execute_across_threads<F>(future: F) -> F::Output
Expand Down

0 comments on commit 2aaeddb

Please sign in to comment.