Skip to content

Commit

Permalink
Do proper type checking for type handles.
Browse files Browse the repository at this point in the history
Instead of relying purely on the assumption that type handles can be compared
cheaply by pointer equality, fallback to a more expensive walk of the
type tree that recursively compares types structurally.

This allows different components to call into each other as long as
their types are structurally equivalent.
  • Loading branch information
rylev committed Sep 19, 2023
1 parent 5359a8c commit 9a7a581
Show file tree
Hide file tree
Showing 2 changed files with 326 additions and 20 deletions.
285 changes: 265 additions & 20 deletions crates/wasmtime/src/component/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
use crate::component::matching::InstanceType;
use crate::component::values::{self, Val};
use anyhow::{anyhow, Result};
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::mem;
use std::ops::Deref;
use std::sync::Arc;
use wasmtime_environ::component::{
CanonicalAbiInfo, ComponentTypes, InterfaceType, ResourceIndex, TypeEnumIndex, TypeFlagsIndex,
TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResultIndex, TypeTupleIndex,
TypeVariantIndex,
TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, TypeResultIndex,
TypeTupleIndex, TypeVariantIndex,
};
use wasmtime_environ::PrimaryMap;

Expand Down Expand Up @@ -56,6 +57,29 @@ impl<T> Handle<T> {
resources: &self.resources,
}
}

fn equivalent<'a>(
&'a self,
other: &'a Self,
type_check: fn(&TypeChecker<'a>, T, T) -> bool,
) -> bool
where
T: PartialEq + Copy,
{
(self.index == other.index
&& Arc::ptr_eq(&self.types, &other.types)
&& Arc::ptr_eq(&self.resources, &other.resources))
|| type_check(
&TypeChecker {
a_types: &self.types,
b_types: &other.types,
a_resource: &self.resources,
b_resource: &self.resources,
},
self.index,
other.index,
)
}
}

impl<T: fmt::Debug> fmt::Debug for Handle<T> {
Expand All @@ -66,23 +90,188 @@ impl<T: fmt::Debug> fmt::Debug for Handle<T> {
}
}

impl<T: PartialEq> PartialEq for Handle<T> {
fn eq(&self, other: &Self) -> bool {
// FIXME: This is an overly-restrictive definition of equality in that it doesn't consider types to be
// equal unless they refer to the same declaration in the same component. It's a good shortcut for the
// common case, but we should also do a recursive structural equality test if the shortcut test fails.
self.index == other.index
&& Arc::ptr_eq(&self.types, &other.types)
&& Arc::ptr_eq(&self.resources, &other.resources)
}
/// Type checker for two
struct TypeChecker<'a> {
a_types: &'a ComponentTypes,
a_resource: &'a PrimaryMap<ResourceIndex, ResourceType>,
b_types: &'a ComponentTypes,
b_resource: &'a PrimaryMap<ResourceIndex, ResourceType>,
}

impl<T: Eq> Eq for Handle<T> {}
impl TypeChecker<'_> {
fn interface_types_equal(&self, a: InterfaceType, b: InterfaceType) -> bool {
match (a, b) {
(InterfaceType::Own(o1), InterfaceType::Own(o2)) => self.resources_equal(o1, o2),
(InterfaceType::Own(_), _) => false,
(InterfaceType::Borrow(b1), InterfaceType::Borrow(b2)) => self.resources_equal(b1, b2),
(InterfaceType::Borrow(_), _) => false,
(InterfaceType::List(l1), InterfaceType::List(l2)) => self.lists_equal(l1, l2),
(InterfaceType::List(_), _) => false,
(InterfaceType::Record(r1), InterfaceType::Record(r2)) => self.records_equal(r1, r2),
(InterfaceType::Record(_), _) => false,
(InterfaceType::Variant(v1), InterfaceType::Variant(v2)) => self.variants_equal(v1, v2),
(InterfaceType::Variant(_), _) => false,
(InterfaceType::Result(r1), InterfaceType::Result(r2)) => self.results_equal(r1, r2),
(InterfaceType::Result(_), _) => false,
(InterfaceType::Option(o1), InterfaceType::Option(o2)) => self.options_equal(o1, o2),
(InterfaceType::Option(_), _) => false,
(InterfaceType::Enum(e1), InterfaceType::Enum(e2)) => self.enums_equal(e1, e2),
(InterfaceType::Enum(_), _) => false,
(InterfaceType::Tuple(t1), InterfaceType::Tuple(t2)) => self.tuples_equal(t1, t2),
(InterfaceType::Tuple(_), _) => false,
(InterfaceType::Flags(f1), InterfaceType::Flags(f2)) => self.flags_equal(f1, f2),
(InterfaceType::Flags(_), _) => false,
(InterfaceType::Bool, InterfaceType::Bool) => true,
(InterfaceType::Bool, _) => false,
(InterfaceType::U8, InterfaceType::U8) => true,
(InterfaceType::U8, _) => false,
(InterfaceType::U16, InterfaceType::U16) => true,
(InterfaceType::U16, _) => false,
(InterfaceType::U32, InterfaceType::U32) => true,
(InterfaceType::U32, _) => false,
(InterfaceType::U64, InterfaceType::U64) => true,
(InterfaceType::U64, _) => false,
(InterfaceType::S8, InterfaceType::S8) => true,
(InterfaceType::S8, _) => false,
(InterfaceType::S16, InterfaceType::S16) => true,
(InterfaceType::S16, _) => false,
(InterfaceType::S32, InterfaceType::S32) => true,
(InterfaceType::S32, _) => false,
(InterfaceType::S64, InterfaceType::S64) => true,
(InterfaceType::S64, _) => false,
(InterfaceType::Float32, InterfaceType::Float32) => true,
(InterfaceType::Float32, _) => false,
(InterfaceType::Float64, InterfaceType::Float64) => true,
(InterfaceType::Float64, _) => false,
(InterfaceType::String, InterfaceType::String) => true,
(InterfaceType::String, _) => false,
(InterfaceType::Char, InterfaceType::Char) => true,
(InterfaceType::Char, _) => false,
}
}

fn lists_equal(&self, l1: TypeListIndex, l2: TypeListIndex) -> bool {
let a = &self.a_types[l1];
let b = &self.b_types[l2];
self.interface_types_equal(a.element, b.element)
}

fn resources_equal(&self, o1: TypeResourceTableIndex, o2: TypeResourceTableIndex) -> bool {
let a = &self.a_types[o1];
let b = &self.b_types[o2];
self.a_resource[a.ty] == self.b_resource[b.ty]
}

fn records_equal(&self, r1: TypeRecordIndex, r2: TypeRecordIndex) -> bool {
let a = &self.a_types[r1];
let b = &self.b_types[r2];
if a.fields.len() != b.fields.len() {
return false;
}
let b_fields = b
.fields
.iter()
.map(|f| (&f.name, &f.ty))
.collect::<HashMap<&String, &InterfaceType>>();
a.fields.iter().all(|a_field| {
let Some(&&b_field_ty) = b_fields.get(&a_field.name) else {
return false;
};
self.interface_types_equal(a_field.ty, b_field_ty)
})
}

fn variants_equal(&self, v1: TypeVariantIndex, v2: TypeVariantIndex) -> bool {
let a = &self.a_types[v1];
let b = &self.b_types[v2];
if a.cases.len() != b.cases.len() {
return false;
}
let b_cases = b
.cases
.iter()
.map(|f| (&f.name, &f.ty))
.collect::<HashMap<&String, &Option<InterfaceType>>>();
a.cases.iter().all(|a_case| {
let Some(&&b_case_ty) = b_cases.get(&a_case.name) else {
return false;
};
match (a_case.ty, b_case_ty) {
(Some(a_case_ty), Some(b_case_ty)) => {
self.interface_types_equal(a_case_ty, b_case_ty)
}
(None, None) => true,
_ => false,
}
})
}

fn results_equal(&self, r1: TypeResultIndex, r2: TypeResultIndex) -> bool {
let a = &self.a_types[r1];
let b = &self.b_types[r2];
let oks = match (a.ok, b.ok) {
(Some(ok1), Some(ok2)) => self.interface_types_equal(ok1, ok2),
(None, None) => true,
_ => false,
};
if !oks {
return false;
}
match (a.err, b.err) {
(Some(err1), Some(err2)) => self.interface_types_equal(err1, err2),
(None, None) => true,
_ => false,
}
}

fn options_equal(&self, o1: TypeOptionIndex, o2: TypeOptionIndex) -> bool {
let a = &self.a_types[o1];
let b = &self.b_types[o2];
self.interface_types_equal(a.ty, b.ty)
}

fn enums_equal(&self, e1: TypeEnumIndex, e2: TypeEnumIndex) -> bool {
let a = &self.a_types[e1];
let b = &self.b_types[e2];
if a.names.len() != b.names.len() {
return false;
}

let b_names = b.names.iter().collect::<HashSet<&String>>();
a.names.iter().all(|a_name| b_names.contains(a_name))
}

fn tuples_equal(&self, t1: TypeTupleIndex, t2: TypeTupleIndex) -> bool {
let a = &self.a_types[t1];
let b = &self.b_types[t2];
if a.types.len() != b.types.len() {
return false;
}
a.types
.iter()
.zip(b.types.iter())
.all(|(&a, &b)| self.interface_types_equal(a, b))
}

fn flags_equal(&self, f1: TypeFlagsIndex, f2: TypeFlagsIndex) -> bool {
let a = &self.a_types[f1];
let b = &self.b_types[f2];
a.names == b.names
}
}

/// A `list` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct List(Handle<TypeListIndex>);

impl PartialEq for List {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::lists_equal)
}
}

impl Eq for List {}

impl List {
/// Instantiate this type with the specified `values`.
pub fn new_val(&self, values: Box<[Val]>) -> Result<Val> {
Expand All @@ -108,7 +297,7 @@ pub struct Field<'a> {
}

/// A `record` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Record(Handle<TypeRecordIndex>);

impl Record {
Expand All @@ -130,8 +319,16 @@ impl Record {
}
}

impl PartialEq for Record {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::records_equal)
}
}

impl Eq for Record {}

/// A `tuple` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Tuple(Handle<TypeTupleIndex>);

impl Tuple {
Expand All @@ -153,6 +350,14 @@ impl Tuple {
}
}

impl PartialEq for Tuple {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::tuples_equal)
}
}

impl Eq for Tuple {}

/// A case declaration belonging to a `variant`
pub struct Case<'a> {
/// The name of the case
Expand All @@ -162,7 +367,7 @@ pub struct Case<'a> {
}

/// A `variant` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Variant(Handle<TypeVariantIndex>);

impl Variant {
Expand All @@ -187,8 +392,16 @@ impl Variant {
}
}

impl PartialEq for Variant {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::variants_equal)
}
}

impl Eq for Variant {}

/// An `enum` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Enum(Handle<TypeEnumIndex>);

impl Enum {
Expand All @@ -210,8 +423,16 @@ impl Enum {
}
}

impl PartialEq for Enum {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::enums_equal)
}
}

impl Eq for Enum {}

/// An `option` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct OptionType(Handle<TypeOptionIndex>);

impl OptionType {
Expand All @@ -230,8 +451,16 @@ impl OptionType {
}
}

impl PartialEq for OptionType {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::options_equal)
}
}

impl Eq for OptionType {}

/// An `expected` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct ResultType(Handle<TypeResultIndex>);

impl ResultType {
Expand Down Expand Up @@ -261,8 +490,16 @@ impl ResultType {
}
}

impl PartialEq for ResultType {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::results_equal)
}
}

impl Eq for ResultType {}

/// A `flags` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Flags(Handle<TypeFlagsIndex>);

impl Flags {
Expand All @@ -288,6 +525,14 @@ impl Flags {
}
}

impl PartialEq for Flags {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::flags_equal)
}
}

impl Eq for Flags {}

/// Represents a component model interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[allow(missing_docs)]
Expand Down
Loading

0 comments on commit 9a7a581

Please sign in to comment.