From 36101dc7e3483c5e189ebd64d9a1edca011555ee Mon Sep 17 00:00:00 2001 From: kaivol Date: Sun, 22 Dec 2024 23:27:02 +0100 Subject: [PATCH] Override component model Lower::store_list and Lift::load_list for f32/f64 --- .../src/runtime/component/func/typed.rs | 78 ++++++++++++++++++- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/crates/wasmtime/src/runtime/component/func/typed.rs b/crates/wasmtime/src/runtime/component/func/typed.rs index 95fb52186c2a..275821799189 100644 --- a/crates/wasmtime/src/runtime/component/func/typed.rs +++ b/crates/wasmtime/src/runtime/component/func/typed.rs @@ -12,6 +12,7 @@ use core::marker; use core::mem::{self, MaybeUninit}; use core::ptr::NonNull; use core::str; +use std::iter; use wasmtime_environ::component::{ CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, VariantInfo, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, @@ -874,7 +875,7 @@ integers! { } macro_rules! floats { - ($($float:ident/$get_float:ident = $ty:ident with abi:$abi:ident)*) => ($(const _: () = { + ($($float:ident/$get_float:ident = $ty:ident with abi:$abi:ident $integer:ident)*) => ($(const _: () = { unsafe impl ComponentType for $float { type Lower = ValRaw; @@ -914,6 +915,44 @@ macro_rules! floats { *ptr = self.to_bits().to_le_bytes(); Ok(()) } + + fn store_list( + cx: &mut LowerContext<'_, T>, + ty: InterfaceType, + offset: usize, + items: &[Self], + ) -> Result<()> { + debug_assert!(matches!(ty, InterfaceType::$ty)); + + // Double-check that the CM alignment is at least the host's + // alignment for this type which should be true for all + // platforms. + assert!((Self::ALIGN32 as usize) >= mem::align_of::()); + + // Slice `cx`'s memory to the window that we'll be modifying. + // This should all have already been verified in terms of + // alignment and sizing meaning that these assertions here are + // not truly necessary but are instead double-checks. + // + // Note that we're casting a `[u8]` slice to `[$integer]` (with + // $integer having the same size as Self) with `align_to_mut` + // which is not safe in general but is safe in our specific + // case as all `u8` patterns are valid `$integer` patterns + // since `$integer` is an integral type. + let dst = &mut cx.as_slice_mut()[offset..][..items.len() * Self::SIZE32]; + let (before, middle, end) = unsafe { dst.align_to_mut::<$integer>() }; + assert!(before.is_empty() && end.is_empty()); + assert_eq!(middle.len(), items.len()); + + // And with all that out of the way perform the copying loop. + // This is not a `copy_from_slice` because endianness needs to + // be handled here, but LLVM should pretty easily transform this + // into a memcpy on little-endian platforms. + for (dst, src) in iter::zip(middle, items) { + *dst = src.to_bits().to_le(); + } + Ok(()) + } } unsafe impl Lift for $float { @@ -929,13 +968,46 @@ macro_rules! floats { debug_assert!((bytes.as_ptr() as usize) % Self::SIZE32 == 0); Ok($float::from_le_bytes(bytes.try_into().unwrap())) } + + fn load_list(cx: &mut LiftContext<'_>, list: &WasmList) -> Result> where Self: Sized { + // See comments in `WasmList::get` for the panicking indexing + let byte_size = list.len * mem::size_of::<$integer>(); + let bytes = &cx.memory()[list.ptr..][..byte_size]; + + // The canonical ABI requires that everything is aligned to its + // own size, so this should be an aligned array. Furthermore the + // alignment of primitive integers for hosts should be smaller + // than or equal to the size of the primitive itself, meaning + // that a wasm canonical-abi-aligned list is also aligned for + // the host. That should mean that the head/tail slices here are + // empty. + // + // Also note that the `unsafe` here is needed since the type + // we're aligning to isn't guaranteed to be valid, but in our + // case it's just integers and bytes so this should be safe. + + let slice = unsafe { + let (head, body, tail) = bytes.align_to::<$integer>(); + assert!(head.is_empty() && tail.is_empty()); + body + }; + + // Copy the resulting slice to a new Vec, handling endianness + // in the process + Ok( + slice + .iter() + .map(|i| $float::from_bits($integer::from_le(*i))) + .collect() + ) + } } };)*) } floats! { - f32/get_f32 = Float32 with abi:SCALAR4 - f64/get_f64 = Float64 with abi:SCALAR8 + f32/get_f32 = Float32 with abi:SCALAR4 u32 + f64/get_f64 = Float64 with abi:SCALAR8 u64 } unsafe impl ComponentType for bool {