diff --git a/Cargo.lock b/Cargo.lock index a5398f234778..8cebb981d88b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3333,6 +3333,7 @@ dependencies = [ "wasmparser", "wasmtime-cache", "wasmtime-component-macro", + "wasmtime-component-util", "wasmtime-cranelift", "wasmtime-environ", "wasmtime-fiber", @@ -3472,8 +3473,13 @@ dependencies = [ "proc-macro2", "quote", "syn", + "wasmtime-component-util", ] +[[package]] +name = "wasmtime-component-util" +version = "0.40.0" + [[package]] name = "wasmtime-cranelift" version = "0.40.0" diff --git a/crates/component-macro/Cargo.toml b/crates/component-macro/Cargo.toml index 6fadf5e65e11..2997bb76c470 100644 --- a/crates/component-macro/Cargo.toml +++ b/crates/component-macro/Cargo.toml @@ -17,6 +17,7 @@ proc-macro = true proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["extra-traits"] } +wasmtime-component-util = { path = "../component-util", version = "=0.40.0" } [badges] maintenance = { status = "actively-developed" } diff --git a/crates/component-macro/src/lib.rs b/crates/component-macro/src/lib.rs index e3f290922d85..e782ff96a0e0 100644 --- a/crates/component-macro/src/lib.rs +++ b/crates/component-macro/src/lib.rs @@ -5,6 +5,7 @@ use std::fmt; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::{braced, parse_macro_input, parse_quote, Data, DeriveInput, Error, Result, Token}; +use wasmtime_component_util::{DiscriminantSize, FlagsSize}; #[derive(Debug, Copy, Clone)] enum VariantStyle { @@ -147,64 +148,6 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn generics } -#[derive(Debug, Copy, Clone)] -enum DiscriminantSize { - Size1, - Size2, - Size4, -} - -impl DiscriminantSize { - fn quote(self, discriminant: usize) -> TokenStream { - match self { - Self::Size1 => { - let discriminant = u8::try_from(discriminant).unwrap(); - quote!(#discriminant) - } - Self::Size2 => { - let discriminant = u16::try_from(discriminant).unwrap(); - quote!(#discriminant) - } - Self::Size4 => { - let discriminant = u32::try_from(discriminant).unwrap(); - quote!(#discriminant) - } - } - } -} - -impl From for u32 { - fn from(size: DiscriminantSize) -> u32 { - match size { - DiscriminantSize::Size1 => 1, - DiscriminantSize::Size2 => 2, - DiscriminantSize::Size4 => 4, - } - } -} - -impl From for usize { - fn from(size: DiscriminantSize) -> usize { - match size { - DiscriminantSize::Size1 => 1, - DiscriminantSize::Size2 => 2, - DiscriminantSize::Size4 => 4, - } - } -} - -fn discriminant_size(case_count: usize) -> Option { - if case_count <= 0xFF { - Some(DiscriminantSize::Size1) - } else if case_count <= 0xFFFF { - Some(DiscriminantSize::Size2) - } else if case_count <= 0xFFFF_FFFF { - Some(DiscriminantSize::Size4) - } else { - None - } -} - struct VariantCase<'a> { attrs: &'a [syn::Attribute], ident: &'a syn::Ident, @@ -288,7 +231,7 @@ fn expand_variant( )); } - let discriminant_size = discriminant_size(body.variants.len()).ok_or_else(|| { + let discriminant_size = DiscriminantSize::from_count(body.variants.len()).ok_or_else(|| { Error::new( input.ident.span(), "`enum`s with more than 2^32 variants are not supported", @@ -417,7 +360,7 @@ fn expand_record_for_component_type( const SIZE32: usize = { let mut size = 0; #sizes - size + #internal::align_to(size, Self::ALIGN32) }; const ALIGN32: u32 = { @@ -439,6 +382,23 @@ fn expand_record_for_component_type( Ok(quote!(const _: () = { #expanded };)) } +fn quote(size: DiscriminantSize, discriminant: usize) -> TokenStream { + match size { + DiscriminantSize::Size1 => { + let discriminant = u8::try_from(discriminant).unwrap(); + quote!(#discriminant) + } + DiscriminantSize::Size2 => { + let discriminant = u16::try_from(discriminant).unwrap(); + quote!(#discriminant) + } + DiscriminantSize::Size4 => { + let discriminant = u32::try_from(discriminant).unwrap(); + quote!(#discriminant) + } + } +} + #[proc_macro_derive(Lift, attributes(component))] pub fn lift(input: proc_macro::TokenStream) -> proc_macro::TokenStream { expand(&LiftExpander, &parse_macro_input!(input as DeriveInput)) @@ -523,7 +483,7 @@ impl Expander for LiftExpander { for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() { let index_u32 = u32::try_from(index).unwrap(); - let index_quoted = discriminant_size.quote(index); + let index_quoted = quote(discriminant_size, index); if let Some(ty) = ty { lifts.extend( @@ -666,7 +626,7 @@ impl Expander for LowerExpander { for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() { let index_u32 = u32::try_from(index).unwrap(); - let index_quoted = discriminant_size.quote(index); + let index_quoted = quote(discriminant_size, index); let discriminant_size = usize::from(discriminant_size); @@ -989,19 +949,6 @@ impl Parse for Flags { } } -enum FlagsSize { - /// Flags can fit in a u8 - Size1, - /// Flags can fit in a u16 - Size2, - /// Flags can fit in a specified number of u32 fields - Size4Plus(usize), -} - -fn ceiling_divide(n: usize, d: usize) -> usize { - (n + d - 1) / d -} - #[proc_macro] pub fn flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream { expand_flags(&parse_macro_input!(input as Flags)) @@ -1010,13 +957,7 @@ pub fn flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream { } fn expand_flags(flags: &Flags) -> Result { - let size = if flags.flags.len() <= 8 { - FlagsSize::Size1 - } else if flags.flags.len() <= 16 { - FlagsSize::Size2 - } else { - FlagsSize::Size4Plus(ceiling_divide(flags.flags.len(), 32)) - }; + let size = FlagsSize::from_count(flags.flags.len()); let ty; let eq; diff --git a/crates/component-util/Cargo.toml b/crates/component-util/Cargo.toml new file mode 100644 index 000000000000..6de2d646042c --- /dev/null +++ b/crates/component-util/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "wasmtime-component-util" +version = "0.40.0" +authors = ["The Wasmtime Project Developers"] +description = "Utility types and functions to support the component model in Wasmtime" +license = "Apache-2.0 WITH LLVM-exception" +repository = "https://github.com/bytecodealliance/wasmtime" +documentation = "https://docs.rs/wasmtime-component-util/" +categories = ["wasm"] +keywords = ["webassembly", "wasm"] +edition = "2021" diff --git a/crates/component-util/src/lib.rs b/crates/component-util/src/lib.rs new file mode 100644 index 000000000000..c59c39040fb9 --- /dev/null +++ b/crates/component-util/src/lib.rs @@ -0,0 +1,75 @@ +/// Represents the possible sizes in bytes of the discriminant of a variant type in the component model +#[derive(Debug, Copy, Clone)] +pub enum DiscriminantSize { + /// 8-bit discriminant + Size1, + /// 16-bit discriminant + Size2, + /// 32-bit discriminant + Size4, +} + +impl DiscriminantSize { + /// Calculate the size of discriminant needed to represent a variant with the specified number of cases. + pub fn from_count(count: usize) -> Option { + if count <= 0xFF { + Some(Self::Size1) + } else if count <= 0xFFFF { + Some(Self::Size2) + } else if count <= 0xFFFF_FFFF { + Some(Self::Size4) + } else { + None + } + } +} + +impl From for u32 { + /// Size of the discriminant as a `u32` + fn from(size: DiscriminantSize) -> u32 { + match size { + DiscriminantSize::Size1 => 1, + DiscriminantSize::Size2 => 2, + DiscriminantSize::Size4 => 4, + } + } +} + +impl From for usize { + /// Size of the discriminant as a `usize` + fn from(size: DiscriminantSize) -> usize { + match size { + DiscriminantSize::Size1 => 1, + DiscriminantSize::Size2 => 2, + DiscriminantSize::Size4 => 4, + } + } +} + +/// Represents the number of bytes required to store a flags value in the component model +pub enum FlagsSize { + /// Flags can fit in a u8 + Size1, + /// Flags can fit in a u16 + Size2, + /// Flags can fit in a specified number of u32 fields + Size4Plus(usize), +} + +impl FlagsSize { + /// Calculate the size needed to represent a value with the specified number of flags. + pub fn from_count(count: usize) -> FlagsSize { + if count <= 8 { + FlagsSize::Size1 + } else if count <= 16 { + FlagsSize::Size2 + } else { + FlagsSize::Size4Plus(ceiling_divide(count, 32)) + } + } +} + +/// Divide `n` by `d`, rounding up in the case of a non-zero remainder. +fn ceiling_divide(n: usize, d: usize) -> usize { + (n + d - 1) / d +} diff --git a/crates/wasmtime/Cargo.toml b/crates/wasmtime/Cargo.toml index 3c250b36d15c..325f595b7cb2 100644 --- a/crates/wasmtime/Cargo.toml +++ b/crates/wasmtime/Cargo.toml @@ -20,6 +20,7 @@ wasmtime-cache = { path = "../cache", version = "=0.40.0", optional = true } wasmtime-fiber = { path = "../fiber", version = "=0.40.0", optional = true } wasmtime-cranelift = { path = "../cranelift", version = "=0.40.0", optional = true } wasmtime-component-macro = { path = "../component-macro", version = "=0.40.0", optional = true } +wasmtime-component-util = { path = "../component-util", version = "=0.40.0", optional = true } target-lexicon = { version = "0.12.0", default-features = false } wasmparser = "0.87.0" anyhow = "1.0.19" @@ -115,4 +116,5 @@ component-model = [ "wasmtime-cranelift?/component-model", "wasmtime-runtime/component-model", "dep:wasmtime-component-macro", + "dep:wasmtime-component-util", ] diff --git a/crates/wasmtime/src/component/func.rs b/crates/wasmtime/src/component/func.rs index a5fe8de4b0b6..39de8b05d170 100644 --- a/crates/wasmtime/src/component/func.rs +++ b/crates/wasmtime/src/component/func.rs @@ -1,8 +1,10 @@ use crate::component::instance::{Instance, InstanceData}; +use crate::component::types::{SizeAndAlignment, Type}; +use crate::component::values::Val; use crate::store::{StoreOpaque, Stored}; -use crate::{AsContext, ValRaw}; -use anyhow::{Context, Result}; -use std::mem::MaybeUninit; +use crate::{AsContext, AsContextMut, StoreContextMut, ValRaw}; +use anyhow::{bail, Context, Result}; +use std::mem::{self, MaybeUninit}; use std::ptr::NonNull; use std::sync::Arc; use wasmtime_environ::component::{ @@ -72,6 +74,12 @@ pub use self::host::*; pub use self::options::*; pub use self::typed::*; +#[repr(C)] +union ParamsAndResults { + params: Params, + ret: Return, +} + /// A WebAssembly component function. // // FIXME: write more docs here @@ -241,4 +249,346 @@ impl Func { Ok(()) } + + /// Get the parameter types for this function. + pub fn params(&self, store: impl AsContext) -> Box<[Type]> { + let data = &store.as_context()[self.0]; + data.types[data.ty] + .params + .iter() + .map(|(_, ty)| Type::from(ty, &data.types)) + .collect() + } + + /// Invokes this function with the `params` given and returns the result. + /// + /// The `params` here must match the type signature of this `Func`, or this will return an error. If a trap + /// occurs while executing this function, then an error will also be returned. + // TODO: say more -- most of the docs for `TypedFunc::call` apply here, too + pub fn call(&self, mut store: impl AsContextMut, args: &[Val]) -> Result { + let store = &mut store.as_context_mut(); + + let params; + let result; + + { + let data = &store[self.0]; + let ty = &data.types[data.ty]; + + if ty.params.len() != args.len() { + bail!( + "expected {} argument(s), got {}", + ty.params.len(), + args.len() + ); + } + + params = ty + .params + .iter() + .zip(args) + .map(|((_, ty), arg)| { + let ty = Type::from(ty, &data.types); + + ty.check(arg).context("type mismatch with parameters")?; + + Ok(ty) + }) + .collect::>>()?; + + result = Type::from(&ty.result, &data.types); + } + + let param_count = params.iter().map(|ty| ty.flatten_count()).sum::(); + let result_count = result.flatten_count(); + + self.call_raw( + store, + args, + |store, options, args, dst: &mut MaybeUninit<[ValRaw; MAX_STACK_PARAMS]>| { + if param_count > MAX_STACK_PARAMS { + self.store_args(store, &options, ¶ms, args, dst) + } else { + dst.write([ValRaw::u64(0); MAX_STACK_PARAMS]); + let dst = unsafe { + mem::transmute::<_, &mut [MaybeUninit; MAX_STACK_PARAMS]>(dst) + }; + args.iter() + .try_for_each(|arg| arg.lower(store, &options, &mut dst.iter_mut())) + } + }, + |store, options, src: &[ValRaw; MAX_STACK_RESULTS]| { + if result_count > MAX_STACK_RESULTS { + Self::load_result(&Memory::new(store, &options), &result, &mut src.iter()) + } else { + Val::lift(&result, store, &options, &mut src.iter()) + } + }, + ) + } + + /// Invokes the underlying wasm function, lowering arguments and lifting the + /// result. + /// + /// The `lower` function and `lift` function provided here are what actually + /// do the lowering and lifting. The `LowerParams` and `LowerReturn` types + /// are what will be allocated on the stack for this function call. They + /// should be appropriately sized for the lowering/lifting operation + /// happening. + fn call_raw( + &self, + store: &mut StoreContextMut<'_, T>, + params: &Params, + lower: impl FnOnce( + &mut StoreContextMut<'_, T>, + &Options, + &Params, + &mut MaybeUninit, + ) -> Result<()>, + lift: impl FnOnce(&StoreOpaque, &Options, &LowerReturn) -> Result, + ) -> Result + where + LowerParams: Copy, + LowerReturn: Copy, + { + let FuncData { + trampoline, + export, + options, + instance, + component_instance, + .. + } = store.0[self.0]; + + let space = &mut MaybeUninit::>::uninit(); + + // Double-check the size/alignemnt of `space`, just in case. + // + // Note that this alone is not enough to guarantee the validity of the + // `unsafe` block below, but it's definitely required. In any case LLVM + // should be able to trivially see through these assertions and remove + // them in release mode. + let val_size = mem::size_of::(); + let val_align = mem::align_of::(); + assert!(mem::size_of_val(space) % val_size == 0); + assert!(mem::size_of_val(map_maybe_uninit!(space.params)) % val_size == 0); + assert!(mem::size_of_val(map_maybe_uninit!(space.ret)) % val_size == 0); + assert!(mem::align_of_val(space) == val_align); + assert!(mem::align_of_val(map_maybe_uninit!(space.params)) == val_align); + assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align); + + let instance = store.0[instance.0].as_ref().unwrap().instance(); + let flags = instance.flags(component_instance); + + unsafe { + // Test the "may enter" flag which is a "lock" on this instance. + // This is immediately set to `false` afterwards and note that + // there's no on-cleanup setting this flag back to true. That's an + // intentional design aspect where if anything goes wrong internally + // from this point on the instance is considered "poisoned" and can + // never be entered again. The only time this flag is set to `true` + // again is after post-return logic has completed successfully. + if !(*flags).may_enter() { + bail!("cannot reenter component instance"); + } + (*flags).set_may_enter(false); + + debug_assert!((*flags).may_leave()); + (*flags).set_may_leave(false); + let result = lower(store, &options, params, map_maybe_uninit!(space.params)); + (*flags).set_may_leave(true); + result?; + + // This is unsafe as we are providing the guarantee that all the + // inputs are valid. The various pointers passed in for the function + // are all valid since they're coming from our store, and the + // `params_and_results` should have the correct layout for the core + // wasm function we're calling. Note that this latter point relies + // on the correctness of this module and `ComponentType` + // implementations, hence `ComponentType` being an `unsafe` trait. + crate::Func::call_unchecked_raw( + store, + export.anyfunc, + trampoline, + space.as_mut_ptr().cast(), + )?; + + // Note that `.assume_init_ref()` here is unsafe but we're relying + // on the correctness of the structure of `LowerReturn` and the + // type-checking performed to acquire the `TypedFunc` to make this + // safe. It should be the case that `LowerReturn` is the exact + // representation of the return value when interpreted as + // `[ValRaw]`, and additionally they should have the correct types + // for the function we just called (which filled in the return + // values). + let ret = map_maybe_uninit!(space.ret).assume_init_ref(); + + // Lift the result into the host while managing post-return state + // here as well. + // + // After a successful lift the return value of the function, which + // is currently required to be 0 or 1 values according to the + // canonical ABI, is saved within the `Store`'s `FuncData`. This'll + // later get used in post-return. + (*flags).set_needs_post_return(true); + let val = lift(store.0, &options, ret)?; + let ret_slice = cast_storage(ret); + let data = &mut store.0[self.0]; + assert!(data.post_return_arg.is_none()); + match ret_slice.len() { + 0 => data.post_return_arg = Some(ValRaw::i32(0)), + 1 => data.post_return_arg = Some(ret_slice[0]), + _ => unreachable!(), + } + return Ok(val); + } + + unsafe fn cast_storage(storage: &T) -> &[ValRaw] { + assert!(std::mem::size_of_val(storage) % std::mem::size_of::() == 0); + assert!(std::mem::align_of_val(storage) == std::mem::align_of::()); + + std::slice::from_raw_parts( + (storage as *const T).cast(), + mem::size_of_val(storage) / mem::size_of::(), + ) + } + } + + /// Invokes the `post-return` canonical ABI option, if specified, after a + /// [`Func::call`] has finished. + /// + /// For some more information on when to use this function see the + /// documentation for post-return in the [`Func::call`] method. + /// Otherwise though this function is a required method call after a + /// [`Func::call`] completes successfully. After the embedder has + /// finished processing the return value then this function must be invoked. + /// + /// # Errors + /// + /// This function will return an error in the case of a WebAssembly trap + /// happening during the execution of the `post-return` function, if + /// specified. + /// + /// # Panics + /// + /// This function will panic if it's not called under the correct + /// conditions. This can only be called after a previous invocation of + /// [`Func::call`] completes successfully, and this function can only + /// be called for the same [`Func`] that was `call`'d. + /// + /// If this function is called when [`Func::call`] was not previously + /// called, then it will panic. If a different [`Func`] for the same + /// component instance was invoked then this function will also panic + /// because the `post-return` needs to happen for the other function. + pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> { + let mut store = store.as_context_mut(); + let data = &mut store.0[self.0]; + let instance = data.instance; + let post_return = data.post_return; + let component_instance = data.component_instance; + let post_return_arg = data.post_return_arg.take(); + let instance = store.0[instance.0].as_ref().unwrap().instance(); + let flags = instance.flags(component_instance); + + unsafe { + // First assert that the instance is in a "needs post return" state. + // This will ensure that the previous action on the instance was a + // function call above. This flag is only set after a component + // function returns so this also can't be called (as expected) + // during a host import for example. + // + // Note, though, that this assert is not sufficient because it just + // means some function on this instance needs its post-return + // called. We need a precise post-return for a particular function + // which is the second assert here (the `.expect`). That will assert + // that this function itself needs to have its post-return called. + // + // The theory at least is that these two asserts ensure component + // model semantics are upheld where the host properly calls + // `post_return` on the right function despite the call being a + // separate step in the API. + assert!( + (*flags).needs_post_return(), + "post_return can only be called after a function has previously been called", + ); + let post_return_arg = post_return_arg.expect("calling post_return on wrong function"); + + // This is a sanity-check assert which shouldn't ever trip. + assert!(!(*flags).may_enter()); + + // Unset the "needs post return" flag now that post-return is being + // processed. This will cause future invocations of this method to + // panic, even if the function call below traps. + (*flags).set_needs_post_return(false); + + // If the function actually had a `post-return` configured in its + // canonical options that's executed here. + // + // Note that if this traps (returns an error) this function + // intentionally leaves the instance in a "poisoned" state where it + // can no longer be entered because `may_enter` is `false`. + if let Some((func, trampoline)) = post_return { + crate::Func::call_unchecked_raw( + &mut store, + func.anyfunc, + trampoline, + &post_return_arg as *const ValRaw as *mut ValRaw, + )?; + } + + // And finally if everything completed successfully then the "may + // enter" flag is set to `true` again here which enables further use + // of the component. + (*flags).set_may_enter(true); + } + Ok(()) + } + + fn store_args( + &self, + store: &mut StoreContextMut<'_, T>, + options: &Options, + params: &[Type], + args: &[Val], + dst: &mut MaybeUninit<[ValRaw; MAX_STACK_PARAMS]>, + ) -> Result<()> { + let mut size = 0; + let mut alignment = 1; + for ty in params { + alignment = alignment.max(ty.size_and_alignment().alignment); + ty.next_field(&mut size); + } + + let mut memory = MemoryMut::new(store.as_context_mut(), options); + let ptr = memory.realloc(0, 0, alignment, size)?; + let mut offset = ptr; + for (ty, arg) in params.iter().zip(args) { + arg.store(&mut memory, ty.next_field(&mut offset))?; + } + + map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64)); + + Ok(()) + } + + fn load_result<'a>( + mem: &Memory, + ty: &Type, + src: &mut std::slice::Iter<'_, ValRaw>, + ) -> Result { + let SizeAndAlignment { size, alignment } = ty.size_and_alignment(); + // FIXME: needs to read an i64 for memory64 + let ptr = usize::try_from(src.next().unwrap().get_u32())?; + if ptr % usize::try_from(alignment)? != 0 { + bail!("return pointer not aligned"); + } + + let bytes = mem + .as_slice() + .get(ptr..) + .and_then(|b| b.get(..size)) + .ok_or_else(|| anyhow::anyhow!("pointer out of bounds of memory"))?; + + Val::load(ty, mem, bytes) + } } diff --git a/crates/wasmtime/src/component/func/options.rs b/crates/wasmtime/src/component/func/options.rs index fbc202fd826c..f2df2bba5a33 100644 --- a/crates/wasmtime/src/component/func/options.rs +++ b/crates/wasmtime/src/component/func/options.rs @@ -213,7 +213,7 @@ impl<'a, T> MemoryMut<'a, T> { /// Like `MemoryMut` but for a read-only version that's used during lifting. pub struct Memory<'a> { - store: &'a StoreOpaque, + pub(crate) store: &'a StoreOpaque, options: &'a Options, } diff --git a/crates/wasmtime/src/component/func/typed.rs b/crates/wasmtime/src/component/func/typed.rs index 4e79d69f1969..098e02800396 100644 --- a/crates/wasmtime/src/component/func/typed.rs +++ b/crates/wasmtime/src/component/func/typed.rs @@ -158,14 +158,14 @@ where // count) if Params::flatten_count() <= MAX_STACK_PARAMS { if Return::flatten_count() <= MAX_STACK_RESULTS { - self.call_raw( + self.func.call_raw( store, ¶ms, Self::lower_stack_args, Self::lift_stack_result, ) } else { - self.call_raw( + self.func.call_raw( store, ¶ms, Self::lower_stack_args, @@ -174,14 +174,14 @@ where } } else { if Return::flatten_count() <= MAX_STACK_RESULTS { - self.call_raw( + self.func.call_raw( store, ¶ms, Self::lower_heap_args, Self::lift_stack_result, ) } else { - self.call_raw( + self.func.call_raw( store, ¶ms, Self::lower_heap_args, @@ -280,230 +280,12 @@ where Return::load(&memory, bytes) } - /// Invokes the underlying wasm function, lowering arguments and lifting the - /// result. - /// - /// The `lower` function and `lift` function provided here are what actually - /// do the lowering and lifting. The `LowerParams` and `LowerReturn` types - /// are what will be allocated on the stack for this function call. They - /// should be appropriately sized for the lowering/lifting operation - /// happening. - fn call_raw( - &self, - store: &mut StoreContextMut<'_, T>, - params: &Params, - lower: impl FnOnce( - &mut StoreContextMut<'_, T>, - &Options, - &Params, - &mut MaybeUninit, - ) -> Result<()>, - lift: impl FnOnce(&StoreOpaque, &Options, &LowerReturn) -> Result, - ) -> Result - where - LowerParams: Copy, - LowerReturn: Copy, - { - let super::FuncData { - trampoline, - export, - options, - instance, - component_instance, - .. - } = store.0[self.func.0]; - - let space = &mut MaybeUninit::>::uninit(); - - // Double-check the size/alignemnt of `space`, just in case. - // - // Note that this alone is not enough to guarantee the validity of the - // `unsafe` block below, but it's definitely required. In any case LLVM - // should be able to trivially see through these assertions and remove - // them in release mode. - let val_size = mem::size_of::(); - let val_align = mem::align_of::(); - assert!(mem::size_of_val(space) % val_size == 0); - assert!(mem::size_of_val(map_maybe_uninit!(space.params)) % val_size == 0); - assert!(mem::size_of_val(map_maybe_uninit!(space.ret)) % val_size == 0); - assert!(mem::align_of_val(space) == val_align); - assert!(mem::align_of_val(map_maybe_uninit!(space.params)) == val_align); - assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align); - - let instance = store.0[instance.0].as_ref().unwrap().instance(); - let flags = instance.flags(component_instance); - - unsafe { - // Test the "may enter" flag which is a "lock" on this instance. - // This is immediately set to `false` afterwards and note that - // there's no on-cleanup setting this flag back to true. That's an - // intentional design aspect where if anything goes wrong internally - // from this point on the instance is considered "poisoned" and can - // never be entered again. The only time this flag is set to `true` - // again is after post-return logic has completed successfully. - if !(*flags).may_enter() { - bail!("cannot reenter component instance"); - } - (*flags).set_may_enter(false); - - debug_assert!((*flags).may_leave()); - (*flags).set_may_leave(false); - let result = lower(store, &options, params, map_maybe_uninit!(space.params)); - (*flags).set_may_leave(true); - result?; - - // This is unsafe as we are providing the guarantee that all the - // inputs are valid. The various pointers passed in for the function - // are all valid since they're coming from our store, and the - // `params_and_results` should have the correct layout for the core - // wasm function we're calling. Note that this latter point relies - // on the correctness of this module and `ComponentType` - // implementations, hence `ComponentType` being an `unsafe` trait. - crate::Func::call_unchecked_raw( - store, - export.anyfunc, - trampoline, - space.as_mut_ptr().cast(), - )?; - - // Note that `.assume_init_ref()` here is unsafe but we're relying - // on the correctness of the structure of `LowerReturn` and the - // type-checking performed to acquire the `TypedFunc` to make this - // safe. It should be the case that `LowerReturn` is the exact - // representation of the return value when interpreted as - // `[ValRaw]`, and additionally they should have the correct types - // for the function we just called (which filled in the return - // values). - let ret = map_maybe_uninit!(space.ret).assume_init_ref(); - - // Lift the result into the host while managing post-return state - // here as well. - // - // After a successful lift the return value of the function, which - // is currently required to be 0 or 1 values according to the - // canonical ABI, is saved within the `Store`'s `FuncData`. This'll - // later get used in post-return. - (*flags).set_needs_post_return(true); - let val = lift(store.0, &options, ret)?; - let ret_slice = cast_storage(ret); - let data = &mut store.0[self.func.0]; - assert!(data.post_return_arg.is_none()); - match ret_slice.len() { - 0 => data.post_return_arg = Some(ValRaw::i32(0)), - 1 => data.post_return_arg = Some(ret_slice[0]), - _ => unreachable!(), - } - return Ok(val); - } - - unsafe fn cast_storage(storage: &T) -> &[ValRaw] { - assert!(std::mem::size_of_val(storage) % std::mem::size_of::() == 0); - assert!(std::mem::align_of_val(storage) == std::mem::align_of::()); - - std::slice::from_raw_parts( - (storage as *const T).cast(), - mem::size_of_val(storage) / mem::size_of::(), - ) - } - } - - /// Invokes the `post-return` canonical ABI option, if specified, after a - /// [`TypedFunc::call`] has finished. - /// - /// For some more information on when to use this function see the - /// documentation for post-return in the [`TypedFunc::call`] method. - /// Otherwise though this function is a required method call after a - /// [`TypedFunc::call`] completes successfully. After the embedder has - /// finished processing the return value then this function must be invoked. - /// - /// # Errors - /// - /// This function will return an error in the case of a WebAssembly trap - /// happening during the execution of the `post-return` function, if - /// specified. - /// - /// # Panics - /// - /// This function will panic if it's not called under the correct - /// conditions. This can only be called after a previous invocation of - /// [`TypedFunc::call`] completes successfully, and this function can only - /// be called for the same [`TypedFunc`] that was `call`'d. - /// - /// If this function is called when [`TypedFunc::call`] was not previously - /// called, then it will panic. If a different [`TypedFunc`] for the same - /// component instance was invoked then this function will also panic - /// because the `post-return` needs to happen for the other function. - pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> { - let mut store = store.as_context_mut(); - let data = &mut store.0[self.func.0]; - let instance = data.instance; - let post_return = data.post_return; - let component_instance = data.component_instance; - let post_return_arg = data.post_return_arg.take(); - let instance = store.0[instance.0].as_ref().unwrap().instance(); - let flags = instance.flags(component_instance); - - unsafe { - // First assert that the instance is in a "needs post return" state. - // This will ensure that the previous action on the instance was a - // function call above. This flag is only set after a component - // function returns so this also can't be called (as expected) - // during a host import for example. - // - // Note, though, that this assert is not sufficient because it just - // means some function on this instance needs its post-return - // called. We need a precise post-return for a particular function - // which is the second assert here (the `.expect`). That will assert - // that this function itself needs to have its post-return called. - // - // The theory at least is that these two asserts ensure component - // model semantics are upheld where the host properly calls - // `post_return` on the right function despite the call being a - // separate step in the API. - assert!( - (*flags).needs_post_return(), - "post_return can only be called after a function has previously been called", - ); - let post_return_arg = post_return_arg.expect("calling post_return on wrong function"); - - // This is a sanity-check assert which shouldn't ever trip. - assert!(!(*flags).may_enter()); - - // Unset the "needs post return" flag now that post-return is being - // processed. This will cause future invocations of this method to - // panic, even if the function call below traps. - (*flags).set_needs_post_return(false); - - // If the function actually had a `post-return` configured in its - // canonical options that's executed here. - // - // Note that if this traps (returns an error) this function - // intentionally leaves the instance in a "poisoned" state where it - // can no longer be entered because `may_enter` is `false`. - if let Some((func, trampoline)) = post_return { - crate::Func::call_unchecked_raw( - &mut store, - func.anyfunc, - trampoline, - &post_return_arg as *const ValRaw as *mut ValRaw, - )?; - } - - // And finally if everything completed successfully then the "may - // enter" flag is set to `true` again here which enables further use - // of the component. - (*flags).set_may_enter(true); - } - Ok(()) + /// See [`Func::post_return`] + pub fn post_return(&self, store: impl AsContextMut) -> Result<()> { + self.func.post_return(store) } } -#[repr(C)] -union ParamsAndResults { - params: Params, - ret: Return, -} - /// A trait representing a static list of parameters that can be passed to a /// [`TypedFunc`]. /// @@ -567,11 +349,9 @@ pub unsafe trait ComponentParams: ComponentType { // though, that correctness bugs in this trait implementation are highly likely // to lead to security bugs, which again leads to the `unsafe` in the trait. // -// Also note that this trait specifically is not sealed because we'll -// eventually have a proc macro that generates implementations of this trait -// for external types in a `#[derive]`-like fashion. -// -// FIXME: need to write a #[derive(ComponentType)] +// Also note that this trait specifically is not sealed because we have a proc +// macro that generates implementations of this trait for external types in a +// `#[derive]`-like fashion. pub unsafe trait ComponentType { /// Representation of the "lowered" form of this component value. /// @@ -690,7 +470,7 @@ pub unsafe trait Lift: Sized + ComponentType { // another type, used for wrappers in Rust like `&T`, `Box`, etc. Note that // these wrappers only implement lowering because lifting native Rust types // cannot be done. -macro_rules! forward_impls { +macro_rules! forward_type_impls { ($(($($generics:tt)*) $a:ty => $b:ty,)*) => ($( unsafe impl <$($generics)*> ComponentType for $a { type Lower = <$b as ComponentType>::Lower; @@ -703,7 +483,20 @@ macro_rules! forward_impls { <$b as ComponentType>::typecheck(ty, types) } } + )*) +} + +forward_type_impls! { + (T: ComponentType + ?Sized) &'_ T => T, + (T: ComponentType + ?Sized) Box => T, + (T: ComponentType + ?Sized) std::rc::Rc => T, + (T: ComponentType + ?Sized) std::sync::Arc => T, + () String => str, + (T: ComponentType) Vec => [T], +} +macro_rules! forward_lowers { + ($(($($generics:tt)*) $a:ty => $b:ty,)*) => ($( unsafe impl <$($generics)*> Lower for $a { fn lower( &self, @@ -721,7 +514,7 @@ macro_rules! forward_impls { )*) } -forward_impls! { +forward_lowers! { (T: Lower + ?Sized) &'_ T => T, (T: Lower + ?Sized) Box => T, (T: Lower + ?Sized) std::rc::Rc => T, @@ -730,6 +523,50 @@ forward_impls! { (T: Lower) Vec => [T], } +macro_rules! forward_string_lifts { + ($($a:ty,)*) => ($( + unsafe impl Lift for $a { + fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result { + Ok(::lift(store, options, src)?.to_str_from_store(store)?.into()) + } + + fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result { + Ok(::load(memory, bytes)?.to_str_from_store(&memory.store)?.into()) + } + } + )*) +} + +forward_string_lifts! { + Box, + std::rc::Rc, + std::sync::Arc, + String, +} + +macro_rules! forward_list_lifts { + ($($a:ty,)*) => ($( + unsafe impl Lift for $a { + fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result { + let list = as Lift>::lift(store, options, src)?; + (0..list.len).map(|index| list.get_from_store(store, index).unwrap()).collect() + } + + fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result { + let list = as Lift>::load(memory, bytes)?; + (0..list.len).map(|index| list.get_from_store(&memory.store, index).unwrap()).collect() + } + } + )*) +} + +forward_list_lifts! { + Box<[T]>, + std::rc::Rc<[T]>, + std::sync::Arc<[T]>, + Vec, +} + // Macro to help generate `ComponentType` implementations for primitive types // such as integers, char, bool, etc. macro_rules! integers { @@ -1092,10 +929,10 @@ impl WasmStr { // method that returns `[u16]` after validating to avoid the utf16-to-utf8 // transcode. pub fn to_str<'a, T: 'a>(&self, store: impl Into>) -> Result> { - self._to_str(store.into().0) + self.to_str_from_store(store.into().0) } - fn _to_str<'a>(&self, store: &'a StoreOpaque) -> Result> { + fn to_str_from_store<'a>(&self, store: &'a StoreOpaque) -> Result> { match self.options.string_encoding() { StringEncoding::Utf8 => self.decode_utf8(store), StringEncoding::Utf16 => self.decode_utf16(store), @@ -1289,10 +1126,10 @@ impl WasmList { // should we even expose a random access iteration API? In theory all // consumers should be validating through the iterator. pub fn get(&self, store: impl AsContext, index: usize) -> Option> { - self._get(store.as_context().0, index) + self.get_from_store(store.as_context().0, index) } - fn _get(&self, store: &StoreOpaque, index: usize) -> Option> { + fn get_from_store(&self, store: &StoreOpaque, index: usize) -> Option> { if index >= self.len { return None; } @@ -1316,7 +1153,7 @@ impl WasmList { store: impl Into>, ) -> impl ExactSizeIterator> + 'a { let store = store.into().0; - (0..self.len).map(move |i| self._get(store, i).unwrap()) + (0..self.len).map(move |i| self.get_from_store(store, i).unwrap()) } } diff --git a/crates/wasmtime/src/component/mod.rs b/crates/wasmtime/src/component/mod.rs index 9f85b65ee05e..527409b91ec4 100644 --- a/crates/wasmtime/src/component/mod.rs +++ b/crates/wasmtime/src/component/mod.rs @@ -9,6 +9,8 @@ mod instance; mod linker; mod matching; mod store; +pub mod types; +mod values; pub use self::component::Component; pub use self::func::{ ComponentParams, ComponentType, Func, IntoComponentFunc, Lift, Lower, TypedFunc, WasmList, @@ -16,6 +18,8 @@ pub use self::func::{ }; pub use self::instance::{ExportInstance, Exports, Instance, InstancePre}; pub use self::linker::{Linker, LinkerInstance}; +pub use self::types::Type; +pub use self::values::Val; pub use wasmtime_component_macro::{flags, ComponentType, Lift, Lower}; // These items are expected to be used by an eventual diff --git a/crates/wasmtime/src/component/types.rs b/crates/wasmtime/src/component/types.rs new file mode 100644 index 000000000000..adf992acc3fc --- /dev/null +++ b/crates/wasmtime/src/component/types.rs @@ -0,0 +1,664 @@ +//! This module defines the `Type` type, representing the dynamic form of a component interface type. + +use crate::component::func; +use crate::component::values::{self, Val}; +use anyhow::{anyhow, Result}; +use std::fmt; +use std::mem; +use std::ops::Deref; +use std::sync::Arc; +use wasmtime_component_util::{DiscriminantSize, FlagsSize}; +use wasmtime_environ::component::{ + ComponentTypes, InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex, + TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, +}; + +#[derive(Clone)] +struct Handle { + index: T, + types: Arc, +} + +impl fmt::Debug for Handle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Handle") + .field("index", &self.index) + .finish() + } +} + +impl PartialEq for Handle { + fn eq(&self, other: &Self) -> bool { + // FIXME: This is an overly-restrictive definition of equality in that it doesn't consider types to be + // equal unless they refer to the same declaration in the same component. It's a good shortcut for the + // common case, but we should also do a recursive structural equality test if the shortcut test fails. + self.index == other.index && Arc::ptr_eq(&self.types, &other.types) + } +} + +impl Eq for Handle {} + +/// A `list` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct List(Handle); + +impl List { + /// Instantiate this type with the specified `values`. + pub fn new_val(&self, values: Box<[Val]>) -> Result { + Ok(Val::List(values::List::new(self, values)?)) + } + + /// Retreive the element type of this `list`. + pub fn ty(&self) -> Type { + Type::from(&self.0.types[self.0.index], &self.0.types) + } +} + +/// A field declaration belonging to a `record` +pub struct Field<'a> { + /// The name of the field + pub name: &'a str, + /// The type of the field + pub ty: Type, +} + +/// A `record` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Record(Handle); + +impl Record { + /// Instantiate this type with the specified `values`. + pub fn new_val<'a>(&self, values: impl IntoIterator) -> Result { + Ok(Val::Record(values::Record::new(self, values)?)) + } + + /// Retrieve the fields of this `record` in declaration order. + pub fn fields(&self) -> impl ExactSizeIterator { + self.0.types[self.0.index].fields.iter().map(|field| Field { + name: &field.name, + ty: Type::from(&field.ty, &self.0.types), + }) + } +} + +/// A `tuple` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Tuple(Handle); + +impl Tuple { + /// Instantiate this type ith the specified `values`. + pub fn new_val(&self, values: Box<[Val]>) -> Result { + Ok(Val::Tuple(values::Tuple::new(self, values)?)) + } + + /// Retrieve the types of the fields of this `tuple` in declaration order. + pub fn types(&self) -> impl ExactSizeIterator + '_ { + self.0.types[self.0.index] + .types + .iter() + .map(|ty| Type::from(ty, &self.0.types)) + } +} + +/// A case declaration belonging to a `variant` +pub struct Case<'a> { + /// The name of the case + pub name: &'a str, + /// The type of the case + pub ty: Type, +} + +/// A `variant` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Variant(Handle); + +impl Variant { + /// Instantiate this type with the specified case `name` and `value`. + pub fn new_val(&self, name: &str, value: Val) -> Result { + Ok(Val::Variant(values::Variant::new(self, name, value)?)) + } + + /// Retrieve the cases of this `variant` in declaration order. + pub fn cases(&self) -> impl ExactSizeIterator { + self.0.types[self.0.index].cases.iter().map(|case| Case { + name: &case.name, + ty: Type::from(&case.ty, &self.0.types), + }) + } +} + +/// An `enum` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Enum(Handle); + +impl Enum { + /// Instantiate this type with the specified case `name`. + pub fn new_val(&self, name: &str) -> Result { + Ok(Val::Enum(values::Enum::new(self, name)?)) + } + + /// Retrieve the names of the cases of this `enum` in declaration order. + pub fn names(&self) -> impl ExactSizeIterator { + self.0.types[self.0.index] + .names + .iter() + .map(|name| name.deref()) + } +} + +/// A `union` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Union(Handle); + +impl Union { + /// Instantiate this type with the specified `discriminant` and `value`. + pub fn new_val(&self, discriminant: u32, value: Val) -> Result { + Ok(Val::Union(values::Union::new(self, discriminant, value)?)) + } + + /// Retrieve the types of the cases of this `union` in declaration order. + pub fn types(&self) -> impl ExactSizeIterator + '_ { + self.0.types[self.0.index] + .types + .iter() + .map(|ty| Type::from(ty, &self.0.types)) + } +} + +/// An `option` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Option(Handle); + +impl Option { + /// Instantiate this type with the specified `value`. + pub fn new_val(&self, value: std::option::Option) -> Result { + Ok(Val::Option(values::Option::new(self, value)?)) + } + + /// Retrieve the type parameter for this `option`. + pub fn ty(&self) -> Type { + Type::from(&self.0.types[self.0.index], &self.0.types) + } +} + +/// An `expected` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Expected(Handle); + +impl Expected { + /// Instantiate this type with the specified `value`. + pub fn new_val(&self, value: Result) -> Result { + Ok(Val::Expected(values::Expected::new(self, value)?)) + } + + /// Retrieve the `ok` type parameter for this `option`. + pub fn ok(&self) -> Type { + Type::from(&self.0.types[self.0.index].ok, &self.0.types) + } + + /// Retrieve the `err` type parameter for this `option`. + pub fn err(&self) -> Type { + Type::from(&self.0.types[self.0.index].err, &self.0.types) + } +} + +/// A `flags` interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Flags(Handle); + +impl Flags { + /// Instantiate this type with the specified flag `names`. + pub fn new_val(&self, names: &[&str]) -> Result { + Ok(Val::Flags(values::Flags::new(self, names)?)) + } + + /// Retrieve the names of the flags of this `flags` type in declaration order. + pub fn names(&self) -> impl ExactSizeIterator { + self.0.types[self.0.index] + .names + .iter() + .map(|name| name.deref()) + } +} + +/// Represents the size and alignment requirements of the heap-serialized form of a type +pub(crate) struct SizeAndAlignment { + pub(crate) size: usize, + pub(crate) alignment: u32, +} + +/// Represents a component model interface type +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum Type { + /// Unit + Unit, + /// Boolean + Bool, + /// Signed 8-bit integer + S8, + /// Unsigned 8-bit integer + U8, + /// Signed 16-bit integer + S16, + /// Unsigned 16-bit integer + U16, + /// Signed 32-bit integer + S32, + /// Unsigned 32-bit integer + U32, + /// Signed 64-bit integer + S64, + /// Unsigned 64-bit integer + U64, + /// 64-bit floating point value + Float32, + /// 64-bit floating point value + Float64, + /// 32-bit character + Char, + /// Character string + String, + /// List of values + List(List), + /// Record + Record(Record), + /// Tuple + Tuple(Tuple), + /// Variant + Variant(Variant), + /// Enum + Enum(Enum), + /// Union + Union(Union), + /// Option + Option(Option), + /// Expected + Expected(Expected), + /// Bit flags + Flags(Flags), +} + +impl Type { + /// Retrieve the inner [`List`] of a [`Type::List`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::List`]. + pub fn unwrap_list(&self) -> &List { + if let Type::List(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a list", self.desc()) + } + } + + /// Retrieve the inner [`Record`] of a [`Type::Record`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::Record`]. + pub fn unwrap_record(&self) -> &Record { + if let Type::Record(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a record", self.desc()) + } + } + + /// Retrieve the inner [`Tuple`] of a [`Type::Tuple`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::Tuple`]. + pub fn unwrap_tuple(&self) -> &Tuple { + if let Type::Tuple(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a tuple", self.desc()) + } + } + + /// Retrieve the inner [`Variant`] of a [`Type::Variant`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::Variant`]. + pub fn unwrap_variant(&self) -> &Variant { + if let Type::Variant(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a variant", self.desc()) + } + } + + /// Retrieve the inner [`Enum`] of a [`Type::Enum`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::Enum`]. + pub fn unwrap_enum(&self) -> &Enum { + if let Type::Enum(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a enum", self.desc()) + } + } + + /// Retrieve the inner [`Union`] of a [`Type::Union`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::Union`]. + pub fn unwrap_union(&self) -> &Union { + if let Type::Union(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a union", self.desc()) + } + } + + /// Retrieve the inner [`Option`] of a [`Type::Option`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::Option`]. + pub fn unwrap_option(&self) -> &Option { + if let Type::Option(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a option", self.desc()) + } + } + + /// Retrieve the inner [`Expected`] of a [`Type::Expected`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::Expected`]. + pub fn unwrap_expected(&self) -> &Expected { + if let Type::Expected(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a expected", self.desc()) + } + } + + /// Retrieve the inner [`Flags`] of a [`Type::Flags`]. + /// + /// # Panics + /// + /// This will panic if `self` is not a [`Type::Flags`]. + pub fn unwrap_flags(&self) -> &Flags { + if let Type::Flags(handle) = self { + &handle + } else { + panic!("attempted to unwrap a {} as a flags", self.desc()) + } + } + + pub(crate) fn check(&self, value: &Val) -> Result<()> { + let other = &value.ty(); + if self == other { + Ok(()) + } else if mem::discriminant(self) != mem::discriminant(other) { + Err(anyhow!( + "type mismatch: expected {}, got {}", + self.desc(), + other.desc() + )) + } else { + Err(anyhow!( + "type mismatch for {}, possibly due to mixing distinct composite types", + self.desc() + )) + } + } + + /// Convert the specified `InterfaceType` to a `Type`. + pub(crate) fn from(ty: &InterfaceType, types: &Arc) -> Self { + match ty { + InterfaceType::Unit => Type::Unit, + InterfaceType::Bool => Type::Bool, + InterfaceType::S8 => Type::S8, + InterfaceType::U8 => Type::U8, + InterfaceType::S16 => Type::S16, + InterfaceType::U16 => Type::U16, + InterfaceType::S32 => Type::S32, + InterfaceType::U32 => Type::U32, + InterfaceType::S64 => Type::S64, + InterfaceType::U64 => Type::U64, + InterfaceType::Float32 => Type::Float32, + InterfaceType::Float64 => Type::Float64, + InterfaceType::Char => Type::Char, + InterfaceType::String => Type::String, + InterfaceType::List(index) => Type::List(List(Handle { + index: *index, + types: types.clone(), + })), + InterfaceType::Record(index) => Type::Record(Record(Handle { + index: *index, + types: types.clone(), + })), + InterfaceType::Tuple(index) => Type::Tuple(Tuple(Handle { + index: *index, + types: types.clone(), + })), + InterfaceType::Variant(index) => Type::Variant(Variant(Handle { + index: *index, + types: types.clone(), + })), + InterfaceType::Enum(index) => Type::Enum(Enum(Handle { + index: *index, + types: types.clone(), + })), + InterfaceType::Union(index) => Type::Union(Union(Handle { + index: *index, + types: types.clone(), + })), + InterfaceType::Option(index) => Type::Option(Option(Handle { + index: *index, + types: types.clone(), + })), + InterfaceType::Expected(index) => Type::Expected(Expected(Handle { + index: *index, + types: types.clone(), + })), + InterfaceType::Flags(index) => Type::Flags(Flags(Handle { + index: *index, + types: types.clone(), + })), + } + } + + /// Return the number of stack slots needed to store values of this type in lowered form. + pub(crate) fn flatten_count(&self) -> usize { + match self { + Type::Unit => 0, + + Type::Bool + | Type::S8 + | Type::U8 + | Type::S16 + | Type::U16 + | Type::S32 + | Type::U32 + | Type::S64 + | Type::U64 + | Type::Float32 + | Type::Float64 + | Type::Char + | Type::Enum(_) => 1, + + Type::String | Type::List(_) => 2, + + Type::Record(handle) => handle.fields().map(|field| field.ty.flatten_count()).sum(), + + Type::Tuple(handle) => handle.types().map(|ty| ty.flatten_count()).sum(), + + Type::Variant(handle) => { + 1 + handle + .cases() + .map(|case| case.ty.flatten_count()) + .max() + .unwrap_or(0) + } + + Type::Union(handle) => { + 1 + handle + .types() + .map(|ty| ty.flatten_count()) + .max() + .unwrap_or(0) + } + + Type::Option(handle) => 1 + handle.ty().flatten_count(), + + Type::Expected(handle) => { + 1 + handle + .ok() + .flatten_count() + .max(handle.err().flatten_count()) + } + + Type::Flags(handle) => values::u32_count_for_flag_count(handle.names().len()), + } + } + + fn desc(&self) -> &'static str { + match self { + Type::Unit => "unit", + Type::Bool => "bool", + Type::S8 => "s8", + Type::U8 => "u8", + Type::S16 => "s16", + Type::U16 => "u16", + Type::S32 => "s32", + Type::U32 => "u32", + Type::S64 => "s64", + Type::U64 => "u64", + Type::Float32 => "float32", + Type::Float64 => "float64", + Type::Char => "char", + Type::String => "string", + Type::List(_) => "list", + Type::Record(_) => "record", + Type::Tuple(_) => "tuple", + Type::Variant(_) => "variant", + Type::Enum(_) => "enum", + Type::Union(_) => "union", + Type::Option(_) => "option", + Type::Expected(_) => "expected", + Type::Flags(_) => "flags", + } + } + + /// Calculate the size and alignment requirements for the specified type. + pub(crate) fn size_and_alignment(&self) -> SizeAndAlignment { + match self { + Type::Unit => SizeAndAlignment { + size: 0, + alignment: 1, + }, + + Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment { + size: 1, + alignment: 1, + }, + + Type::S16 | Type::U16 => SizeAndAlignment { + size: 2, + alignment: 2, + }, + + Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment { + size: 4, + alignment: 4, + }, + + Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment { + size: 8, + alignment: 8, + }, + + Type::String | Type::List(_) => SizeAndAlignment { + size: 8, + alignment: 4, + }, + + Type::Record(handle) => { + record_size_and_alignment(handle.fields().map(|field| field.ty)) + } + + Type::Tuple(handle) => record_size_and_alignment(handle.types()), + + Type::Variant(handle) => variant_size_and_alignment(handle.cases().map(|case| case.ty)), + + Type::Enum(handle) => variant_size_and_alignment(handle.names().map(|_| Type::Unit)), + + Type::Union(handle) => variant_size_and_alignment(handle.types()), + + Type::Option(handle) => { + variant_size_and_alignment([Type::Unit, handle.ty()].into_iter()) + } + + Type::Expected(handle) => { + variant_size_and_alignment([handle.ok(), handle.err()].into_iter()) + } + + Type::Flags(handle) => match FlagsSize::from_count(handle.names().len()) { + FlagsSize::Size1 => SizeAndAlignment { + size: 1, + alignment: 1, + }, + FlagsSize::Size2 => SizeAndAlignment { + size: 2, + alignment: 2, + }, + FlagsSize::Size4Plus(n) => SizeAndAlignment { + size: n * 4, + alignment: 4, + }, + }, + } + } + + /// Calculate the aligned offset of a field of this type, updating `offset` to point to just after that field. + pub(crate) fn next_field(&self, offset: &mut usize) -> usize { + let SizeAndAlignment { size, alignment } = self.size_and_alignment(); + *offset = func::align_to(*offset, alignment); + let result = *offset; + *offset += size; + result + } +} + +fn record_size_and_alignment(types: impl Iterator) -> SizeAndAlignment { + let mut offset = 0; + let mut align = 1; + for ty in types { + let SizeAndAlignment { size, alignment } = ty.size_and_alignment(); + offset = func::align_to(offset, alignment) + size; + align = align.max(alignment); + } + + SizeAndAlignment { + size: func::align_to(offset, align), + alignment: align, + } +} + +fn variant_size_and_alignment(types: impl ExactSizeIterator) -> SizeAndAlignment { + let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap(); + let mut alignment = u32::from(discriminant_size); + let mut size = 0; + for ty in types { + let size_and_alignment = ty.size_and_alignment(); + alignment = alignment.max(size_and_alignment.alignment); + size = size.max(size_and_alignment.size); + } + + SizeAndAlignment { + size: func::align_to(usize::from(discriminant_size), alignment) + size, + alignment, + } +} diff --git a/crates/wasmtime/src/component/values.rs b/crates/wasmtime/src/component/values.rs new file mode 100644 index 000000000000..066174d257c7 --- /dev/null +++ b/crates/wasmtime/src/component/values.rs @@ -0,0 +1,908 @@ +use crate::component::func::{self, Lift, Lower, Memory, MemoryMut, Options}; +use crate::component::types::{self, SizeAndAlignment, Type}; +use crate::store::StoreOpaque; +use crate::{AsContextMut, StoreContextMut, ValRaw}; +use anyhow::{anyhow, bail, Context, Error, Result}; +use std::collections::HashMap; +use std::iter; +use std::mem::MaybeUninit; +use std::ops::Deref; +use wasmtime_component_util::{DiscriminantSize, FlagsSize}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct List { + ty: types::List, + values: Box<[Val]>, +} + +impl List { + /// Instantiate the specified type with the specified `values`. + pub fn new(ty: &types::List, values: Box<[Val]>) -> Result { + let element_type = ty.ty(); + for (index, value) in values.iter().enumerate() { + element_type + .check(value) + .with_context(|| format!("type mismatch for element {index} of list"))?; + } + + Ok(Self { + ty: ty.clone(), + values, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Record { + ty: types::Record, + values: Box<[Val]>, +} + +impl Record { + /// Instantiate the specified type with the specified `values`. + pub fn new<'a>( + ty: &types::Record, + values: impl IntoIterator, + ) -> Result { + let mut fields = ty.fields(); + let expected_len = fields.len(); + let mut iter = values.into_iter(); + let mut values = Vec::with_capacity(expected_len); + loop { + match (fields.next(), iter.next()) { + (Some(field), Some((name, value))) => { + if name == field.name { + field + .ty + .check(&value) + .with_context(|| format!("type mismatch for field {name} of record"))?; + + values.push(value); + } else { + bail!("field name mismatch: expected {}; got {name}", field.name) + } + } + (None, Some((_, value))) => values.push(value), + _ => break, + } + } + + if values.len() != expected_len { + bail!("expected {} value(s); got {}", expected_len, values.len()); + } + + Ok(Self { + ty: ty.clone(), + values: values.into(), + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Tuple { + ty: types::Tuple, + values: Box<[Val]>, +} + +impl Tuple { + /// Instantiate the specified type ith the specified `values`. + pub fn new(ty: &types::Tuple, values: Box<[Val]>) -> Result { + if values.len() != ty.types().len() { + bail!( + "expected {} value(s); got {}", + ty.types().len(), + values.len() + ); + } + + for (index, (value, ty)) in values.iter().zip(ty.types()).enumerate() { + ty.check(value) + .with_context(|| format!("type mismatch for field {index} of tuple"))?; + } + + Ok(Self { + ty: ty.clone(), + values, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Variant { + ty: types::Variant, + discriminant: u32, + value: Box, +} + +impl Variant { + /// Instantiate the specified type with the specified case `name` and `value`. + pub fn new(ty: &types::Variant, name: &str, value: Val) -> Result { + let (discriminant, case_type) = ty + .cases() + .enumerate() + .find_map(|(index, case)| { + if case.name == name { + Some((index, case.ty)) + } else { + None + } + }) + .ok_or_else(|| anyhow!("unknown variant case: {name}"))?; + + case_type + .check(&value) + .with_context(|| format!("type mismatch for case {name} of variant"))?; + + Ok(Self { + ty: ty.clone(), + discriminant: u32::try_from(discriminant)?, + value: Box::new(value), + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Enum { + ty: types::Enum, + discriminant: u32, +} + +impl Enum { + /// Instantiate the specified type with the specified case `name`. + pub fn new(ty: &types::Enum, name: &str) -> Result { + let discriminant = u32::try_from( + ty.names() + .position(|n| n == name) + .ok_or_else(|| anyhow!("unknown enum case: {name}"))?, + )?; + + Ok(Self { + ty: ty.clone(), + discriminant, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Union { + ty: types::Union, + discriminant: u32, + value: Box, +} + +impl Union { + /// Instantiate the specified type with the specified `discriminant` and `value`. + pub fn new(ty: &types::Union, discriminant: u32, value: Val) -> Result { + if let Some(case_ty) = ty.types().nth(usize::try_from(discriminant)?) { + case_ty + .check(&value) + .with_context(|| format!("type mismatch for case {discriminant} of union"))?; + + Ok(Self { + ty: ty.clone(), + discriminant, + value: Box::new(value), + }) + } else { + Err(anyhow!( + "discriminant {discriminant} out of range: [0,{})", + ty.types().len() + )) + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Option { + ty: types::Option, + discriminant: u32, + value: Box, +} + +impl Option { + /// Instantiate the specified type with the specified `value`. + pub fn new(ty: &types::Option, value: std::option::Option) -> Result { + let value = value + .map(|value| { + ty.ty().check(&value).context("type mismatch for option")?; + + Ok::<_, Error>(value) + }) + .transpose()?; + + Ok(Self { + ty: ty.clone(), + discriminant: if value.is_none() { 0 } else { 1 }, + value: Box::new(value.unwrap_or(Val::Unit)), + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Expected { + ty: types::Expected, + discriminant: u32, + value: Box, +} + +impl Expected { + /// Instantiate the specified type with the specified `value`. + pub fn new(ty: &types::Expected, value: Result) -> Result { + Ok(Self { + ty: ty.clone(), + discriminant: if value.is_ok() { 0 } else { 1 }, + value: Box::new(match value { + Ok(value) => { + ty.ok() + .check(&value) + .context("type mismatch for ok case of expected")?; + value + } + Err(value) => { + ty.err() + .check(&value) + .context("type mismatch for err case of expected")?; + value + } + }), + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Flags { + ty: types::Flags, + count: u32, + value: Box<[u32]>, +} + +impl Flags { + /// Instantiate the specified type with the specified flag `names`. + pub fn new(ty: &types::Flags, names: &[&str]) -> Result { + let map = ty + .names() + .enumerate() + .map(|(index, name)| (name, index)) + .collect::>(); + + let mut values = vec![0_u32; u32_count_for_flag_count(ty.names().len())]; + + for name in names { + let index = map + .get(name) + .ok_or_else(|| anyhow!("unknown flag: {name}"))?; + values[index / 32] |= 1 << (index % 32); + } + + Ok(Self { + ty: ty.clone(), + count: u32::try_from(map.len())?, + value: values.into(), + }) + } +} + +/// Represents possible runtime values which a component function can either consume or produce +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Val { + /// Unit + Unit, + /// Boolean + Bool(bool), + /// Signed 8-bit integer + S8(i8), + /// Unsigned 8-bit integer + U8(u8), + /// Signed 16-bit integer + S16(i16), + /// Unsigned 16-bit integer + U16(u16), + /// Signed 32-bit integer + S32(i32), + /// Unsigned 32-bit integer + U32(u32), + /// Signed 64-bit integer + S64(i64), + /// Unsigned 64-bit integer + U64(u64), + /// 32-bit floating point value + Float32(u32), + /// 64-bit floating point value + Float64(u64), + /// 32-bit character + Char(char), + /// Character string + String(Box), + /// List of values + List(List), + /// Record + Record(Record), + /// Tuple + Tuple(Tuple), + /// Variant + Variant(Variant), + /// Enum + Enum(Enum), + /// Union + Union(Union), + /// Option + Option(Option), + /// Expected + Expected(Expected), + /// Bit flags + Flags(Flags), +} + +impl Val { + /// Retrieve the [`Type`] of this value. + pub fn ty(&self) -> Type { + match self { + Val::Unit => Type::Unit, + Val::Bool(_) => Type::Bool, + Val::S8(_) => Type::S8, + Val::U8(_) => Type::U8, + Val::S16(_) => Type::S16, + Val::U16(_) => Type::U16, + Val::S32(_) => Type::S32, + Val::U32(_) => Type::U32, + Val::S64(_) => Type::S64, + Val::U64(_) => Type::U64, + Val::Float32(_) => Type::Float32, + Val::Float64(_) => Type::Float64, + Val::Char(_) => Type::Char, + Val::String(_) => Type::String, + Val::List(List { ty, .. }) => Type::List(ty.clone()), + Val::Record(Record { ty, .. }) => Type::Record(ty.clone()), + Val::Tuple(Tuple { ty, .. }) => Type::Tuple(ty.clone()), + Val::Variant(Variant { ty, .. }) => Type::Variant(ty.clone()), + Val::Enum(Enum { ty, .. }) => Type::Enum(ty.clone()), + Val::Union(Union { ty, .. }) => Type::Union(ty.clone()), + Val::Option(Option { ty, .. }) => Type::Option(ty.clone()), + Val::Expected(Expected { ty, .. }) => Type::Expected(ty.clone()), + Val::Flags(Flags { ty, .. }) => Type::Flags(ty.clone()), + } + } + + /// Deserialize a value of this type from core Wasm stack values. + pub(crate) fn lift<'a>( + ty: &Type, + store: &StoreOpaque, + options: &Options, + src: &mut std::slice::Iter<'_, ValRaw>, + ) -> Result { + Ok(match ty { + Type::Unit => Val::Unit, + Type::Bool => Val::Bool(bool::lift(store, options, next(src))?), + Type::S8 => Val::S8(i8::lift(store, options, next(src))?), + Type::U8 => Val::U8(u8::lift(store, options, next(src))?), + Type::S16 => Val::S16(i16::lift(store, options, next(src))?), + Type::U16 => Val::U16(u16::lift(store, options, next(src))?), + Type::S32 => Val::S32(i32::lift(store, options, next(src))?), + Type::U32 => Val::U32(u32::lift(store, options, next(src))?), + Type::S64 => Val::S64(i64::lift(store, options, next(src))?), + Type::U64 => Val::U64(u64::lift(store, options, next(src))?), + Type::Float32 => Val::Float32(u32::lift(store, options, next(src))?), + Type::Float64 => Val::Float64(u64::lift(store, options, next(src))?), + Type::Char => Val::Char(char::lift(store, options, next(src))?), + Type::String => { + Val::String(Box::::lift(store, options, &[*next(src), *next(src)])?) + } + Type::List(handle) => { + // FIXME: needs memory64 treatment + let ptr = u32::lift(store, options, next(src))? as usize; + let len = u32::lift(store, options, next(src))? as usize; + load_list(handle, &Memory::new(store, options), ptr, len)? + } + Type::Record(handle) => Val::Record(Record { + ty: handle.clone(), + values: handle + .fields() + .map(|field| Self::lift(&field.ty, store, options, src)) + .collect::>()?, + }), + Type::Tuple(handle) => Val::Tuple(Tuple { + ty: handle.clone(), + values: handle + .types() + .map(|ty| Self::lift(&ty, store, options, src)) + .collect::>()?, + }), + Type::Variant(handle) => { + let (discriminant, value) = lift_variant( + ty.flatten_count(), + handle.cases().map(|case| case.ty), + store, + options, + src, + )?; + + Val::Variant(Variant { + ty: handle.clone(), + discriminant, + value: Box::new(value), + }) + } + Type::Enum(handle) => { + let (discriminant, _) = lift_variant( + ty.flatten_count(), + handle.names().map(|_| Type::Unit), + store, + options, + src, + )?; + + Val::Enum(Enum { + ty: handle.clone(), + discriminant, + }) + } + Type::Union(handle) => { + let (discriminant, value) = + lift_variant(ty.flatten_count(), handle.types(), store, options, src)?; + + Val::Union(Union { + ty: handle.clone(), + discriminant, + value: Box::new(value), + }) + } + Type::Option(handle) => { + let (discriminant, value) = lift_variant( + ty.flatten_count(), + [Type::Unit, handle.ty()].into_iter(), + store, + options, + src, + )?; + + Val::Option(Option { + ty: handle.clone(), + discriminant, + value: Box::new(value), + }) + } + Type::Expected(handle) => { + let (discriminant, value) = lift_variant( + ty.flatten_count(), + [handle.ok(), handle.err()].into_iter(), + store, + options, + src, + )?; + + Val::Expected(Expected { + ty: handle.clone(), + discriminant, + value: Box::new(value), + }) + } + Type::Flags(handle) => { + let count = u32::try_from(handle.names().len()).unwrap(); + assert!(count <= 32); + let value = iter::once(u32::lift(store, options, next(src))?).collect(); + + Val::Flags(Flags { + ty: handle.clone(), + count, + value, + }) + } + }) + } + + /// Deserialize a value of this type from the heap. + pub(crate) fn load(ty: &Type, mem: &Memory, bytes: &[u8]) -> Result { + Ok(match ty { + Type::Unit => Val::Unit, + Type::Bool => Val::Bool(bool::load(mem, bytes)?), + Type::S8 => Val::S8(i8::load(mem, bytes)?), + Type::U8 => Val::U8(u8::load(mem, bytes)?), + Type::S16 => Val::S16(i16::load(mem, bytes)?), + Type::U16 => Val::U16(u16::load(mem, bytes)?), + Type::S32 => Val::S32(i32::load(mem, bytes)?), + Type::U32 => Val::U32(u32::load(mem, bytes)?), + Type::S64 => Val::S64(i64::load(mem, bytes)?), + Type::U64 => Val::U64(u64::load(mem, bytes)?), + Type::Float32 => Val::Float32(u32::load(mem, bytes)?), + Type::Float64 => Val::Float64(u64::load(mem, bytes)?), + Type::Char => Val::Char(char::load(mem, bytes)?), + Type::String => Val::String(Box::::load(mem, bytes)?), + Type::List(handle) => { + // FIXME: needs memory64 treatment + let ptr = u32::from_le_bytes(bytes[..4].try_into().unwrap()) as usize; + let len = u32::from_le_bytes(bytes[4..].try_into().unwrap()) as usize; + load_list(handle, mem, ptr, len)? + } + Type::Record(handle) => Val::Record(Record { + ty: handle.clone(), + values: load_record(handle.fields().map(|field| field.ty), mem, bytes)?, + }), + Type::Tuple(handle) => Val::Tuple(Tuple { + ty: handle.clone(), + values: load_record(handle.types(), mem, bytes)?, + }), + Type::Variant(handle) => { + let (discriminant, value) = + load_variant(ty, handle.cases().map(|case| case.ty), mem, bytes)?; + + Val::Variant(Variant { + ty: handle.clone(), + discriminant, + value: Box::new(value), + }) + } + Type::Enum(handle) => { + let (discriminant, _) = + load_variant(ty, handle.names().map(|_| Type::Unit), mem, bytes)?; + + Val::Enum(Enum { + ty: handle.clone(), + discriminant, + }) + } + Type::Union(handle) => { + let (discriminant, value) = load_variant(ty, handle.types(), mem, bytes)?; + + Val::Union(Union { + ty: handle.clone(), + discriminant, + value: Box::new(value), + }) + } + Type::Option(handle) => { + let (discriminant, value) = + load_variant(ty, [Type::Unit, handle.ty()].into_iter(), mem, bytes)?; + + Val::Option(Option { + ty: handle.clone(), + discriminant, + value: Box::new(value), + }) + } + Type::Expected(handle) => { + let (discriminant, value) = + load_variant(ty, [handle.ok(), handle.err()].into_iter(), mem, bytes)?; + + Val::Expected(Expected { + ty: handle.clone(), + discriminant, + value: Box::new(value), + }) + } + Type::Flags(handle) => Val::Flags(Flags { + ty: handle.clone(), + count: u32::try_from(handle.names().len())?, + value: match FlagsSize::from_count(handle.names().len()) { + FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(), + FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(), + FlagsSize::Size4Plus(n) => (0..n) + .map(|index| u32::load(mem, &bytes[index * 4..][..4])) + .collect::>()?, + }, + }), + }) + } + + /// Serialize this value as core Wasm stack values. + pub(crate) fn lower( + &self, + store: &mut StoreContextMut, + options: &Options, + dst: &mut std::slice::IterMut<'_, MaybeUninit>, + ) -> Result<()> { + match self { + Val::Unit => (), + Val::Bool(value) => value.lower(store, options, next_mut(dst))?, + Val::S8(value) => value.lower(store, options, next_mut(dst))?, + Val::U8(value) => value.lower(store, options, next_mut(dst))?, + Val::S16(value) => value.lower(store, options, next_mut(dst))?, + Val::U16(value) => value.lower(store, options, next_mut(dst))?, + Val::S32(value) => value.lower(store, options, next_mut(dst))?, + Val::U32(value) => value.lower(store, options, next_mut(dst))?, + Val::S64(value) => value.lower(store, options, next_mut(dst))?, + Val::U64(value) => value.lower(store, options, next_mut(dst))?, + Val::Float32(value) => value.lower(store, options, next_mut(dst))?, + Val::Float64(value) => value.lower(store, options, next_mut(dst))?, + Val::Char(value) => value.lower(store, options, next_mut(dst))?, + Val::String(value) => { + let my_dst = &mut MaybeUninit::<[ValRaw; 2]>::uninit(); + value.lower(store, options, my_dst)?; + let my_dst = unsafe { my_dst.assume_init() }; + next_mut(dst).write(my_dst[0]); + next_mut(dst).write(my_dst[1]); + } + Val::List(List { values, ty }) => { + let (ptr, len) = lower_list( + &ty.ty(), + &mut MemoryMut::new(store.as_context_mut(), options), + values, + )?; + next_mut(dst).write(ValRaw::i64(ptr as i64)); + next_mut(dst).write(ValRaw::i64(len as i64)); + } + Val::Record(Record { values, .. }) | Val::Tuple(Tuple { values, .. }) => { + for value in values.deref() { + value.lower(store, options, dst)?; + } + } + Val::Variant(Variant { + discriminant, + value, + .. + }) + | Val::Union(Union { + discriminant, + value, + .. + }) + | Val::Option(Option { + discriminant, + value, + .. + }) + | Val::Expected(Expected { + discriminant, + value, + .. + }) => { + next_mut(dst).write(ValRaw::u32(*discriminant)); + value.lower(store, options, dst)?; + for _ in (1 + value.ty().flatten_count())..self.ty().flatten_count() { + next_mut(dst).write(ValRaw::u32(0)); + } + } + Val::Enum(Enum { discriminant, .. }) => { + next_mut(dst).write(ValRaw::u32(*discriminant)); + } + Val::Flags(Flags { value, .. }) => { + for value in value.deref() { + next_mut(dst).write(ValRaw::u32(*value)); + } + } + } + + Ok(()) + } + + /// Serialize this value to the heap at the specified memory location. + pub(crate) fn store(&self, mem: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> { + debug_assert!(offset % usize::try_from(self.ty().size_and_alignment().alignment)? == 0); + + match self { + Val::Unit => (), + Val::Bool(value) => value.store(mem, offset)?, + Val::S8(value) => value.store(mem, offset)?, + Val::U8(value) => value.store(mem, offset)?, + Val::S16(value) => value.store(mem, offset)?, + Val::U16(value) => value.store(mem, offset)?, + Val::S32(value) => value.store(mem, offset)?, + Val::U32(value) => value.store(mem, offset)?, + Val::S64(value) => value.store(mem, offset)?, + Val::U64(value) => value.store(mem, offset)?, + Val::Float32(value) => value.store(mem, offset)?, + Val::Float64(value) => value.store(mem, offset)?, + Val::Char(value) => value.store(mem, offset)?, + Val::String(value) => value.store(mem, offset)?, + Val::List(List { values, ty }) => { + let (ptr, len) = lower_list(&ty.ty(), mem, values)?; + // FIXME: needs memory64 handling + *mem.get(offset + 0) = (ptr as i32).to_le_bytes(); + *mem.get(offset + 4) = (len as i32).to_le_bytes(); + } + Val::Record(Record { values, .. }) | Val::Tuple(Tuple { values, .. }) => { + let mut offset = offset; + for value in values.deref() { + value.store(mem, value.ty().next_field(&mut offset))?; + } + } + Val::Variant(Variant { + discriminant, + value, + ty, + }) => self.store_variant(*discriminant, value, ty.cases().len(), mem, offset)?, + + Val::Enum(Enum { discriminant, ty }) => { + self.store_variant(*discriminant, &Val::Unit, ty.names().len(), mem, offset)? + } + + Val::Union(Union { + discriminant, + value, + ty, + }) => self.store_variant(*discriminant, value, ty.types().len(), mem, offset)?, + + Val::Option(Option { + discriminant, + value, + .. + }) + | Val::Expected(Expected { + discriminant, + value, + .. + }) => self.store_variant(*discriminant, value, 2, mem, offset)?, + + Val::Flags(Flags { count, value, .. }) => { + match FlagsSize::from_count(*count as usize) { + FlagsSize::Size1 => u8::try_from(value[0]).unwrap().store(mem, offset)?, + FlagsSize::Size2 => u16::try_from(value[0]).unwrap().store(mem, offset)?, + FlagsSize::Size4Plus(_) => { + let mut offset = offset; + for value in value.deref() { + value.store(mem, offset)?; + offset += 4; + } + } + } + } + } + + Ok(()) + } + + fn store_variant( + &self, + discriminant: u32, + value: &Val, + case_count: usize, + mem: &mut MemoryMut<'_, T>, + offset: usize, + ) -> Result<()> { + let discriminant_size = DiscriminantSize::from_count(case_count).unwrap(); + match discriminant_size { + DiscriminantSize::Size1 => u8::try_from(discriminant).unwrap().store(mem, offset)?, + DiscriminantSize::Size2 => u16::try_from(discriminant).unwrap().store(mem, offset)?, + DiscriminantSize::Size4 => (discriminant).store(mem, offset)?, + } + + value.store( + mem, + offset + + func::align_to( + discriminant_size.into(), + self.ty().size_and_alignment().alignment, + ), + ) + } +} + +fn load_list(handle: &types::List, mem: &Memory, ptr: usize, len: usize) -> Result { + let element_type = handle.ty(); + let SizeAndAlignment { + size: element_size, + alignment: element_alignment, + } = element_type.size_and_alignment(); + + match len + .checked_mul(element_size) + .and_then(|len| ptr.checked_add(len)) + { + Some(n) if n <= mem.as_slice().len() => {} + _ => bail!("list pointer/length out of bounds of memory"), + } + if ptr % usize::try_from(element_alignment)? != 0 { + bail!("list pointer is not aligned") + } + + Ok(Val::List(List { + ty: handle.clone(), + values: (0..len) + .map(|index| { + Val::load( + &element_type, + mem, + &mem.as_slice()[ptr + (index * element_size)..][..element_size], + ) + }) + .collect::>()?, + })) +} + +fn load_record( + types: impl Iterator, + mem: &Memory, + bytes: &[u8], +) -> Result> { + let mut offset = 0; + types + .map(|ty| { + Val::load( + &ty, + mem, + &bytes[ty.next_field(&mut offset)..][..ty.size_and_alignment().size], + ) + }) + .collect() +} + +fn load_variant( + ty: &Type, + mut types: impl ExactSizeIterator, + mem: &Memory, + bytes: &[u8], +) -> Result<(u32, Val)> { + let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap(); + let discriminant = match discriminant_size { + DiscriminantSize::Size1 => u8::load(mem, &bytes[..1])? as u32, + DiscriminantSize::Size2 => u16::load(mem, &bytes[..2])? as u32, + DiscriminantSize::Size4 => u32::load(mem, &bytes[..4])?, + }; + let case_ty = types.nth(discriminant as usize).ok_or_else(|| { + anyhow!( + "discriminant {} out of range [0..{})", + discriminant, + types.len() + ) + })?; + let value = Val::load( + &case_ty, + mem, + &bytes[func::align_to( + usize::from(discriminant_size), + ty.size_and_alignment().alignment, + )..][..case_ty.size_and_alignment().size], + )?; + Ok((discriminant, value)) +} + +fn lift_variant<'a>( + flatten_count: usize, + mut types: impl ExactSizeIterator, + store: &StoreOpaque, + options: &Options, + src: &mut std::slice::Iter<'_, ValRaw>, +) -> Result<(u32, Val)> { + let len = types.len(); + let discriminant = next(src).get_u32(); + let ty = types + .nth(discriminant as usize) + .ok_or_else(|| anyhow!("discriminant {} out of range [0..{})", discriminant, len))?; + let value = Val::lift(&ty, store, options, src)?; + for _ in (1 + ty.flatten_count())..flatten_count { + next(src); + } + Ok((discriminant, value)) +} + +/// Lower a list with the specified element type and values. +fn lower_list( + element_type: &Type, + mem: &mut MemoryMut<'_, T>, + items: &[Val], +) -> Result<(usize, usize)> { + let SizeAndAlignment { + size: element_size, + alignment: element_alignment, + } = element_type.size_and_alignment(); + let size = items + .len() + .checked_mul(element_size) + .ok_or_else(|| anyhow::anyhow!("size overflow copying a list"))?; + let ptr = mem.realloc(0, 0, element_alignment, size)?; + let mut element_ptr = ptr; + for item in items { + item.store(mem, element_ptr)?; + element_ptr += element_size; + } + Ok((ptr, items.len())) +} + +/// Calculate the size of a u32 array needed to represent the specified number of bit flags. +/// +/// Note that this will always return at least 1, even if the `count` parameter is zero. +pub(crate) fn u32_count_for_flag_count(count: usize) -> usize { + match FlagsSize::from_count(count) { + FlagsSize::Size1 | FlagsSize::Size2 => 1, + FlagsSize::Size4Plus(n) => n, + } +} + +fn next<'a>(src: &mut std::slice::Iter<'a, ValRaw>) -> &'a ValRaw { + src.next().unwrap() +} + +fn next_mut<'a>( + dst: &mut std::slice::IterMut<'a, MaybeUninit>, +) -> &'a mut MaybeUninit { + dst.next().unwrap() +} diff --git a/scripts/publish.rs b/scripts/publish.rs index a2b1c5b81243..30e699b3c79b 100644 --- a/scripts/publish.rs +++ b/scripts/publish.rs @@ -41,6 +41,7 @@ const CRATES_TO_PUBLISH: &[&str] = &[ "wiggle-macro", // wasmtime "wasmtime-asm-macros", + "wasmtime-component-util", "wasmtime-component-macro", "wasmtime-jit-debug", "wasmtime-fiber", diff --git a/tests/all/component_model.rs b/tests/all/component_model.rs index 4fee2a480fa3..6c73ddfd35d4 100644 --- a/tests/all/component_model.rs +++ b/tests/all/component_model.rs @@ -1,7 +1,10 @@ use anyhow::Result; +use std::fmt::Write; +use std::iter; use wasmtime::component::{Component, ComponentParams, Lift, Lower, TypedFunc}; use wasmtime::{AsContextMut, Config, Engine}; +mod dynamic; mod func; mod import; mod instance; @@ -148,3 +151,128 @@ fn components_importing_modules() -> Result<()> { Ok(()) } + +#[derive(Copy, Clone, PartialEq, Eq)] +enum Type { + S8, + U8, + S16, + U16, + I32, + I64, + F32, + F64, +} + +impl Type { + fn store(&self) -> &'static str { + match self { + Self::S8 | Self::U8 => "store8", + Self::S16 | Self::U16 => "store16", + Self::I32 | Self::F32 | Self::I64 | Self::F64 => "store", + } + } + + fn primitive(&self) -> &'static str { + match self { + Self::S8 | Self::U8 | Self::S16 | Self::U16 | Self::I32 => "i32", + Self::I64 => "i64", + Self::F32 => "f32", + Self::F64 => "f64", + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +struct Param(Type, Option); + +fn make_echo_component(type_definition: &str, type_size: u32) -> String { + let mut offset = 0; + make_echo_component_with_params( + type_definition, + &iter::repeat(Type::I32) + .map(|ty| { + let param = Param(ty, Some(offset)); + offset += 4; + param + }) + .take(usize::try_from(type_size).unwrap() / 4) + .collect::>(), + ) +} + +fn make_echo_component_with_params(type_definition: &str, params: &[Param]) -> String { + let func = if params.len() == 1 || params.len() > 16 { + let primitive = if params.len() == 1 { + params[0].0.primitive() + } else { + "i32" + }; + + format!( + r#" + (func (export "echo") (param {primitive}) (result {primitive}) + local.get 0 + )"#, + ) + } else { + let mut param_string = String::new(); + let mut store = String::new(); + let mut size = 8; + + for (index, Param(ty, offset)) in params.iter().enumerate() { + let primitive = ty.primitive(); + + write!(&mut param_string, " {primitive}").unwrap(); + if let Some(offset) = offset { + write!( + &mut store, + "({primitive}.{} offset={offset} (local.get $base) (local.get {index}))", + ty.store(), + ) + .unwrap(); + + size = size.max(offset + 8); + } + } + + format!( + r#" + (func (export "echo") (param{param_string}) (result i32) + (local $base i32) + (local.set $base + (call $realloc + (i32.const 0) + (i32.const 0) + (i32.const 4) + (i32.const {size}))) + {store} + local.get $base + )"# + ) + }; + + format!( + r#" + (component + (core module $m + {func} + + (memory (export "memory") 1) + {REALLOC_AND_FREE} + ) + + (core instance $i (instantiate $m)) + + (type $Foo {type_definition}) + + (func (export "echo") (param $Foo) (result $Foo) + (canon lift + (core func $i "echo") + (memory $i "memory") + (realloc (func $i "realloc")) + ) + ) + )"# + ) +} diff --git a/tests/all/component_model/dynamic.rs b/tests/all/component_model/dynamic.rs new file mode 100644 index 000000000000..66ab8c0fe6cf --- /dev/null +++ b/tests/all/component_model/dynamic.rs @@ -0,0 +1,511 @@ +use super::{make_echo_component, make_echo_component_with_params, Param, Type}; +use anyhow::Result; +use wasmtime::component::{self, Component, Func, Linker, Val}; +use wasmtime::{AsContextMut, Store}; + +trait FuncExt { + fn call_and_post_return(&self, store: impl AsContextMut, args: &[Val]) -> Result; +} + +impl FuncExt for Func { + fn call_and_post_return(&self, mut store: impl AsContextMut, args: &[Val]) -> Result { + let result = self.call(&mut store, args)?; + self.post_return(&mut store)?; + Ok(result) + } +} + +#[test] +fn primitives() -> Result<()> { + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + for (input, ty, param) in [ + (Val::Bool(true), "bool", Param(Type::U8, Some(0))), + (Val::S8(-42), "s8", Param(Type::S8, Some(0))), + (Val::U8(42), "u8", Param(Type::U8, Some(0))), + (Val::S16(-4242), "s16", Param(Type::S16, Some(0))), + (Val::U16(4242), "u16", Param(Type::U16, Some(0))), + (Val::S32(-314159265), "s32", Param(Type::I32, Some(0))), + (Val::U32(314159265), "u32", Param(Type::I32, Some(0))), + (Val::S64(-31415926535897), "s64", Param(Type::I64, Some(0))), + (Val::U64(31415926535897), "u64", Param(Type::I64, Some(0))), + ( + Val::Float32(3.14159265_f32.to_bits()), + "float32", + Param(Type::F32, Some(0)), + ), + ( + Val::Float64(3.14159265_f64.to_bits()), + "float64", + Param(Type::F64, Some(0)), + ), + (Val::Char('🦀'), "char", Param(Type::I32, Some(0))), + ] { + let component = Component::new(&engine, make_echo_component_with_params(ty, &[param]))?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + } + + // Sad path: type mismatch + + let component = Component::new( + &engine, + make_echo_component_with_params("float64", &[Param(Type::F64, Some(0))]), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let err = func + .call_and_post_return(&mut store, &[Val::U64(42)]) + .unwrap_err(); + + assert!(err.to_string().contains("type mismatch"), "{err}"); + + // Sad path: arity mismatch (too many) + + let err = func + .call_and_post_return( + &mut store, + &[ + Val::Float64(3.14159265_f64.to_bits()), + Val::Float64(3.14159265_f64.to_bits()), + ], + ) + .unwrap_err(); + + assert!( + err.to_string().contains("expected 1 argument(s), got 2"), + "{err}" + ); + + // Sad path: arity mismatch (too few) + + let err = func.call_and_post_return(&mut store, &[]).unwrap_err(); + + assert!( + err.to_string().contains("expected 1 argument(s), got 0"), + "{err}" + ); + + Ok(()) +} + +#[test] +fn strings() -> Result<()> { + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + let component = Component::new(&engine, make_echo_component("string", 8))?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let input = Val::String(Box::from("hello, component!")); + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + + Ok(()) +} + +#[test] +fn lists() -> Result<()> { + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + let component = Component::new(&engine, make_echo_component("(list u32)", 8))?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let ty = &func.params(&store)[0]; + let input = ty.unwrap_list().new_val(Box::new([ + Val::U32(32343), + Val::U32(79023439), + Val::U32(2084037802), + ]))?; + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + + // Sad path: type mismatch + + let err = ty + .unwrap_list() + .new_val(Box::new([ + Val::U32(32343), + Val::U32(79023439), + Val::Float32(3.14159265_f32.to_bits()), + ])) + .unwrap_err(); + + assert!(err.to_string().contains("type mismatch"), "{err}"); + + Ok(()) +} + +#[test] +fn records() -> Result<()> { + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + let component = Component::new( + &engine, + make_echo_component_with_params( + r#"(record (field "A" u32) (field "B" float64) (field "C" (record (field "D" bool) (field "E" u32))))"#, + &[ + Param(Type::I32, Some(0)), + Param(Type::F64, Some(8)), + Param(Type::U8, Some(16)), + Param(Type::I32, Some(20)), + ], + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let ty = &func.params(&store)[0]; + let inner_type = &ty.unwrap_record().fields().nth(2).unwrap().ty; + let input = ty.unwrap_record().new_val([ + ("A", Val::U32(32343)), + ("B", Val::Float64(3.14159265_f64.to_bits())), + ( + "C", + inner_type + .unwrap_record() + .new_val([("D", Val::Bool(false)), ("E", Val::U32(2084037802))])?, + ), + ])?; + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + + // Sad path: type mismatch + + let err = ty + .unwrap_record() + .new_val([ + ("A", Val::S32(32343)), + ("B", Val::Float64(3.14159265_f64.to_bits())), + ( + "C", + inner_type + .unwrap_record() + .new_val([("D", Val::Bool(false)), ("E", Val::U32(2084037802))])?, + ), + ]) + .unwrap_err(); + + assert!(err.to_string().contains("type mismatch"), "{err}"); + + // Sad path: too many fields + + let err = ty + .unwrap_record() + .new_val([ + ("A", Val::U32(32343)), + ("B", Val::Float64(3.14159265_f64.to_bits())), + ( + "C", + inner_type + .unwrap_record() + .new_val([("D", Val::Bool(false)), ("E", Val::U32(2084037802))])?, + ), + ("F", Val::Unit), + ]) + .unwrap_err(); + + assert!( + err.to_string().contains("expected 3 value(s); got 4"), + "{err}" + ); + + // Sad path: too few fields + + let err = ty + .unwrap_record() + .new_val([ + ("A", Val::U32(32343)), + ("B", Val::Float64(3.14159265_f64.to_bits())), + ]) + .unwrap_err(); + + assert!( + err.to_string().contains("expected 3 value(s); got 2"), + "{err}" + ); + + Ok(()) +} + +#[test] +fn variants() -> Result<()> { + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + let component = Component::new( + &engine, + make_echo_component_with_params( + r#"(variant (case "A" u32) (case "B" float64) (case "C" (record (field "D" bool) (field "E" u32))))"#, + &[ + Param(Type::U8, Some(0)), + Param(Type::I64, Some(8)), + Param(Type::I32, None), + ], + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let ty = &func.params(&store)[0]; + let input = ty + .unwrap_variant() + .new_val("B", Val::Float64(3.14159265_f64.to_bits()))?; + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + + // Do it again, this time using case "C" + + let component = Component::new( + &engine, + dbg!(make_echo_component_with_params( + r#"(variant (case "A" u32) (case "B" float64) (case "C" (record (field "D" bool) (field "E" u32))))"#, + &[ + Param(Type::U8, Some(0)), + Param(Type::I64, Some(8)), + Param(Type::I32, Some(12)), + ], + )), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let ty = &func.params(&store)[0]; + let c_type = &ty.unwrap_variant().cases().nth(2).unwrap().ty; + let input = ty.unwrap_variant().new_val( + "C", + c_type + .unwrap_record() + .new_val([("D", Val::Bool(true)), ("E", Val::U32(314159265))])?, + )?; + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + + // Sad path: type mismatch + + let err = ty + .unwrap_variant() + .new_val("B", Val::U64(314159265)) + .unwrap_err(); + + assert!(err.to_string().contains("type mismatch"), "{err}"); + + // Sad path: unknown case + + let err = ty + .unwrap_variant() + .new_val("D", Val::U64(314159265)) + .unwrap_err(); + + assert!(err.to_string().contains("unknown variant case"), "{err}"); + + // Make sure we lift variants which have cases of different sizes with the correct alignment + + let component = Component::new( + &engine, + make_echo_component_with_params( + r#" + (record + (field "A" (variant + (case "A" u32) + (case "B" float64) + (case "C" (record (field "D" bool) (field "E" u32))))) + (field "B" u32))"#, + &[ + Param(Type::U8, Some(0)), + Param(Type::I64, Some(8)), + Param(Type::I32, None), + Param(Type::I32, Some(16)), + ], + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let ty = &func.params(&store)[0]; + let a_type = &ty.unwrap_record().fields().nth(0).unwrap().ty; + let input = ty.unwrap_record().new_val([ + ( + "A", + a_type.unwrap_variant().new_val("A", Val::U32(314159265))?, + ), + ("B", Val::U32(628318530)), + ])?; + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + + Ok(()) +} + +#[test] +fn flags() -> Result<()> { + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + let component = Component::new( + &engine, + make_echo_component_with_params( + r#"(flags "A" "B" "C" "D" "E")"#, + &[Param(Type::U8, Some(0))], + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let ty = &func.params(&store)[0]; + let input = ty.unwrap_flags().new_val(&["B", "D"])?; + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + + // Sad path: unknown flags + + let err = ty.unwrap_flags().new_val(&["B", "D", "F"]).unwrap_err(); + + assert!(err.to_string().contains("unknown flag"), "{err}"); + + Ok(()) +} + +#[test] +fn everything() -> Result<()> { + // This serves to test both nested types and storing parameters on the heap (i.e. exceeding `MAX_STACK_PARAMS`) + + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + let component = Component::new( + &engine, + make_echo_component_with_params( + r#" + (record + (field "A" u32) + (field "B" (enum "1" "2")) + (field "C" (record (field "D" bool) (field "E" u32))) + (field "F" (list (flags "G" "H" "I"))) + (field "J" (variant + (case "K" u32) + (case "L" float64) + (case "M" (record (field "N" bool) (field "O" u32))))) + (field "P" s8) + (field "Q" s16) + (field "R" s32) + (field "S" s64) + (field "T" float32) + (field "U" float64) + (field "V" string) + (field "W" char) + (field "X" unit) + (field "Y" (tuple u32 u32)) + (field "Z" (union u32 float64)) + (field "AA" (option u32)) + (field "BB" (expected string string)) + )"#, + &[ + Param(Type::I32, Some(0)), + Param(Type::U8, Some(4)), + Param(Type::U8, Some(5)), + Param(Type::I32, Some(8)), + Param(Type::I32, Some(12)), + Param(Type::I32, Some(16)), + Param(Type::U8, Some(20)), + Param(Type::I64, Some(28)), + Param(Type::I32, Some(32)), + Param(Type::S8, Some(36)), + Param(Type::S16, Some(38)), + Param(Type::I32, Some(40)), + Param(Type::I64, Some(48)), + Param(Type::F32, Some(56)), + Param(Type::F64, Some(64)), + Param(Type::I32, Some(72)), + Param(Type::I32, Some(76)), + Param(Type::I32, Some(80)), + Param(Type::I32, Some(84)), + Param(Type::I32, Some(88)), + Param(Type::I64, Some(96)), + Param(Type::U8, Some(104)), + Param(Type::I32, Some(108)), + Param(Type::U8, Some(112)), + Param(Type::I32, Some(116)), + Param(Type::I32, Some(120)), + ], + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "echo").unwrap(); + let ty = &func.params(&store)[0]; + let types = ty + .unwrap_record() + .fields() + .map(|field| field.ty) + .collect::>(); + let (b_type, c_type, f_type, j_type, y_type, z_type, aa_type, bb_type) = ( + &types[1], &types[2], &types[3], &types[4], &types[14], &types[15], &types[16], &types[17], + ); + let f_element_type = &f_type.unwrap_list().ty(); + let input = ty.unwrap_record().new_val([ + ("A", Val::U32(32343)), + ("B", b_type.unwrap_enum().new_val("2")?), + ( + "C", + c_type + .unwrap_record() + .new_val([("D", Val::Bool(false)), ("E", Val::U32(2084037802))])?, + ), + ( + "F", + f_type.unwrap_list().new_val(Box::new([f_element_type + .unwrap_flags() + .new_val(&["G", "I"])?]))?, + ), + ( + "J", + j_type + .unwrap_variant() + .new_val("L", Val::Float64(3.14159265_f64.to_bits()))?, + ), + ("P", Val::S8(42)), + ("Q", Val::S16(4242)), + ("R", Val::S32(42424242)), + ("S", Val::S64(424242424242424242)), + ("T", Val::Float32(3.14159265_f32.to_bits())), + ("U", Val::Float64(3.14159265_f64.to_bits())), + ("V", Val::String(Box::from("wow, nice types"))), + ("W", Val::Char('🦀')), + ("X", Val::Unit), + ( + "Y", + y_type + .unwrap_tuple() + .new_val(Box::new([Val::U32(42), Val::U32(24)]))?, + ), + ( + "Z", + z_type + .unwrap_union() + .new_val(1, Val::Float64(3.14159265_f64.to_bits()))?, + ), + ( + "AA", + aa_type.unwrap_option().new_val(Some(Val::U32(314159265)))?, + ), + ( + "BB", + bb_type + .unwrap_expected() + .new_val(Ok(Val::String(Box::from("no problem"))))?, + ), + ])?; + let output = func.call_and_post_return(&mut store, &[input.clone()])?; + + assert_eq!(input, output); + + Ok(()) +} diff --git a/tests/all/component_model/macros.rs b/tests/all/component_model/macros.rs index 73b18f2a6a78..e8dd38baeee2 100644 --- a/tests/all/component_model/macros.rs +++ b/tests/all/component_model/macros.rs @@ -1,75 +1,9 @@ -use super::TypedFuncExt; +use super::{make_echo_component, TypedFuncExt}; use anyhow::Result; use component_macro_test::{add_variants, flags_test}; -use std::fmt::Write; use wasmtime::component::{Component, ComponentType, Lift, Linker, Lower}; use wasmtime::Store; -fn make_echo_component(type_definition: &str, type_size: u32) -> String { - if type_size <= 4 { - format!( - r#" - (component - (core module $m - (func (export "echo") (param i32) (result i32) - local.get 0 - ) - - (memory (export "memory") 1) - ) - - (core instance $i (instantiate $m)) - - {} - - (func (export "echo") (param $Foo) (result $Foo) - (canon lift (core func $i "echo") (memory $i "memory")) - ) - )"#, - type_definition - ) - } else { - let mut params = String::new(); - let mut store = String::new(); - - for index in 0..(type_size / 4) { - params.push_str(" i32"); - write!( - &mut store, - "(i32.store offset={} (local.get $base) (local.get {}))", - index * 4, - index, - ) - .unwrap(); - } - - format!( - r#" - (component - (core module $m - (func (export "echo") (param{}) (result i32) - (local $base i32) - (local.set $base (i32.const 0)) - {} - local.get $base - ) - - (memory (export "memory") 1) - ) - - (core instance $i (instantiate $m)) - - {} - - (func (export "echo") (param $Foo) (result $Foo) - (canon lift (core func $i "echo") (memory $i "memory")) - ) - )"#, - params, store, type_definition - ) - } -} - #[test] fn record_derive() -> Result<()> { #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)] @@ -87,10 +21,7 @@ fn record_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component( - r#"(type $Foo (record (field "foo-bar-baz" s32) (field "b" u32)))"#, - 8, - ), + make_echo_component(r#"(record (field "foo-bar-baz" s32) (field "b" u32))"#, 8), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -105,7 +36,7 @@ fn record_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (record (field "foo-bar-baz" s32)))"#, 4), + make_echo_component(r#"(record (field "foo-bar-baz" s32))"#, 4), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -118,7 +49,7 @@ fn record_derive() -> Result<()> { let component = Component::new( &engine, make_echo_component( - r#"(type $Foo (record (field "foo-bar-baz" s32) (field "b" u32) (field "c" u32)))"#, + r#"(record (field "foo-bar-baz" s32) (field "b" u32) (field "c" u32))"#, 12, ), )?; @@ -132,7 +63,7 @@ fn record_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (record (field "a" s32) (field "b" u32)))"#, 8), + make_echo_component(r#"(record (field "a" s32) (field "b" u32))"#, 8), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -144,10 +75,7 @@ fn record_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component( - r#"(type $Foo (record (field "foo-bar-baz" s32) (field "b" s32)))"#, - 8, - ), + make_echo_component(r#"(record (field "foo-bar-baz" s32) (field "b" s32))"#, 8), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -172,10 +100,7 @@ fn record_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component( - r#"(type $Foo (record (field "foo-bar-baz" s32) (field "b" u32)))"#, - 8, - ), + make_echo_component(r#"(record (field "foo-bar-baz" s32) (field "b" u32))"#, 8), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -203,10 +128,7 @@ fn union_derive() -> Result<()> { // Happy path: component type matches case count and types - let component = Component::new( - &engine, - make_echo_component(r#"(type $Foo (union s32 u32 s32))"#, 8), - )?; + let component = Component::new(&engine, make_echo_component("(union s32 u32 s32)", 8))?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; let func = instance.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")?; @@ -218,10 +140,7 @@ fn union_derive() -> Result<()> { // Sad path: case count mismatch (too few) - let component = Component::new( - &engine, - make_echo_component(r#"(type $Foo (union s32 u32))"#, 8), - )?; + let component = Component::new(&engine, make_echo_component("(union s32 u32)", 8))?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; assert!(instance @@ -232,7 +151,7 @@ fn union_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (union s32 u32 s32 s32))"#, 8), + make_echo_component(r#"(union s32 u32 s32 s32)"#, 8), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -246,10 +165,7 @@ fn union_derive() -> Result<()> { // Sad path: case type mismatch - let component = Component::new( - &engine, - make_echo_component(r#"(type $Foo (union s32 s32 s32))"#, 8), - )?; + let component = Component::new(&engine, make_echo_component("(union s32 s32 s32)", 8))?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; assert!(instance @@ -266,10 +182,7 @@ fn union_derive() -> Result<()> { C(C), } - let component = Component::new( - &engine, - make_echo_component(r#"(type $Foo (union s32 u32 s32))"#, 8), - )?; + let component = Component::new(&engine, make_echo_component("(union s32 u32 s32)", 8))?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; let func = instance.get_typed_func::<(Generic,), Generic, _>( &mut store, "echo", @@ -307,7 +220,7 @@ fn variant_derive() -> Result<()> { let component = Component::new( &engine, make_echo_component( - r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit)))"#, + r#"(variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit))"#, 8, ), )?; @@ -324,10 +237,7 @@ fn variant_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component( - r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" u32)))"#, - 8, - ), + make_echo_component(r#"(variant (case "foo-bar-baz" s32) (case "B" u32))"#, 8), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -340,7 +250,7 @@ fn variant_derive() -> Result<()> { let component = Component::new( &engine, make_echo_component( - r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit) (case "D" u32)))"#, + r#"(variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit) (case "D" u32))"#, 8, ), )?; @@ -355,7 +265,7 @@ fn variant_derive() -> Result<()> { let component = Component::new( &engine, make_echo_component( - r#"(type $Foo (variant (case "A" s32) (case "B" u32) (case "C" unit)))"#, + r#"(variant (case "A" s32) (case "B" u32) (case "C" unit))"#, 8, ), )?; @@ -370,7 +280,7 @@ fn variant_derive() -> Result<()> { let component = Component::new( &engine, make_echo_component( - r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" s32) (case "C" unit)))"#, + r#"(variant (case "foo-bar-baz" s32) (case "B" s32) (case "C" unit))"#, 8, ), )?; @@ -394,7 +304,7 @@ fn variant_derive() -> Result<()> { let component = Component::new( &engine, make_echo_component( - r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit)))"#, + r#"(variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit))"#, 8, ), )?; @@ -429,7 +339,7 @@ fn enum_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (enum "foo-bar-baz" "B" "C"))"#, 4), + make_echo_component(r#"(enum "foo-bar-baz" "B" "C")"#, 4), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; let func = instance.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")?; @@ -444,7 +354,7 @@ fn enum_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (enum "foo-bar-baz" "B"))"#, 4), + make_echo_component(r#"(enum "foo-bar-baz" "B")"#, 4), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -456,7 +366,7 @@ fn enum_derive() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (enum "foo-bar-baz" "B" "C" "D"))"#, 4), + make_echo_component(r#"(enum "foo-bar-baz" "B" "C" "D")"#, 4), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -466,10 +376,7 @@ fn enum_derive() -> Result<()> { // Sad path: case name mismatch - let component = Component::new( - &engine, - make_echo_component(r#"(type $Foo (enum "A" "B" "C"))"#, 4), - )?; + let component = Component::new(&engine, make_echo_component(r#"(enum "A" "B" "C")"#, 4))?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; assert!(instance @@ -487,7 +394,7 @@ fn enum_derive() -> Result<()> { &engine, make_echo_component( &format!( - r#"(type $Foo (enum {}))"#, + "(enum {})", (0..257) .map(|index| format!(r#""V{}""#, index)) .collect::>() @@ -542,7 +449,7 @@ fn flags() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B" "C"))"#, 4), + make_echo_component(r#"(flags "foo-bar-baz" "B" "C")"#, 4), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; let func = instance.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")?; @@ -568,7 +475,7 @@ fn flags() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B"))"#, 4), + make_echo_component(r#"(flags "foo-bar-baz" "B")"#, 4), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -580,7 +487,7 @@ fn flags() -> Result<()> { let component = Component::new( &engine, - make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B" "C" "D"))"#, 4), + make_echo_component(r#"(flags "foo-bar-baz" "B" "C" "D")"#, 4), )?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; @@ -590,10 +497,7 @@ fn flags() -> Result<()> { // Sad path: flag name mismatch - let component = Component::new( - &engine, - make_echo_component(r#"(type $Foo (flags "A" "B" "C"))"#, 4), - )?; + let component = Component::new(&engine, make_echo_component(r#"(flags "A" "B" "C")"#, 4))?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?; assert!(instance @@ -633,7 +537,7 @@ fn flags() -> Result<()> { &engine, make_echo_component( &format!( - r#"(type $Foo (flags {}))"#, + r#"(flags {})"#, (0..8) .map(|index| format!(r#""F{}""#, index)) .collect::>() @@ -682,7 +586,7 @@ fn flags() -> Result<()> { &engine, make_echo_component( &format!( - r#"(type $Foo (flags {}))"#, + "(flags {})", (0..9) .map(|index| format!(r#""F{}""#, index)) .collect::>() @@ -730,7 +634,7 @@ fn flags() -> Result<()> { &engine, make_echo_component( &format!( - r#"(type $Foo (flags {}))"#, + r#"(flags {})"#, (0..16) .map(|index| format!(r#""F{}""#, index)) .collect::>() @@ -769,7 +673,7 @@ fn flags() -> Result<()> { &engine, make_echo_component( &format!( - r#"(type $Foo (flags {}))"#, + "(flags {})", (0..17) .map(|index| format!(r#""F{}""#, index)) .collect::>() @@ -817,7 +721,7 @@ fn flags() -> Result<()> { &engine, make_echo_component( &format!( - r#"(type $Foo (flags {}))"#, + r#"(flags {})"#, (0..32) .map(|index| format!(r#""F{}""#, index)) .collect::>() @@ -856,7 +760,7 @@ fn flags() -> Result<()> { &engine, make_echo_component( &format!( - r#"(type $Foo (flags {}))"#, + "(flags {})", (0..33) .map(|index| format!(r#""F{}""#, index)) .collect::>() @@ -889,7 +793,7 @@ fn flags() -> Result<()> { &engine, make_echo_component( &format!( - r#"(type $Foo (flags {}))"#, + "(flags {})", (0..65) .map(|index| format!(r#""F{}""#, index)) .collect::>()