Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions crates/wit-component/src/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ pub fn dummy_module(resolve: &Resolve, world: WorldId) -> Vec<u8> {
push_tys(wat, "param", &sig.params);
push_tys(wat, "result", &sig.results);
wat.push_str(" unreachable)\n");

if resolve.guest_export_needs_post_return(func) {
wat.push_str(&format!("(func (export \"cabi_post_{name}\")"));
push_tys(wat, "param", &sig.results);
wat.push_str(")\n");
}
}

fn push_tys(dst: &mut String, desc: &str, params: &[WasmType]) {
Expand Down
21 changes: 12 additions & 9 deletions crates/wit-component/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@

use crate::encoding::world::WorldAdapter;
use crate::metadata::{self, Bindgen, ModuleMetadata};
use crate::validation::{ValidatedModule, BARE_FUNC_MODULE_NAME, MAIN_MODULE_IMPORT_NAME};
use crate::validation::{
ValidatedModule, BARE_FUNC_MODULE_NAME, MAIN_MODULE_IMPORT_NAME, POST_RETURN_PREFIX,
};
use crate::StringEncoding;
use anyhow::{anyhow, bail, Context, Result};
use indexmap::{IndexMap, IndexSet};
Expand Down Expand Up @@ -1106,6 +1108,10 @@ impl<'a> EncodingState<'a> {
CustomModule::Main => &self.info.encoder.metadata.metadata,
CustomModule::Adapter(name) => &self.info.encoder.adapters[name].metadata,
};
let post_returns = match module {
CustomModule::Main => &self.info.info.post_returns,
CustomModule::Adapter(name) => &self.info.adapters[name].info.post_returns,
};
let instance_index = match module {
CustomModule::Main => self.instance_index.expect("instantiated by now"),
CustomModule::Adapter(name) => self.adapter_instances[name],
Expand All @@ -1128,14 +1134,11 @@ impl<'a> EncodingState<'a> {
.into_iter(encoding, self.memory_index, realloc_index)?
.collect::<Vec<_>>();

// TODO: This should probe for the existence of
// `cabi_post_{name}` but not require its existence.
if resolve.guest_export_needs_post_return(func) {
let post_return = self.component.core_alias_export(
instance_index,
&format!("cabi_post_{core_name}"),
ExportKind::Func,
);
let post_return = format!("{POST_RETURN_PREFIX}{core_name}");
if post_returns.contains(&post_return[..]) {
let post_return =
self.component
.core_alias_export(instance_index, &post_return, ExportKind::Func);
options.push(CanonicalOption::PostReturn(post_return));
}
let func_index = self.component.lift_func(core_func_index, ty, options);
Expand Down
19 changes: 11 additions & 8 deletions crates/wit-component/src/encoding/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,29 +150,32 @@ impl<'a> ComponentWorld<'a> {
/// Returns the set of functions required to be exported from an adapter,
/// either because they're exported from the adapter's world or because
/// they're required as an import to the main module.
fn required_adapter_exports(
fn required_adapter_exports<'r>(
&self,
resolve: &Resolve,
resolve: &'r Resolve,
world: WorldId,
required_exports: &IndexSet<WorldKey>,
required_by_import: Option<&IndexMap<&str, FuncType>>,
) -> IndexMap<String, FuncType> {
) -> IndexMap<String, (FuncType, Option<&'r Function>)> {
use wasmparser::ValType;

let mut required = IndexMap::new();
if let Some(imports) = required_by_import {
for (name, ty) in imports {
required.insert(name.to_string(), ty.clone());
required.insert(name.to_string(), (ty.clone(), None));
}
}
let mut add_func = |func: &Function, name: Option<&str>| {
let mut add_func = |func: &'r Function, name: Option<&str>| {
let name = func.core_export_name(name);
let ty = resolve.wasm_signature(AbiVariant::GuestExport, func);
let prev = required.insert(
name.into_owned(),
wasmparser::FuncType::new(
ty.params.iter().map(to_valty),
ty.results.iter().map(to_valty),
(
wasmparser::FuncType::new(
ty.params.iter().map(to_valty),
ty.results.iter().map(to_valty),
),
Some(func),
),
);
assert!(prev.is_none());
Expand Down
4 changes: 2 additions & 2 deletions crates/wit-component/src/gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ const PAGE_SIZE: i32 = 64 * 1024;
///
/// This internally performs a "gc" pass after removing exports to ensure that
/// the resulting module imports the minimal set of functions necessary.
pub fn run(
pub fn run<T>(
wasm: &[u8],
required: &IndexMap<String, FuncType>,
required: &IndexMap<String, T>,
main_module_realloc: Option<&str>,
) -> Result<Vec<u8>> {
assert!(!required.is_empty());
Expand Down
83 changes: 61 additions & 22 deletions crates/wit-component/src/validation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::metadata::{Bindgen, ModuleMetadata};
use anyhow::{bail, Context, Result};
use indexmap::{map::Entry, IndexMap, IndexSet};
use std::mem;
use wasmparser::names::{KebabName, KebabNameKind};
use wasmparser::{
types::Types, ComponentExternName, Encoding, ExternalKind, FuncType, Parser, Payload, TypeRef,
Expand Down Expand Up @@ -44,6 +45,8 @@ pub const RESOURCE_DROP: &str = "[resource-drop]";
pub const RESOURCE_REP: &str = "[resource-rep]";
pub const RESOURCE_NEW: &str = "[resource-new]";

pub const POST_RETURN_PREFIX: &str = "cabi_post_";

/// Metadata about a validated module and what was found internally.
///
/// All imports to the module are described by the union of `required_imports`
Expand Down Expand Up @@ -90,6 +93,10 @@ pub struct ValidatedModule<'a> {

/// The original metadata specified for this module.
pub metadata: &'a ModuleMetadata,

/// Post-return functions annotated with `cabi_post_*` in their function
/// name.
pub post_returns: IndexSet<String>,
}

#[derive(Default)]
Expand Down Expand Up @@ -135,6 +142,7 @@ pub fn validate_module<'a>(
adapter_realloc: None,
metadata: &metadata.metadata,
required_resource_funcs: Default::default(),
post_returns: Default::default(),
};

for payload in Parser::new(0).parse_all(bytes) {
Expand Down Expand Up @@ -181,7 +189,6 @@ pub fn validate_module<'a>(
if export.name == "cabi_realloc_adapter" {
ret.adapter_realloc = Some(export.name);
}
continue;
}

assert!(export_funcs.insert(export.name, export.index).is_none())
Expand Down Expand Up @@ -338,6 +345,10 @@ pub struct ValidatedAdapter<'a> {

/// Metadata about the original adapter module.
pub metadata: &'a ModuleMetadata,

/// Post-return functions annotated with `cabi_post_*` in their function
/// name.
pub post_returns: IndexSet<String>,
}

/// This function will validate the `bytes` provided as a wasm adapter module.
Expand All @@ -362,7 +373,7 @@ pub fn validate_adapter_module<'a>(
resolve: &'a Resolve,
world: WorldId,
metadata: &'a ModuleMetadata,
required: &IndexMap<String, FuncType>,
required: &IndexMap<String, (FuncType, Option<&Function>)>,
is_library: bool,
) -> Result<ValidatedAdapter<'a>> {
let mut validator = Validator::new();
Expand All @@ -377,6 +388,7 @@ pub fn validate_adapter_module<'a>(
import_realloc: None,
export_realloc: None,
metadata,
post_returns: Default::default(),
};

for payload in Parser::new(0).parse_all(bytes) {
Expand Down Expand Up @@ -490,26 +502,25 @@ pub fn validate_adapter_module<'a>(
}
}

for (name, ty) in required {
for (name, (ty, func)) in required {
let idx = match export_funcs.get(name.as_str()) {
Some(idx) => *idx,
None => bail!("adapter module did not export `{name}`"),
};
let id = types.function_at(idx);
let actual = types[id].unwrap_func();
if ty == actual {
continue;
validate_func_sig(name, ty, actual)?;

if let Some(func) = func {
let post_return = format!("{POST_RETURN_PREFIX}{name}");
if let Some(idx) = export_funcs.get(post_return.as_str()) {
let id = types.function_at(*idx);
let actual = types[id].unwrap_func();
validate_post_return(resolve, &actual, func)?;
let ok = ret.post_returns.insert(post_return);
assert!(ok);
}
}
bail!(
"adapter module export `{name}` does not match the expected signature:\n\
expected: {:?} -> {:?}\n\
actual: {:?} -> {:?}\n\
",
ty.params(),
ty.results(),
actual.params(),
actual.results(),
);
}

Ok(ret)
Expand Down Expand Up @@ -673,6 +684,25 @@ fn validate_func(
)
}

fn validate_post_return(
resolve: &Resolve,
ty: &wasmparser::FuncType,
func: &Function,
) -> Result<()> {
// The expected signature of a post-return function is to take all the
// parameters that are returned by the guest function and then return no
// results. Model this by calculating the signature of `func` and then
// moving its results into the parameters list while emptying out the
// results.
let mut sig = resolve.wasm_signature(AbiVariant::GuestExport, func);
sig.params = mem::take(&mut sig.results);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here might be helpful -- took me a minute to realize this was because the parameters to a post-return function should match the results of the primary function.

validate_func_sig(
&format!("{} post-return", func.name),
&wasm_sig_to_func_type(sig),
ty,
)
}

fn validate_func_sig(name: &str, expected: &FuncType, ty: &wasmparser::FuncType) -> Result<()> {
if ty != expected {
bail!(
Expand All @@ -696,19 +726,28 @@ fn validate_exported_item<'a>(
types: &Types,
info: &mut ValidatedModule<'a>,
) -> Result<()> {
let validate = |func: &Function, name: Option<&str>| {
let mut validate = |func: &Function, name: Option<&str>| {
let expected_export_name = func.core_export_name(name);
match exports.get(expected_export_name.as_ref()) {
Some(func_index) => {
let id = types.function_at(*func_index);
let ty = types[id].unwrap_func();
validate_func(resolve, ty, func, AbiVariant::GuestExport)
}
let func_index = match exports.get(expected_export_name.as_ref()) {
Some(func_index) => func_index,
None => bail!(
"module does not export required function `{}`",
expected_export_name
),
};
let id = types.function_at(*func_index);
let ty = types[id].unwrap_func();
validate_func(resolve, ty, func, AbiVariant::GuestExport)?;

let post_return = format!("{POST_RETURN_PREFIX}{expected_export_name}");
if let Some(index) = exports.get(&post_return[..]) {
let ok = info.post_returns.insert(post_return);
assert!(ok);
let id = types.function_at(*index);
let ty = types[id].unwrap_func();
validate_post_return(resolve, ty, func)?;
}
Ok(())
};
match item {
WorldItem::Function(func) => validate(func, None)?,
Expand Down
50 changes: 0 additions & 50 deletions crates/wit-parser/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,54 +243,4 @@ impl Resolve {
}
}
}

/// Returns whether the `Function` specified needs a post-return function to
/// be generated in guest code.
///
/// This is used when the return value contains a memory allocation such as
/// a list or a string primarily.
pub fn guest_export_needs_post_return(&self, func: &Function) -> bool {
func.results.iter_types().any(|t| self.needs_post_return(t))
}

fn needs_post_return(&self, ty: &Type) -> bool {
match ty {
Type::String => true,
Type::Id(id) => match &self.types[*id].kind {
TypeDefKind::List(_) => true,
TypeDefKind::Type(t) => self.needs_post_return(t),
TypeDefKind::Handle(_) => false,
TypeDefKind::Resource => false,
TypeDefKind::Record(r) => r.fields.iter().any(|f| self.needs_post_return(&f.ty)),
TypeDefKind::Tuple(t) => t.types.iter().any(|t| self.needs_post_return(t)),
TypeDefKind::Union(t) => t.cases.iter().any(|t| self.needs_post_return(&t.ty)),
TypeDefKind::Variant(t) => t
.cases
.iter()
.filter_map(|t| t.ty.as_ref())
.any(|t| self.needs_post_return(t)),
TypeDefKind::Option(t) => self.needs_post_return(t),
TypeDefKind::Result(t) => [&t.ok, &t.err]
.iter()
.filter_map(|t| t.as_ref())
.any(|t| self.needs_post_return(t)),
TypeDefKind::Flags(_) | TypeDefKind::Enum(_) => false,
TypeDefKind::Future(_) | TypeDefKind::Stream(_) => unimplemented!(),
TypeDefKind::Unknown => unreachable!(),
},

Type::Bool
| Type::U8
| Type::S8
| Type::U16
| Type::S16
| Type::U32
| Type::S32
| Type::U64
| Type::S64
| Type::Float32
| Type::Float64
| Type::Char => false,
}
}
}