Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do proper type checking for type handles. #7065

Merged
merged 1 commit into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 249 additions & 20 deletions crates/wasmtime/src/component/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use std::ops::Deref;
use std::sync::Arc;
use wasmtime_environ::component::{
CanonicalAbiInfo, ComponentTypes, InterfaceType, ResourceIndex, TypeEnumIndex, TypeFlagsIndex,
TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResultIndex, TypeTupleIndex,
TypeVariantIndex,
TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, TypeResultIndex,
TypeTupleIndex, TypeVariantIndex,
};
use wasmtime_environ::PrimaryMap;

Expand Down Expand Up @@ -56,6 +56,29 @@ impl<T> Handle<T> {
resources: &self.resources,
}
}

fn equivalent<'a>(
&'a self,
other: &'a Self,
type_check: fn(&TypeChecker<'a>, T, T) -> bool,
) -> bool
where
T: PartialEq + Copy,
{
(self.index == other.index
&& Arc::ptr_eq(&self.types, &other.types)
&& Arc::ptr_eq(&self.resources, &other.resources))
|| type_check(
&TypeChecker {
a_types: &self.types,
b_types: &other.types,
a_resource: &self.resources,
b_resource: &other.resources,
},
self.index,
other.index,
)
}
}

impl<T: fmt::Debug> fmt::Debug for Handle<T> {
Expand All @@ -66,23 +89,173 @@ impl<T: fmt::Debug> fmt::Debug for Handle<T> {
}
}

impl<T: PartialEq> PartialEq for Handle<T> {
fn eq(&self, other: &Self) -> bool {
// FIXME: This is an overly-restrictive definition of equality in that it doesn't consider types to be
// equal unless they refer to the same declaration in the same component. It's a good shortcut for the
// common case, but we should also do a recursive structural equality test if the shortcut test fails.
self.index == other.index
&& Arc::ptr_eq(&self.types, &other.types)
&& Arc::ptr_eq(&self.resources, &other.resources)
}
/// Type checker between two `Handle`s
struct TypeChecker<'a> {
a_types: &'a ComponentTypes,
a_resource: &'a PrimaryMap<ResourceIndex, ResourceType>,
b_types: &'a ComponentTypes,
b_resource: &'a PrimaryMap<ResourceIndex, ResourceType>,
}

impl<T: Eq> Eq for Handle<T> {}
impl TypeChecker<'_> {
fn interface_types_equal(&self, a: InterfaceType, b: InterfaceType) -> bool {
match (a, b) {
(InterfaceType::Own(o1), InterfaceType::Own(o2)) => self.resources_equal(o1, o2),
(InterfaceType::Own(_), _) => false,
(InterfaceType::Borrow(b1), InterfaceType::Borrow(b2)) => self.resources_equal(b1, b2),
(InterfaceType::Borrow(_), _) => false,
(InterfaceType::List(l1), InterfaceType::List(l2)) => self.lists_equal(l1, l2),
(InterfaceType::List(_), _) => false,
(InterfaceType::Record(r1), InterfaceType::Record(r2)) => self.records_equal(r1, r2),
(InterfaceType::Record(_), _) => false,
(InterfaceType::Variant(v1), InterfaceType::Variant(v2)) => self.variants_equal(v1, v2),
(InterfaceType::Variant(_), _) => false,
(InterfaceType::Result(r1), InterfaceType::Result(r2)) => self.results_equal(r1, r2),
(InterfaceType::Result(_), _) => false,
(InterfaceType::Option(o1), InterfaceType::Option(o2)) => self.options_equal(o1, o2),
(InterfaceType::Option(_), _) => false,
(InterfaceType::Enum(e1), InterfaceType::Enum(e2)) => self.enums_equal(e1, e2),
(InterfaceType::Enum(_), _) => false,
(InterfaceType::Tuple(t1), InterfaceType::Tuple(t2)) => self.tuples_equal(t1, t2),
(InterfaceType::Tuple(_), _) => false,
(InterfaceType::Flags(f1), InterfaceType::Flags(f2)) => self.flags_equal(f1, f2),
(InterfaceType::Flags(_), _) => false,
(InterfaceType::Bool, InterfaceType::Bool) => true,
(InterfaceType::Bool, _) => false,
(InterfaceType::U8, InterfaceType::U8) => true,
(InterfaceType::U8, _) => false,
(InterfaceType::U16, InterfaceType::U16) => true,
(InterfaceType::U16, _) => false,
(InterfaceType::U32, InterfaceType::U32) => true,
(InterfaceType::U32, _) => false,
(InterfaceType::U64, InterfaceType::U64) => true,
(InterfaceType::U64, _) => false,
(InterfaceType::S8, InterfaceType::S8) => true,
(InterfaceType::S8, _) => false,
(InterfaceType::S16, InterfaceType::S16) => true,
(InterfaceType::S16, _) => false,
(InterfaceType::S32, InterfaceType::S32) => true,
(InterfaceType::S32, _) => false,
(InterfaceType::S64, InterfaceType::S64) => true,
(InterfaceType::S64, _) => false,
(InterfaceType::Float32, InterfaceType::Float32) => true,
(InterfaceType::Float32, _) => false,
(InterfaceType::Float64, InterfaceType::Float64) => true,
(InterfaceType::Float64, _) => false,
(InterfaceType::String, InterfaceType::String) => true,
(InterfaceType::String, _) => false,
(InterfaceType::Char, InterfaceType::Char) => true,
(InterfaceType::Char, _) => false,
}
}

fn lists_equal(&self, l1: TypeListIndex, l2: TypeListIndex) -> bool {
let a = &self.a_types[l1];
let b = &self.b_types[l2];
self.interface_types_equal(a.element, b.element)
}

fn resources_equal(&self, o1: TypeResourceTableIndex, o2: TypeResourceTableIndex) -> bool {
let a = &self.a_types[o1];
let b = &self.b_types[o2];
self.a_resource[a.ty] == self.b_resource[b.ty]
}

fn records_equal(&self, r1: TypeRecordIndex, r2: TypeRecordIndex) -> bool {
let a = &self.a_types[r1];
let b = &self.b_types[r2];
if a.fields.len() != b.fields.len() {
return false;
}
a.fields
.iter()
.zip(b.fields.iter())
.all(|(a_field, b_field)| {
a_field.name == b_field.name && self.interface_types_equal(a_field.ty, b_field.ty)
})
}

fn variants_equal(&self, v1: TypeVariantIndex, v2: TypeVariantIndex) -> bool {
let a = &self.a_types[v1];
let b = &self.b_types[v2];
if a.cases.len() != b.cases.len() {
return false;
}
a.cases.iter().zip(b.cases.iter()).all(|(a_case, b_case)| {
if a_case.name != b_case.name {
return false;
}
match (a_case.ty, b_case.ty) {
(Some(a_case_ty), Some(b_case_ty)) => {
self.interface_types_equal(a_case_ty, b_case_ty)
}
(None, None) => true,
_ => false,
}
})
}

fn results_equal(&self, r1: TypeResultIndex, r2: TypeResultIndex) -> bool {
let a = &self.a_types[r1];
let b = &self.b_types[r2];
let oks = match (a.ok, b.ok) {
(Some(ok1), Some(ok2)) => self.interface_types_equal(ok1, ok2),
(None, None) => true,
_ => false,
};
if !oks {
return false;
}
match (a.err, b.err) {
(Some(err1), Some(err2)) => self.interface_types_equal(err1, err2),
(None, None) => true,
_ => false,
}
}

fn options_equal(&self, o1: TypeOptionIndex, o2: TypeOptionIndex) -> bool {
let a = &self.a_types[o1];
let b = &self.b_types[o2];
self.interface_types_equal(a.ty, b.ty)
}

fn enums_equal(&self, e1: TypeEnumIndex, e2: TypeEnumIndex) -> bool {
let a = &self.a_types[e1];
let b = &self.b_types[e2];
a.names == b.names
}

fn tuples_equal(&self, t1: TypeTupleIndex, t2: TypeTupleIndex) -> bool {
let a = &self.a_types[t1];
let b = &self.b_types[t2];
if a.types.len() != b.types.len() {
return false;
}
a.types
.iter()
.zip(b.types.iter())
.all(|(&a, &b)| self.interface_types_equal(a, b))
}

fn flags_equal(&self, f1: TypeFlagsIndex, f2: TypeFlagsIndex) -> bool {
let a = &self.a_types[f1];
let b = &self.b_types[f2];
a.names == b.names
rylev marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// A `list` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct List(Handle<TypeListIndex>);

impl PartialEq for List {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::lists_equal)
}
}

impl Eq for List {}

impl List {
/// Instantiate this type with the specified `values`.
pub fn new_val(&self, values: Box<[Val]>) -> Result<Val> {
Expand All @@ -108,7 +281,7 @@ pub struct Field<'a> {
}

/// A `record` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Record(Handle<TypeRecordIndex>);

impl Record {
Expand All @@ -130,8 +303,16 @@ impl Record {
}
}

impl PartialEq for Record {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::records_equal)
}
}

impl Eq for Record {}

/// A `tuple` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Tuple(Handle<TypeTupleIndex>);

impl Tuple {
Expand All @@ -153,6 +334,14 @@ impl Tuple {
}
}

impl PartialEq for Tuple {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::tuples_equal)
}
}

impl Eq for Tuple {}

/// A case declaration belonging to a `variant`
pub struct Case<'a> {
/// The name of the case
Expand All @@ -162,7 +351,7 @@ pub struct Case<'a> {
}

/// A `variant` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Variant(Handle<TypeVariantIndex>);

impl Variant {
Expand All @@ -187,8 +376,16 @@ impl Variant {
}
}

impl PartialEq for Variant {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::variants_equal)
}
}

impl Eq for Variant {}

/// An `enum` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Enum(Handle<TypeEnumIndex>);

impl Enum {
Expand All @@ -210,8 +407,16 @@ impl Enum {
}
}

impl PartialEq for Enum {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::enums_equal)
}
}

impl Eq for Enum {}

/// An `option` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct OptionType(Handle<TypeOptionIndex>);

impl OptionType {
Expand All @@ -230,8 +435,16 @@ impl OptionType {
}
}

impl PartialEq for OptionType {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::options_equal)
}
}

impl Eq for OptionType {}

/// An `expected` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct ResultType(Handle<TypeResultIndex>);

impl ResultType {
Expand Down Expand Up @@ -261,8 +474,16 @@ impl ResultType {
}
}

impl PartialEq for ResultType {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::results_equal)
}
}

impl Eq for ResultType {}

/// A `flags` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Flags(Handle<TypeFlagsIndex>);

impl Flags {
Expand All @@ -288,6 +509,14 @@ impl Flags {
}
}

impl PartialEq for Flags {
fn eq(&self, other: &Self) -> bool {
self.0.equivalent(&other.0, TypeChecker::flags_equal)
}
}

impl Eq for Flags {}

/// Represents a component model interface type
#[derive(Clone, PartialEq, Eq, Debug)]
#[allow(missing_docs)]
Expand Down
Loading