Skip to content

Commit a0c46a7

Browse files
authored
Probe for post-return instead of asserting existence (#1159)
* Probe for post-return instead of asserting existence This commit updates `wit-component` to probe for the existence of post-return functions rather than asserting their existence. This enables moves some more ABI bits to `wit-bindgen-core` since they otherwise don't quite make sense here. This additionally fixes a longstanding issue with `wit-component`, although it's not one that's actually come up in practice yet. * Remove no-longer-relevant comment * Fix rebase conflict * Review comments
1 parent 4678a61 commit a0c46a7

File tree

6 files changed

+86
-97
lines changed

6 files changed

+86
-97
lines changed

crates/wit-component/src/dummy.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,6 @@ pub fn dummy_module(resolve: &Resolve, world: WorldId) -> Vec<u8> {
113113
push_tys(wat, "param", &sig.params);
114114
push_tys(wat, "result", &sig.results);
115115
wat.push_str(" unreachable)\n");
116-
117-
if resolve.guest_export_needs_post_return(func) {
118-
wat.push_str(&format!("(func (export \"cabi_post_{name}\")"));
119-
push_tys(wat, "param", &sig.results);
120-
wat.push_str(")\n");
121-
}
122116
}
123117

124118
fn push_tys(dst: &mut String, desc: &str, params: &[WasmType]) {

crates/wit-component/src/encoding.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@
7373
7474
use crate::encoding::world::WorldAdapter;
7575
use crate::metadata::{self, Bindgen, ModuleMetadata};
76-
use crate::validation::{ValidatedModule, BARE_FUNC_MODULE_NAME, MAIN_MODULE_IMPORT_NAME};
76+
use crate::validation::{
77+
ValidatedModule, BARE_FUNC_MODULE_NAME, MAIN_MODULE_IMPORT_NAME, POST_RETURN_PREFIX,
78+
};
7779
use crate::StringEncoding;
7880
use anyhow::{anyhow, bail, Context, Result};
7981
use indexmap::{IndexMap, IndexSet};
@@ -1106,6 +1108,10 @@ impl<'a> EncodingState<'a> {
11061108
CustomModule::Main => &self.info.encoder.metadata.metadata,
11071109
CustomModule::Adapter(name) => &self.info.encoder.adapters[name].metadata,
11081110
};
1111+
let post_returns = match module {
1112+
CustomModule::Main => &self.info.info.post_returns,
1113+
CustomModule::Adapter(name) => &self.info.adapters[name].info.post_returns,
1114+
};
11091115
let instance_index = match module {
11101116
CustomModule::Main => self.instance_index.expect("instantiated by now"),
11111117
CustomModule::Adapter(name) => self.adapter_instances[name],
@@ -1128,14 +1134,11 @@ impl<'a> EncodingState<'a> {
11281134
.into_iter(encoding, self.memory_index, realloc_index)?
11291135
.collect::<Vec<_>>();
11301136

1131-
// TODO: This should probe for the existence of
1132-
// `cabi_post_{name}` but not require its existence.
1133-
if resolve.guest_export_needs_post_return(func) {
1134-
let post_return = self.component.core_alias_export(
1135-
instance_index,
1136-
&format!("cabi_post_{core_name}"),
1137-
ExportKind::Func,
1138-
);
1137+
let post_return = format!("{POST_RETURN_PREFIX}{core_name}");
1138+
if post_returns.contains(&post_return[..]) {
1139+
let post_return =
1140+
self.component
1141+
.core_alias_export(instance_index, &post_return, ExportKind::Func);
11391142
options.push(CanonicalOption::PostReturn(post_return));
11401143
}
11411144
let func_index = self.component.lift_func(core_func_index, ty, options);

crates/wit-component/src/encoding/world.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,29 +150,32 @@ impl<'a> ComponentWorld<'a> {
150150
/// Returns the set of functions required to be exported from an adapter,
151151
/// either because they're exported from the adapter's world or because
152152
/// they're required as an import to the main module.
153-
fn required_adapter_exports(
153+
fn required_adapter_exports<'r>(
154154
&self,
155-
resolve: &Resolve,
155+
resolve: &'r Resolve,
156156
world: WorldId,
157157
required_exports: &IndexSet<WorldKey>,
158158
required_by_import: Option<&IndexMap<&str, FuncType>>,
159-
) -> IndexMap<String, FuncType> {
159+
) -> IndexMap<String, (FuncType, Option<&'r Function>)> {
160160
use wasmparser::ValType;
161161

162162
let mut required = IndexMap::new();
163163
if let Some(imports) = required_by_import {
164164
for (name, ty) in imports {
165-
required.insert(name.to_string(), ty.clone());
165+
required.insert(name.to_string(), (ty.clone(), None));
166166
}
167167
}
168-
let mut add_func = |func: &Function, name: Option<&str>| {
168+
let mut add_func = |func: &'r Function, name: Option<&str>| {
169169
let name = func.core_export_name(name);
170170
let ty = resolve.wasm_signature(AbiVariant::GuestExport, func);
171171
let prev = required.insert(
172172
name.into_owned(),
173-
wasmparser::FuncType::new(
174-
ty.params.iter().map(to_valty),
175-
ty.results.iter().map(to_valty),
173+
(
174+
wasmparser::FuncType::new(
175+
ty.params.iter().map(to_valty),
176+
ty.results.iter().map(to_valty),
177+
),
178+
Some(func),
176179
),
177180
);
178181
assert!(prev.is_none());

crates/wit-component/src/gc.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ const PAGE_SIZE: i32 = 64 * 1024;
1717
///
1818
/// This internally performs a "gc" pass after removing exports to ensure that
1919
/// the resulting module imports the minimal set of functions necessary.
20-
pub fn run(
20+
pub fn run<T>(
2121
wasm: &[u8],
22-
required: &IndexMap<String, FuncType>,
22+
required: &IndexMap<String, T>,
2323
main_module_realloc: Option<&str>,
2424
) -> Result<Vec<u8>> {
2525
assert!(!required.is_empty());

crates/wit-component/src/validation.rs

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::metadata::{Bindgen, ModuleMetadata};
22
use anyhow::{bail, Context, Result};
33
use indexmap::{map::Entry, IndexMap, IndexSet};
4+
use std::mem;
45
use wasmparser::names::{KebabName, KebabNameKind};
56
use wasmparser::{
67
types::Types, ComponentExternName, Encoding, ExternalKind, FuncType, Parser, Payload, TypeRef,
@@ -44,6 +45,8 @@ pub const RESOURCE_DROP: &str = "[resource-drop]";
4445
pub const RESOURCE_REP: &str = "[resource-rep]";
4546
pub const RESOURCE_NEW: &str = "[resource-new]";
4647

48+
pub const POST_RETURN_PREFIX: &str = "cabi_post_";
49+
4750
/// Metadata about a validated module and what was found internally.
4851
///
4952
/// All imports to the module are described by the union of `required_imports`
@@ -90,6 +93,10 @@ pub struct ValidatedModule<'a> {
9093

9194
/// The original metadata specified for this module.
9295
pub metadata: &'a ModuleMetadata,
96+
97+
/// Post-return functions annotated with `cabi_post_*` in their function
98+
/// name.
99+
pub post_returns: IndexSet<String>,
93100
}
94101

95102
#[derive(Default)]
@@ -135,6 +142,7 @@ pub fn validate_module<'a>(
135142
adapter_realloc: None,
136143
metadata: &metadata.metadata,
137144
required_resource_funcs: Default::default(),
145+
post_returns: Default::default(),
138146
};
139147

140148
for payload in Parser::new(0).parse_all(bytes) {
@@ -181,7 +189,6 @@ pub fn validate_module<'a>(
181189
if export.name == "cabi_realloc_adapter" {
182190
ret.adapter_realloc = Some(export.name);
183191
}
184-
continue;
185192
}
186193

187194
assert!(export_funcs.insert(export.name, export.index).is_none())
@@ -338,6 +345,10 @@ pub struct ValidatedAdapter<'a> {
338345

339346
/// Metadata about the original adapter module.
340347
pub metadata: &'a ModuleMetadata,
348+
349+
/// Post-return functions annotated with `cabi_post_*` in their function
350+
/// name.
351+
pub post_returns: IndexSet<String>,
341352
}
342353

343354
/// This function will validate the `bytes` provided as a wasm adapter module.
@@ -362,7 +373,7 @@ pub fn validate_adapter_module<'a>(
362373
resolve: &'a Resolve,
363374
world: WorldId,
364375
metadata: &'a ModuleMetadata,
365-
required: &IndexMap<String, FuncType>,
376+
required: &IndexMap<String, (FuncType, Option<&Function>)>,
366377
is_library: bool,
367378
) -> Result<ValidatedAdapter<'a>> {
368379
let mut validator = Validator::new();
@@ -377,6 +388,7 @@ pub fn validate_adapter_module<'a>(
377388
import_realloc: None,
378389
export_realloc: None,
379390
metadata,
391+
post_returns: Default::default(),
380392
};
381393

382394
for payload in Parser::new(0).parse_all(bytes) {
@@ -490,26 +502,25 @@ pub fn validate_adapter_module<'a>(
490502
}
491503
}
492504

493-
for (name, ty) in required {
505+
for (name, (ty, func)) in required {
494506
let idx = match export_funcs.get(name.as_str()) {
495507
Some(idx) => *idx,
496508
None => bail!("adapter module did not export `{name}`"),
497509
};
498510
let id = types.function_at(idx);
499511
let actual = types[id].unwrap_func();
500-
if ty == actual {
501-
continue;
512+
validate_func_sig(name, ty, actual)?;
513+
514+
if let Some(func) = func {
515+
let post_return = format!("{POST_RETURN_PREFIX}{name}");
516+
if let Some(idx) = export_funcs.get(post_return.as_str()) {
517+
let id = types.function_at(*idx);
518+
let actual = types[id].unwrap_func();
519+
validate_post_return(resolve, &actual, func)?;
520+
let ok = ret.post_returns.insert(post_return);
521+
assert!(ok);
522+
}
502523
}
503-
bail!(
504-
"adapter module export `{name}` does not match the expected signature:\n\
505-
expected: {:?} -> {:?}\n\
506-
actual: {:?} -> {:?}\n\
507-
",
508-
ty.params(),
509-
ty.results(),
510-
actual.params(),
511-
actual.results(),
512-
);
513524
}
514525

515526
Ok(ret)
@@ -673,6 +684,25 @@ fn validate_func(
673684
)
674685
}
675686

687+
fn validate_post_return(
688+
resolve: &Resolve,
689+
ty: &wasmparser::FuncType,
690+
func: &Function,
691+
) -> Result<()> {
692+
// The expected signature of a post-return function is to take all the
693+
// parameters that are returned by the guest function and then return no
694+
// results. Model this by calculating the signature of `func` and then
695+
// moving its results into the parameters list while emptying out the
696+
// results.
697+
let mut sig = resolve.wasm_signature(AbiVariant::GuestExport, func);
698+
sig.params = mem::take(&mut sig.results);
699+
validate_func_sig(
700+
&format!("{} post-return", func.name),
701+
&wasm_sig_to_func_type(sig),
702+
ty,
703+
)
704+
}
705+
676706
fn validate_func_sig(name: &str, expected: &FuncType, ty: &wasmparser::FuncType) -> Result<()> {
677707
if ty != expected {
678708
bail!(
@@ -696,19 +726,28 @@ fn validate_exported_item<'a>(
696726
types: &Types,
697727
info: &mut ValidatedModule<'a>,
698728
) -> Result<()> {
699-
let validate = |func: &Function, name: Option<&str>| {
729+
let mut validate = |func: &Function, name: Option<&str>| {
700730
let expected_export_name = func.core_export_name(name);
701-
match exports.get(expected_export_name.as_ref()) {
702-
Some(func_index) => {
703-
let id = types.function_at(*func_index);
704-
let ty = types[id].unwrap_func();
705-
validate_func(resolve, ty, func, AbiVariant::GuestExport)
706-
}
731+
let func_index = match exports.get(expected_export_name.as_ref()) {
732+
Some(func_index) => func_index,
707733
None => bail!(
708734
"module does not export required function `{}`",
709735
expected_export_name
710736
),
737+
};
738+
let id = types.function_at(*func_index);
739+
let ty = types[id].unwrap_func();
740+
validate_func(resolve, ty, func, AbiVariant::GuestExport)?;
741+
742+
let post_return = format!("{POST_RETURN_PREFIX}{expected_export_name}");
743+
if let Some(index) = exports.get(&post_return[..]) {
744+
let ok = info.post_returns.insert(post_return);
745+
assert!(ok);
746+
let id = types.function_at(*index);
747+
let ty = types[id].unwrap_func();
748+
validate_post_return(resolve, ty, func)?;
711749
}
750+
Ok(())
712751
};
713752
match item {
714753
WorldItem::Function(func) => validate(func, None)?,

crates/wit-parser/src/abi.rs

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -243,54 +243,4 @@ impl Resolve {
243243
}
244244
}
245245
}
246-
247-
/// Returns whether the `Function` specified needs a post-return function to
248-
/// be generated in guest code.
249-
///
250-
/// This is used when the return value contains a memory allocation such as
251-
/// a list or a string primarily.
252-
pub fn guest_export_needs_post_return(&self, func: &Function) -> bool {
253-
func.results.iter_types().any(|t| self.needs_post_return(t))
254-
}
255-
256-
fn needs_post_return(&self, ty: &Type) -> bool {
257-
match ty {
258-
Type::String => true,
259-
Type::Id(id) => match &self.types[*id].kind {
260-
TypeDefKind::List(_) => true,
261-
TypeDefKind::Type(t) => self.needs_post_return(t),
262-
TypeDefKind::Handle(_) => false,
263-
TypeDefKind::Resource => false,
264-
TypeDefKind::Record(r) => r.fields.iter().any(|f| self.needs_post_return(&f.ty)),
265-
TypeDefKind::Tuple(t) => t.types.iter().any(|t| self.needs_post_return(t)),
266-
TypeDefKind::Union(t) => t.cases.iter().any(|t| self.needs_post_return(&t.ty)),
267-
TypeDefKind::Variant(t) => t
268-
.cases
269-
.iter()
270-
.filter_map(|t| t.ty.as_ref())
271-
.any(|t| self.needs_post_return(t)),
272-
TypeDefKind::Option(t) => self.needs_post_return(t),
273-
TypeDefKind::Result(t) => [&t.ok, &t.err]
274-
.iter()
275-
.filter_map(|t| t.as_ref())
276-
.any(|t| self.needs_post_return(t)),
277-
TypeDefKind::Flags(_) | TypeDefKind::Enum(_) => false,
278-
TypeDefKind::Future(_) | TypeDefKind::Stream(_) => unimplemented!(),
279-
TypeDefKind::Unknown => unreachable!(),
280-
},
281-
282-
Type::Bool
283-
| Type::U8
284-
| Type::S8
285-
| Type::U16
286-
| Type::S16
287-
| Type::U32
288-
| Type::S32
289-
| Type::U64
290-
| Type::S64
291-
| Type::Float32
292-
| Type::Float64
293-
| Type::Char => false,
294-
}
295-
}
296246
}

0 commit comments

Comments
 (0)