Skip to content

Commit

Permalink
Merge pull request #4 from golemcloud/wasmtime-resource-support
Browse files Browse the repository at this point in the history
GOL-207 Support handles in wasmtime encode/decode, and stub generator fixes
  • Loading branch information
vigoo authored Feb 23, 2024
2 parents 73a4fd4 + 8017fcf commit 1779de3
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 37 deletions.
35 changes: 27 additions & 8 deletions wasm-rpc-stubgen/src/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ pub fn generate_stub_source(def: &StubDefinition) -> anyhow::Result<()> {
if interface.global {
None
} else {
Some(&interface.name)
match &interface.owner_interface {
Some(owner) => Some(format!("{owner}/{}", &interface.name)),
None => Some(interface.name.clone()),
}
},
if interface.is_resource() {
FunctionMode::Method
Expand All @@ -102,7 +105,10 @@ pub fn generate_stub_source(def: &StubDefinition) -> anyhow::Result<()> {
if interface.global {
None
} else {
Some(&interface.name)
match &interface.owner_interface {
Some(owner) => Some(format!("{owner}/{}", &interface.name)),
None => Some(interface.name.clone()),
}
},
FunctionMode::Static,
)?);
Expand All @@ -123,7 +129,11 @@ pub fn generate_stub_source(def: &StubDefinition) -> anyhow::Result<()> {
generate_function_stub_source(
def,
&constructor_stub,
Some(&interface.name),
Some(format!(
"{}/{}",
interface.owner_interface.clone().unwrap_or_default(),
&interface.name
)),
FunctionMode::Constructor,
)?
} else {
Expand All @@ -146,7 +156,15 @@ pub fn generate_stub_source(def: &StubDefinition) -> anyhow::Result<()> {
});

if interface.is_resource() {
let remote_function_name = get_remote_function_name(def, "drop", Some(&interface.name));
let remote_function_name = get_remote_function_name(
def,
"drop",
Some(&format!(
"{}/{}",
interface.owner_interface.clone().unwrap_or_default(),
&interface.name
)),
);
interface_impls.push(quote! {
impl Drop for #interface_name {
fn drop(&mut self) {
Expand Down Expand Up @@ -198,7 +216,7 @@ enum FunctionMode {
fn generate_function_stub_source(
def: &StubDefinition,
function: &FunctionStub,
interface_name: Option<&String>,
interface_name: Option<String>,
mode: FunctionMode,
) -> anyhow::Result<TokenStream> {
let function_name = Ident::new(&to_rust_ident(&function.name), Span::call_site());
Expand Down Expand Up @@ -270,7 +288,7 @@ fn generate_function_stub_source(
output_values.push(extract_from_wit_value(
typ,
&def.resolve,
quote! { result },
quote! { result.tuple_element(0).expect("tuple not found") },
)?);
}
FunctionResultStub::Multi(params) => {
Expand All @@ -285,7 +303,7 @@ fn generate_function_stub_source(
FunctionResultStub::SelfType if mode == FunctionMode::Constructor => {
output_values.push(quote! {
{
let (uri, id) = result.handle().expect("handle not found");
let (uri, id) = result.tuple_element(0).expect("tuple not found").handle().expect("handle not found");
Self {
rpc,
id,
Expand All @@ -301,7 +319,8 @@ fn generate_function_stub_source(
}
}

let remote_function_name = get_remote_function_name(def, &function.name, interface_name);
let remote_function_name =
get_remote_function_name(def, &function.name, interface_name.as_ref());

let rpc = match mode {
FunctionMode::Static => {
Expand Down
21 changes: 15 additions & 6 deletions wasm-rpc-stubgen/src/stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pub struct InterfaceStub {
pub static_functions: Vec<FunctionStub>,
pub imports: Vec<InterfaceStubImport>,
pub global: bool,
pub owner_interface: Option<String>,
}

impl InterfaceStub {
Expand Down Expand Up @@ -292,16 +293,19 @@ fn collect_stub_interfaces(resolve: &Resolve, world: &World) -> anyhow::Result<V
.filter(|f| f.kind == FunctionKind::Freestanding),
)?;
let imports = collect_stub_imports(interface.types.iter(), resolve)?;
let resource_interfaces =
collect_stub_resources(&name, interface.types.iter(), resolve)?;

interfaces.push(InterfaceStub {
name,
functions,
imports,
global: false,
constructor_params: None,
static_functions: vec![],
owner_interface: None,
});

let resource_interfaces = collect_stub_resources(interface.types.iter(), resolve)?;
interfaces.extend(resource_interfaces);
}
}
Expand All @@ -318,6 +322,7 @@ fn collect_stub_interfaces(resolve: &Resolve, world: &World) -> anyhow::Result<V
global: true,
constructor_params: None,
static_functions: vec![],
owner_interface: None,
});
}

Expand Down Expand Up @@ -361,6 +366,7 @@ fn collect_stub_functions<'a>(
}

fn collect_stub_resources<'a>(
owner_interface: &str,
types: impl Iterator<Item = (&'a String, &'a TypeId)>,
resolve: &'a Resolve,
) -> anyhow::Result<Vec<InterfaceStub>> {
Expand Down Expand Up @@ -422,17 +428,20 @@ fn collect_stub_resources<'a>(
.collect::<Vec<_>>()
});

let resource_name = typ
.name
.as_ref()
.ok_or(anyhow!("Resource type has no name"))?
.clone();

interfaces.push(InterfaceStub {
name: typ
.name
.as_ref()
.ok_or(anyhow!("Resource type has no name"))?
.clone(),
name: resource_name,
functions,
imports,
global: false,
constructor_params,
static_functions,
owner_interface: Some(owner_interface.to_string()),
});
}
TypeOwner::None => {}
Expand Down
104 changes: 81 additions & 23 deletions wasm-rpc/src/wasmtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::Value;
use crate::{Uri, Value};
use wasmtime::component::{
types, Enum, Flags, List, OptionVal, Record, ResultVal, Tuple, Type, Val, Variant,
types, Enum, Flags, List, OptionVal, Record, ResourceAny, ResultVal, Tuple, Type, Val, Variant,
};

pub enum EncodingError {
Expand All @@ -23,8 +23,19 @@ pub enum EncodingError {
Unknown { details: String },
}

pub trait ResourceStore {
fn self_uri(&self) -> Uri;
fn add(&mut self, resource: ResourceAny) -> u64;
fn borrow(&self, resource_id: u64) -> Option<ResourceAny>;
fn remove(&mut self, resource_id: u64) -> Option<ResourceAny>;
}

/// Converts a Value to a wasmtime Val based on the available type information.
pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingError> {
pub fn decode_param(
param: &Value,
param_type: &Type,
resource_store: &mut impl ResourceStore,
) -> Result<Val, EncodingError> {
match param_type {
Type::Bool => match param {
Value::Bool(bool) => Ok(Val::Bool(*bool)),
Expand Down Expand Up @@ -82,7 +93,7 @@ pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingErr
Value::List(values) => {
let decoded_values = values
.iter()
.map(|v| decode_param(v, &ty.ty()))
.map(|v| decode_param(v, &ty.ty(), resource_store))
.collect::<Result<Vec<Val>, EncodingError>>()?;
let list = List::new(ty, decoded_values.into_boxed_slice())
.expect("Type mismatch in decode_param");
Expand All @@ -96,7 +107,7 @@ pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingErr
.iter()
.zip(ty.fields())
.map(|(value, field)| {
let decoded_param = decode_param(value, &field.ty)?;
let decoded_param = decode_param(value, &field.ty, resource_store)?;
Ok((field.name, decoded_param))
})
.collect::<Result<Vec<(&str, Val)>, EncodingError>>()?;
Expand All @@ -110,7 +121,7 @@ pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingErr
let tuple_values: Vec<Val> = values
.iter()
.zip(ty.types())
.map(|(value, ty)| decode_param(value, &ty))
.map(|(value, ty)| decode_param(value, &ty, resource_store))
.collect::<Result<Vec<Val>, EncodingError>>()?;
let tuple = Tuple::new(ty, tuple_values.into_boxed_slice())
.expect("Type mismatch in decode_param");
Expand Down Expand Up @@ -138,7 +149,7 @@ pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingErr
}?;
let decoded_value = case_value
.as_ref()
.map(|v| decode_param(v, case_ty))
.map(|v| decode_param(v, case_ty, resource_store))
.transpose()?;
let variant =
Variant::new(ty, name, decoded_value).expect("Type mismatch in decode_param");
Expand Down Expand Up @@ -166,7 +177,7 @@ pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingErr
Type::Option(ty) => match param {
Value::Option(value) => match value {
Some(value) => {
let decoded_value = decode_param(value, &ty.ty())?;
let decoded_value = decode_param(value, &ty.ty(), resource_store)?;
let option = OptionVal::new(ty, Some(decoded_value))
.expect("Type mismatch in decode_param");
Ok(Val::Option(option))
Expand All @@ -186,7 +197,7 @@ pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingErr
})?;
let decoded_value = value
.as_ref()
.map(|v| decode_param(v, &ok_ty))
.map(|v| decode_param(v, &ok_ty, resource_store))
.transpose()?;
let result = ResultVal::new(ty, Ok(decoded_value))
.expect("Type mismatch in decode_param");
Expand All @@ -198,7 +209,7 @@ pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingErr
})?;
let decoded_value = value
.as_ref()
.map(|v| decode_param(v, &err_ty))
.map(|v| decode_param(v, &err_ty, resource_store))
.transpose()?;
let result = ResultVal::new(ty, Err(decoded_value))
.expect("Type mismatch in decode_param");
Expand All @@ -221,13 +232,50 @@ pub fn decode_param(param: &Value, param_type: &Type) -> Result<Val, EncodingErr
}
_ => Err(EncodingError::ParamTypeMismatch),
},
Type::Own(_) => Err(EncodingError::ParamTypeMismatch),
Type::Borrow(_) => Err(EncodingError::ParamTypeMismatch),
Type::Own(_) => match param {
Value::Handle { uri, resource_id } => {
if resource_store.self_uri() == *uri {
match resource_store.remove(*resource_id) {
Some(resource) => Ok(Val::Resource(resource)),
None => Err(EncodingError::ValueMismatch {
details: "resource not found".to_string(),
}),
}
} else {
Err(EncodingError::ValueMismatch {
details: "cannot resolve handle belonging to a different worker"
.to_string(),
})
}
}
_ => Err(EncodingError::ParamTypeMismatch),
},
Type::Borrow(_) => match param {
Value::Handle { uri, resource_id } => {
if resource_store.self_uri() == *uri {
match resource_store.borrow(*resource_id) {
Some(resource) => Ok(Val::Resource(resource)),
None => Err(EncodingError::ValueMismatch {
details: "resource not found".to_string(),
}),
}
} else {
Err(EncodingError::ValueMismatch {
details: "cannot resolve handle belonging to a different worker"
.to_string(),
})
}
}
_ => Err(EncodingError::ParamTypeMismatch),
},
}
}

/// Converts a wasmtime Val to a Golem protobuf Val
pub fn encode_output(value: &Val) -> Result<Value, EncodingError> {
pub fn encode_output(
value: &Val,
resource_store: &mut impl ResourceStore,
) -> Result<Value, EncodingError> {
match value {
Val::Bool(bool) => Ok(Value::Bool(*bool)),
Val::S8(i8) => Ok(Value::S8(*i8)),
Expand All @@ -245,22 +293,22 @@ pub fn encode_output(value: &Val) -> Result<Value, EncodingError> {
Val::List(list) => {
let mut encoded_values = Vec::new();
for value in (*list).iter() {
encoded_values.push(encode_output(value)?);
encoded_values.push(encode_output(value, resource_store)?);
}
Ok(Value::List(encoded_values))
}
Val::Record(record) => {
let encoded_values = record
.fields()
.map(|(_, value)| encode_output(value))
.map(|(_, value)| encode_output(value, resource_store))
.collect::<Result<Vec<Value>, EncodingError>>()?;
Ok(Value::Record(encoded_values))
}
Val::Tuple(tuple) => {
let encoded_values = tuple
.values()
.iter()
.map(encode_output)
.map(|v| encode_output(v, resource_store))
.collect::<Result<Vec<Value>, EncodingError>>()?;
Ok(Value::Tuple(encoded_values))
}
Expand All @@ -271,7 +319,9 @@ pub fn encode_output(value: &Val) -> Result<Value, EncodingError> {
discriminant,
value,
} = wasm_variant;
let encoded_output = value.map(|v| encode_output(&v)).transpose()?;
let encoded_output = value
.map(|v| encode_output(&v, resource_store))
.transpose()?;
Ok(Value::Variant {
case_idx: discriminant,
case_value: encoded_output.map(Box::new),
Expand All @@ -287,18 +337,22 @@ pub fn encode_output(value: &Val) -> Result<Value, EncodingError> {
}
Val::Option(option) => match option.value() {
Some(value) => {
let encoded_output = encode_output(value)?;
let encoded_output = encode_output(value, resource_store)?;
Ok(Value::Option(Some(Box::new(encoded_output))))
}
None => Ok(Value::Option(None)),
},
Val::Result(result) => match result.value() {
Ok(value) => {
let encoded_output = value.map(encode_output).transpose()?;
let encoded_output = value
.map(|v| encode_output(v, resource_store))
.transpose()?;
Ok(Value::Result(Ok(encoded_output.map(Box::new))))
}
Err(value) => {
let encoded_output = value.map(encode_output).transpose()?;
let encoded_output = value
.map(|v| encode_output(v, resource_store))
.transpose()?;
Ok(Value::Result(Err(encoded_output.map(Box::new))))
}
},
Expand All @@ -321,9 +375,13 @@ pub fn encode_output(value: &Val) -> Result<Value, EncodingError> {
}
Ok(Value::Flags(encoded_value))
}
Val::Resource(_) => Err(EncodingError::Unknown {
details: "resource values are not supported yet".to_string(),
}),
Val::Resource(resource) => {
let id = resource_store.add(*resource);
Ok(Value::Handle {
uri: resource_store.self_uri(),
resource_id: id,
})
}
}
}

Expand Down

0 comments on commit 1779de3

Please sign in to comment.