Skip to content

Commit

Permalink
Do proper type checking for type handles. (#7065)
Browse files Browse the repository at this point in the history
Instead of relying purely on the assumption that type handles can be compared
cheaply by pointer equality, fallback to a more expensive walk of the
type tree that recursively compares types structurally.

This allows different components to call into each other as long as
their types are structurally equivalent.

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>
  • Loading branch information
rylev authored Sep 21, 2023
1 parent b1511dc commit e69a7f7
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 20 deletions.
269 changes: 249 additions & 20 deletions crates/wasmtime/src/component/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use std::ops::Deref;
use std::sync::Arc;
use wasmtime_environ::component::{
CanonicalAbiInfo, ComponentTypes, InterfaceType, ResourceIndex, TypeEnumIndex, TypeFlagsIndex,
TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResultIndex, TypeTupleIndex,
TypeVariantIndex,
TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, TypeResultIndex,
TypeTupleIndex, TypeVariantIndex,
};
use wasmtime_environ::PrimaryMap;

Expand Down Expand Up @@ -56,6 +56,29 @@ impl<T> Handle<T> {
resources: &self.resources,
}
}

fn equivalent<'a>(
&'a self,
other: &'a Self,
type_check: fn(&TypeChecker<'a>, T, T) -> bool,
) -> bool
where
T: PartialEq + Copy,
{
(self.index == other.index
&& Arc::ptr_eq(&self.types, &other.types)
&& Arc::ptr_eq(&self.resources, &other.resources))
|| type_check(
&TypeChecker {
a_types: &self.types,
b_types: &other.types,
a_resource: &self.resources,
b_resource: &other.resources,
},
self.index,
other.index,
)
}
}

impl<T: fmt::Debug> fmt::Debug for Handle<T> {
Expand All @@ -66,23 +89,173 @@ impl<T: fmt::Debug> fmt::Debug for Handle<T> {
}
}

impl<T: PartialEq> PartialEq for Handle<T> {
fn eq(&self, other: &Self) -> bool {
// FIXME: This is an overly-restrictive definition of equality in that it doesn't consider types to be
// equal unless they refer to the same declaration in the same component. It's a good shortcut for the
// common case, but we should also do a recursive structural equality test if the shortcut test fails.
self.index == other.index
&& Arc::ptr_eq(&self.types, &other.types)
&& Arc::ptr_eq(&self.resources, &other.resources)
}
/// Type checker between two `Handle`s
struct TypeChecker<'a> {
a_types: &'a ComponentTypes,
a_resource: &'a PrimaryMap<ResourceIndex, ResourceType>,
b_types: &'a ComponentTypes,
b_resource: &'a PrimaryMap<ResourceIndex, ResourceType>,
}

impl<T: Eq> Eq for Handle<T> {}
impl TypeChecker<'_> {
fn interface_types_equal(&self, a: InterfaceType, b: InterfaceType) -> bool {
match (a, b) {
(InterfaceType::Own(o1), InterfaceType::Own(o2)) => self.resources_equal(o1, o2),
(InterfaceType::Own(_), _) => false,
(InterfaceType::Borrow(b1), InterfaceType::Borrow(b2)) => self.resources_equal(b1, b2),
(InterfaceType::Borrow(_), _) => false,
(InterfaceType::List(l1), InterfaceType::List(l2)) => self.lists_equal(l1, l2),
(InterfaceType::List(_), _) => false,
(InterfaceType::Record(r1), InterfaceType::Record(r2)) => self.records_equal(r1, r2),
(InterfaceType::Record(_), _) => false,
(InterfaceType::Variant(v1), InterfaceType::Variant(v2)) => self.variants_equal(v1, v2),
(InterfaceType::Variant(_), _) => false,
(InterfaceType::Result(r1), InterfaceType::Result(r2)) => self.results_equal(r1, r2),
(InterfaceType::Result(_), _) => false,
(InterfaceType::Option(o1), InterfaceType::Option(o2)) => self.options_equal(o1, o2),
(InterfaceType::Option(_), _) => false,
(InterfaceType::Enum(e1), InterfaceType::Enum(e2)) => self.enums_equal(e1, e2),
(InterfaceType::Enum(_), _) => false,
(InterfaceType::Tuple(t1), InterfaceType::Tuple(t2)) => self.tuples_equal(t1, t2),
(InterfaceType::Tuple(_), _) => false,
(InterfaceType::Flags(f1), InterfaceType::Flags(f2)) => self.flags_equal(f1, f2),
(InterfaceType::Flags(_), _) => false,
(InterfaceType::Bool, InterfaceType::Bool) => true,
(InterfaceType::Bool, _) => false,
(InterfaceType::U8, InterfaceType::U8) => true,
(InterfaceType::U8, _) => false,
(InterfaceType::U16, InterfaceType::U16) => true,
(InterfaceType::U16, _) => false,
(InterfaceType::U32, InterfaceType::U32) => true,
(InterfaceType::U32, _) => false,
(InterfaceType::U64, InterfaceType::U64) => true,
(InterfaceType::U64, _) => false,
(InterfaceType::S8, InterfaceType::S8) => true,
(InterfaceType::S8, _) => false,
(InterfaceType::S16, InterfaceType::S16) => true,
(InterfaceType::S16, _) => false,
(InterfaceType::S32, InterfaceType::S32) => true,
(InterfaceType::S32, _) => false,
(InterfaceType::S64, InterfaceType::S64) => true,
(InterfaceType::S64, _) => false,
(InterfaceType::Float32, InterfaceType::Float32) => true,
(InterfaceType::Float32, _) => false,
(InterfaceType::Float64, InterfaceType::Float64) => true,
(InterfaceType::Float64, _) => false,
(InterfaceType::String, InterfaceType::String) => true,
(InterfaceType::String, _) => false,
(InterfaceType::Char, InterfaceType::Char) => true,
(InterfaceType::Char, _) => false,
}
}

fn lists_equal(&self, l1: TypeListIndex, l2: TypeListIndex) -> bool {
let a = &self.a_types[l1];
let b = &self.b_types[l2];
self.interface_types_equal(a.element, b.element)
}

fn resources_equal(&self, o1: TypeResourceTableIndex, o2: TypeResourceTableIndex) -> bool {
let a = &self.a_types[o1];
let b = &self.b_types[o2];
self.a_resource[a.ty] == self.b_resource[b.ty]
}

fn records_equal(&self, r1: TypeRecordIndex, r2: TypeRecordIndex) -> bool {
let a = &self.a_types[r1];
let b = &self.b_types[r2];
if a.fields.len() != b.fields.len() {
return false;
}
a.fields
.iter()
.zip(b.fields.iter())
.all(|(a_field, b_field)| {
a_field.name == b_field.name && self.interface_types_equal(a_field.ty, b_field.ty)
})
}

fn variants_equal(&self, v1: TypeVariantIndex, v2: TypeVariantIndex) -> bool {
let a = &self.a_types[v1];
let b = &self.b_types[v2];
if a.cases.len() != b.cases.len() {
return false;
}
a.cases.iter().zip(b.cases.iter()).all(|(a_case, b_case)| {
if a_case.name != b_case.name {
return false;
}
match (a_case.ty, b_case.ty) {
(Some(a_case_ty), Some(b_case_ty)) => {
self.interface_types_equal(a_case_ty, b_case_ty)
}
(None, None) => true,
_ => false,
}
})
}

fn results_equal(&self, r1: TypeResultIndex, r2: TypeResultIndex) -> bool {
let a = &self.a_types[r1];
let b = &self.b_types[r2];
let oks = match (a.ok, b.ok) {
(Some(ok1), Some(ok2)) => self.interface_types_equal(ok1, ok2),
(None, None) => true,
_ => false,
};
if !oks {
return false;
}
match (a.err, b.err) {
(Some(err1), Some(err2)) => self.interface_types_equal(err1, err2),
(None, None) => true,
_ => false,
}
}

fn options_equal(&self, o1: TypeOptionIndex, o2: TypeOptionIndex) -> bool {
let a = &self.a_types[o1];
let b = &self.b_types[o2];
self.interface_types_equal(a.ty, b.ty)
}

fn enums_equal(&self, e1: TypeEnumIndex, e2: TypeEnumIndex) -> bool {
let a = &self.a_types[e1];
let b = &self.b_types[e2];
a.names == b.names
}

fn tuples_equal(&self, t1: TypeTupleIndex, t2: TypeTupleIndex) -> bool {
let a = &self.a_types[t1];
let b = &self.b_types[t2];
if a.types.len() != b.types.len() {
return false;
}
a.types
.iter()
.zip(b.types.iter())
.all(|(&a, &b)| self.interface_types_equal(a, b))
}

fn flags_equal(&self, f1: TypeFlagsIndex, f2: TypeFlagsIndex) -> bool {
let a = &self.a_types[f1];
let b = &self.b_types[f2];
a.names == b.names
}
}

/// A `list` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct List(Handle<TypeListIndex>);

impl PartialEq for List {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::lists_equal)
}
}

impl Eq for List {}

impl List {
/// Instantiate this type with the specified `values`.
pub fn new_val(&self, values: Box<[Val]>) -> Result<Val> {
Expand All @@ -108,7 +281,7 @@ pub struct Field<'a> {
}

/// A `record` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Record(Handle<TypeRecordIndex>);

impl Record {
Expand All @@ -130,8 +303,16 @@ impl Record {
}
}

impl PartialEq for Record {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::records_equal)
}
}

impl Eq for Record {}

/// A `tuple` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Tuple(Handle<TypeTupleIndex>);

impl Tuple {
Expand All @@ -153,6 +334,14 @@ impl Tuple {
}
}

impl PartialEq for Tuple {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::tuples_equal)
}
}

impl Eq for Tuple {}

/// A case declaration belonging to a `variant`
pub struct Case<'a> {
/// The name of the case
Expand All @@ -162,7 +351,7 @@ pub struct Case<'a> {
}

/// A `variant` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Variant(Handle<TypeVariantIndex>);

impl Variant {
Expand All @@ -187,8 +376,16 @@ impl Variant {
}
}

impl PartialEq for Variant {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::variants_equal)
}
}

impl Eq for Variant {}

/// An `enum` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Enum(Handle<TypeEnumIndex>);

impl Enum {
Expand All @@ -210,8 +407,16 @@ impl Enum {
}
}

impl PartialEq for Enum {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::enums_equal)
}
}

impl Eq for Enum {}

/// An `option` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct OptionType(Handle<TypeOptionIndex>);

impl OptionType {
Expand All @@ -230,8 +435,16 @@ impl OptionType {
}
}

impl PartialEq for OptionType {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::options_equal)
}
}

impl Eq for OptionType {}

/// An `expected` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct ResultType(Handle<TypeResultIndex>);

impl ResultType {
Expand Down Expand Up @@ -261,8 +474,16 @@ impl ResultType {
}
}

impl PartialEq for ResultType {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::results_equal)
}
}

impl Eq for ResultType {}

/// A `flags` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Flags(Handle<TypeFlagsIndex>);

impl Flags {
Expand All @@ -288,6 +509,14 @@ impl Flags {
}
}

impl PartialEq for Flags {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::flags_equal)
}
}

impl Eq for Flags {}

/// Represents a component model interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[allow(missing_docs)]
Expand Down
Loading

0 comments on commit e69a7f7

Please sign in to comment.