diff --git a/crates/mun_abi/src/autogen_impl.rs b/crates/mun_abi/src/autogen_impl.rs index 1db879dba..e63b5b4ca 100644 --- a/crates/mun_abi/src/autogen_impl.rs +++ b/crates/mun_abi/src/autogen_impl.rs @@ -146,6 +146,22 @@ impl StructInfo { unsafe { slice::from_raw_parts(self.field_sizes, self.num_fields as usize) } } } + + /// Returns the index of the field matching the specified `field_name`. + pub fn find_field_index(struct_info: &StructInfo, field_name: &str) -> Result { + struct_info + .field_names() + .enumerate() + .find(|(_, name)| *name == field_name) + .map(|(idx, _)| idx) + .ok_or_else(|| { + format!( + "Struct `{}` does not contain field `{}`.", + struct_info.name(), + field_name + ) + }) + } } impl fmt::Display for StructInfo { diff --git a/crates/mun_runtime/src/macros.rs b/crates/mun_runtime/src/macros.rs index 724d95d83..d286dfa37 100644 --- a/crates/mun_runtime/src/macros.rs +++ b/crates/mun_runtime/src/macros.rs @@ -105,53 +105,39 @@ macro_rules! invoke_fn_impl { if arg_types.len() != num_args { return Err(format!( "Invalid number of arguments. Expected: {}. Found: {}.", - num_args, arg_types.len(), + num_args, )); } #[allow(unused_mut, unused_variables)] let mut idx = 0; $( - if arg_types[idx].guid != $Arg.type_guid() { - return Err(format!( - "Invalid argument type at index {}. Expected: {}. Found: {}.", - idx, - $Arg.type_name(), - arg_types[idx].name(), - )); - } + crate::reflection::equals_argument_type(&arg_types[idx], &$Arg) + .map_err(|(expected, found)| { + format!( + "Invalid argument type at index {}. Expected: {}. Found: {}.", + idx, + expected, + found, + ) + })?; idx += 1; )* if let Some(return_type) = function_info.signature.return_type() { - match return_type.group { - abi::TypeGroup::FundamentalTypes => { - if return_type.guid != Output::type_guid() { - return Err(format!( - "Invalid return type. Expected: {}. Found: {}", - Output::type_name(), - return_type.name(), - )); - } - } - abi::TypeGroup::StructTypes => { - if ::type_guid() != Output::type_guid() { - return Err(format!( - "Invalid return type. Expected: {}. Found: Struct", - Output::type_name(), - )); - } - } - } - + crate::reflection::equals_return_type::(return_type) } else if <() as ReturnTypeReflection>::type_guid() != Output::type_guid() { - return Err(format!( + Err((<() as ReturnTypeReflection>::type_name(), Output::type_name())) + } else { + Ok(()) + }.map_err(|(expected, found)| { + format!( "Invalid return type. Expected: {}. Found: {}", - Output::type_name(), - <() as ReturnTypeReflection>::type_name(), - )); - } + expected, + found, + ) + })?; Ok(function_info) }) { diff --git a/crates/mun_runtime/src/reflection.rs b/crates/mun_runtime/src/reflection.rs index 3df68c571..b5d1a8e25 100644 --- a/crates/mun_runtime/src/reflection.rs +++ b/crates/mun_runtime/src/reflection.rs @@ -1,7 +1,38 @@ -use crate::marshal::MarshalInto; -use abi::Guid; +use crate::{marshal::MarshalInto, Struct}; +use abi::{Guid, TypeInfo}; use md5; +/// Returns whether the specified argument type matches the `type_info`. +pub fn equals_argument_type<'e, 'f, T: ArgumentReflection>( + type_info: &'e TypeInfo, + arg: &'f T, +) -> Result<(), (&'e str, &'f str)> { + if type_info.guid != arg.type_guid() { + Err((type_info.name(), arg.type_name())) + } else { + Ok(()) + } +} + +/// Returns whether the specified return type matches the `type_info`. +pub fn equals_return_type( + type_info: &TypeInfo, +) -> Result<(), (&str, &str)> { + match type_info.group { + abi::TypeGroup::FundamentalTypes => { + if type_info.guid != T::type_guid() { + return Err((type_info.name(), T::type_name())); + } + } + abi::TypeGroup::StructTypes => { + if ::type_guid() != T::type_guid() { + return Err(("struct", T::type_name())); + } + } + } + Ok(()) +} + /// A type to emulate dynamic typing across compilation units for static types. pub trait ReturnTypeReflection: Sized + 'static { /// The resulting type after marshaling. @@ -19,9 +50,9 @@ pub trait ReturnTypeReflection: Sized + 'static { } /// A type to emulate dynamic typing across compilation units for statically typed values. -pub trait ArgumentReflection { +pub trait ArgumentReflection: Sized { /// The resulting type after dereferencing. - type Marshalled: Sized; + type Marshalled: MarshalInto; /// Retrieves the `Guid` of the value's type. fn type_guid(&self) -> Guid { diff --git a/crates/mun_runtime/src/struct.rs b/crates/mun_runtime/src/struct.rs index 1472754f0..33fafb6b2 100644 --- a/crates/mun_runtime/src/struct.rs +++ b/crates/mun_runtime/src/struct.rs @@ -1,6 +1,8 @@ use crate::{ marshal::MarshalInto, - reflection::{ArgumentReflection, ReturnTypeReflection}, + reflection::{ + equals_argument_type, equals_return_type, ArgumentReflection, ReturnTypeReflection, + }, }; use abi::{StructInfo, TypeInfo}; use std::mem; @@ -39,38 +41,33 @@ impl Struct { } /// Retrieves the value of the field corresponding to the specified `field_name`. - pub fn get(&self, field_name: &str) -> Result<&T, String> { - let field_idx = self - .info - .field_names() - .enumerate() - .find(|(_, name)| *name == field_name) - .map(|(idx, _)| idx) - .ok_or_else(|| { - format!( - "Struct `{}` does not contain field `{}`.", - self.info.name(), - field_name - ) - })?; - + pub fn get(&self, field_name: &str) -> Result { + let field_idx = StructInfo::find_field_index(&self.info, field_name)?; let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) }; - if T::type_guid() != field_type.guid { - return Err(format!( + equals_return_type::(&field_type).map_err(|(expected, found)| { + format!( "Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.", self.info.name(), field_name, - field_type.name(), - T::type_name() - )); - } + expected, + found, + ) + })?; - unsafe { + let field_value = unsafe { // If we found the `field_idx`, we are guaranteed to also have the `field_offset` let offset = *self.info.field_offsets().get_unchecked(field_idx); // self.ptr is never null - Ok(&*self.raw.0.add(offset as usize).cast::()) - } + // TODO: The unsafe `read` fn could be avoided by adding the `Clone` bound on + // `T::Marshalled`, but its only available on nightly: + // `ReturnTypeReflection` + self.raw + .0 + .add(offset as usize) + .cast::() + .read() + }; + Ok(field_value.marshal_into(Some(*field_type))) } /// Replaces the value of the field corresponding to the specified `field_name` and returns the @@ -78,75 +75,50 @@ impl Struct { pub fn replace( &mut self, field_name: &str, - mut value: T, + value: T, ) -> Result { - let field_idx = self - .info - .field_names() - .enumerate() - .find(|(_, name)| *name == field_name) - .map(|(idx, _)| idx) - .ok_or_else(|| { - format!( - "Struct `{}` does not contain field `{}`.", - self.info.name(), - field_name - ) - })?; - + let field_idx = StructInfo::find_field_index(&self.info, field_name)?; let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) }; - if value.type_guid() != field_type.guid { - return Err(format!( + equals_argument_type(&field_type, &value).map_err(|(expected, found)| { + format!( "Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.", self.info.name(), field_name, - field_type.name(), - value.type_name() - )); - } + expected, + found, + ) + })?; + let mut marshalled: T::Marshalled = value.marshal(); let ptr = unsafe { // If we found the `field_idx`, we are guaranteed to also have the `field_offset` let offset = *self.info.field_offsets().get_unchecked(field_idx); // self.ptr is never null - &mut *self.raw.0.add(offset as usize).cast::() + &mut *self.raw.0.add(offset as usize).cast::() }; - mem::swap(&mut value, ptr); - Ok(value) + mem::swap(&mut marshalled, ptr); + Ok(marshalled.marshal_into(Some(*field_type))) } /// Sets the value of the field corresponding to the specified `field_name`. pub fn set(&mut self, field_name: &str, value: T) -> Result<(), String> { - let field_idx = self - .info - .field_names() - .enumerate() - .find(|(_, name)| *name == field_name) - .map(|(idx, _)| idx) - .ok_or_else(|| { - format!( - "Struct `{}` does not contain field `{}`.", - self.info.name(), - field_name - ) - })?; - + let field_idx = StructInfo::find_field_index(&self.info, field_name)?; let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) }; - if value.type_guid() != field_type.guid { - return Err(format!( + equals_argument_type(&field_type, &value).map_err(|(expected, found)| { + format!( "Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.", self.info.name(), field_name, - field_type.name(), - value.type_name() - )); - } + expected, + found, + ) + })?; unsafe { // If we found the `field_idx`, we are guaranteed to also have the `field_offset` let offset = *self.info.field_offsets().get_unchecked(field_idx); // self.ptr is never null - *self.raw.0.add(offset as usize).cast::() = value; + *self.raw.0.add(offset as usize).cast::() = value.marshal(); } Ok(()) } diff --git a/crates/mun_runtime/src/test.rs b/crates/mun_runtime/src/test.rs index b6996a108..c1dc65446 100644 --- a/crates/mun_runtime/src/test.rs +++ b/crates/mun_runtime/src/test.rs @@ -387,10 +387,15 @@ fn marshal_struct() { let mut driver = TestDriver::new( r#" struct(gc) Foo { a: int, b: bool, c: float, }; + struct Bar(Foo); fn foo_new(a: int, b: bool, c: float): Foo { Foo { a, b, c, } } + fn bar_new(foo: Foo): Bar { + Bar(foo) + } + fn foo_a(foo: Foo):int { foo.a } fn foo_b(foo: Foo):bool { foo.b } fn foo_c(foo: Foo):float { foo.c } @@ -401,9 +406,9 @@ fn marshal_struct() { let b = true; let c = 1.23f64; let mut foo: Struct = invoke_fn!(driver.runtime, "foo_new", a, b, c).unwrap(); - assert_eq!(Ok(&a), foo.get::("a")); - assert_eq!(Ok(&b), foo.get::("b")); - assert_eq!(Ok(&c), foo.get::("c")); + assert_eq!(Ok(a), foo.get::("a")); + assert_eq!(Ok(b), foo.get::("b")); + assert_eq!(Ok(c), foo.get::("c")); let d = 6i64; let e = false; @@ -412,19 +417,45 @@ fn marshal_struct() { foo.set("b", e).unwrap(); foo.set("c", f).unwrap(); - assert_eq!(Ok(&d), foo.get::("a")); - assert_eq!(Ok(&e), foo.get::("b")); - assert_eq!(Ok(&f), foo.get::("c")); + assert_eq!(Ok(d), foo.get::("a")); + assert_eq!(Ok(e), foo.get::("b")); + assert_eq!(Ok(f), foo.get::("c")); assert_eq!(Ok(d), foo.replace("a", a)); assert_eq!(Ok(e), foo.replace("b", b)); assert_eq!(Ok(f), foo.replace("c", c)); - assert_eq!(Ok(&a), foo.get::("a")); - assert_eq!(Ok(&b), foo.get::("b")); - assert_eq!(Ok(&c), foo.get::("c")); + assert_eq!(Ok(a), foo.get::("a")); + assert_eq!(Ok(b), foo.get::("b")); + assert_eq!(Ok(c), foo.get::("c")); assert_invoke_eq!(i64, a, driver, "foo_a", foo.clone()); assert_invoke_eq!(bool, b, driver, "foo_b", foo.clone()); - assert_invoke_eq!(f64, c, driver, "foo_c", foo); + assert_invoke_eq!(f64, c, driver, "foo_c", foo.clone()); + + let mut bar: Struct = invoke_fn!(driver.runtime, "bar_new", foo.clone()).unwrap(); + let foo2 = bar.get::("0").unwrap(); + assert_eq!(Ok(a), foo2.get::("a")); + assert_eq!(foo2.get::("b"), foo.get::("b")); + assert_eq!(foo2.get::("c"), foo.get::("c")); + + // Specify invalid return type + let bar_err = bar.get::("0"); + assert!(bar_err.is_err()); + + // Specify invalid argument type + let bar_err = bar.replace("0", 1i64); + assert!(bar_err.is_err()); + + // Specify invalid argument type + let bar_err = bar.set("0", 1i64); + assert!(bar_err.is_err()); + + // Specify invalid return type + let bar_err: Result = invoke_fn!(driver.runtime, "bar_new", foo); + assert!(bar_err.is_err()); + + // Pass invalid struct type + let bar_err: Result = invoke_fn!(driver.runtime, "bar_new", bar); + assert!(bar_err.is_err()); } diff --git a/crates/mun_runtime_capi/src/lib.rs b/crates/mun_runtime_capi/src/lib.rs index 1ce1b9a05..7a6f6d09c 100644 --- a/crates/mun_runtime_capi/src/lib.rs +++ b/crates/mun_runtime_capi/src/lib.rs @@ -13,7 +13,7 @@ use std::os::raw::c_char; use crate::error::ErrorHandle; use crate::hub::HUB; use failure::err_msg; -use mun_abi::FunctionInfo; +use mun_abi::{FunctionInfo, StructInfo, TypeInfo}; use mun_runtime::{Runtime, RuntimeBuilder}; pub(crate) type Token = usize;