Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust SDK: no more reducer args structs #2036

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 149 additions & 53 deletions crates/cli/src/subcommands/generate/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Requested namespace: {namespace}",
AlgebraicTypeDef::Product(product) => {
gen_and_print_imports(module, out, &product.elements, &[typ.ty]);
out.newline();
define_struct_for_product(module, out, &type_name, &product.elements);
define_struct_for_product(module, out, &type_name, &product.elements, "pub");
}
AlgebraicTypeDef::Sum(sum) => {
gen_and_print_imports(module, out, &sum.variants, &[typ.ty]);
Expand Down Expand Up @@ -130,7 +130,7 @@ Requested namespace: {namespace}",
let table_handle = table_name_pascalcase.clone() + "TableHandle";
let insert_callback_id = table_name_pascalcase.clone() + "InsertCallbackId";
let delete_callback_id = table_name_pascalcase.clone() + "DeleteCallbackId";
let accessor_trait = table_name_pascalcase.clone() + "TableAccess";
let accessor_trait = table_access_trait_name(&table.name);
let accessor_method = table_method_name(&table.name);

write!(
Expand Down Expand Up @@ -353,14 +353,31 @@ Requested namespace: {namespace}",

let reducer_name = reducer.name.deref();
let func_name = reducer_function_name(reducer);
let set_reducer_flags_trait = format!("set_flags_for_{func_name}");
let set_reducer_flags_trait = reducer_flags_trait_name(reducer);
let args_type = reducer_args_type_name(&reducer.name);

define_struct_for_product(module, out, &args_type, &reducer.params_for_generate.elements);
let enum_variant_name = reducer_variant_name(&reducer.name);

// Define an "args struct" for the reducer.
// This is not user-facing (note the `pub(super)` visibility);
// it is an internal helper for serialization and deserialization.
// We actually want to ser/de instances of `enum Reducer`, but:
// - `Reducer` will have struct-like variants, which SATS ser/de does not support.
// - The WS format does not contain a BSATN-serialized `Reducer` instance;
// it holds the reducer name or ID separately from the argument bytes.
// We could work up some magic with `DeserializeSeed`
// and/or custom `Serializer` and `Deserializer` types
// to account for this, but it's much easier to just use an intermediate struct per reducer.
define_struct_for_product(
module,
out,
&args_type,
&reducer.params_for_generate.elements,
"pub(super)",
);

out.newline();

let callback_id = args_type.clone() + "CallbackId";
let callback_id = reducer_callback_id_name(&reducer.name);

// The reducer arguments as `ident: ty, ident: ty, ident: ty,`,
// like an argument list.
Expand All @@ -373,10 +390,6 @@ Requested namespace: {namespace}",
// The reducer argument names as `ident, ident, ident`,
// for passing to function call and struct literal expressions.
let mut arg_names_list = String::new();
// The reducer argument names as `&args.ident, &args.ident, &args.ident`,
// for extracting from a structure named `args` by reference
// and passing to a function call.
let mut unboxed_arg_refs = String::new();
for (arg_ident, arg_ty) in &reducer.params_for_generate.elements[..] {
arg_types_ref_list += "&";
write_type(module, &mut arg_types_ref_list, arg_ty).unwrap();
Expand All @@ -385,12 +398,40 @@ Requested namespace: {namespace}",
let arg_name = arg_ident.deref().to_case(Case::Snake);
arg_names_list += &arg_name;
arg_names_list += ", ";

unboxed_arg_refs += "&args.";
unboxed_arg_refs += &arg_name;
unboxed_arg_refs += ", ";
}

write!(out, "impl From<{args_type}> for super::Reducer ");
out.delimited_block(
"{",
|out| {
write!(out, "fn from(args: {args_type}) -> Self ");
out.delimited_block(
"{",
|out| {
write!(out, "Self::{enum_variant_name}");
if !reducer.params_for_generate.elements.is_empty() {
// We generate "struct variants" for reducers with arguments,
// but "unit variants" for reducers of no arguments.
// These use different constructor syntax.
out.delimited_block(
" {",
|out| {
for (arg_ident, _ty) in &reducer.params_for_generate.elements[..] {
let arg_name = arg_ident.deref().to_case(Case::Snake);
writeln!(out, "{arg_name}: args.{arg_name},");
}
},
"}",
);
}
out.newline();
},
"}\n",
);
},
"}\n",
);

// TODO: check for lifecycle reducers and do not generate the invoke method.

writeln!(
Expand Down Expand Up @@ -437,13 +478,24 @@ impl {func_name} for super::RemoteReducers {{
&self,
mut callback: impl FnMut(&super::EventContext, {arg_types_ref_list}) + Send + 'static,
) -> {callback_id} {{
{callback_id}(self.imp.on_reducer::<{args_type}>(
{callback_id}(self.imp.on_reducer(
{reducer_name:?},
Box::new(move |ctx: &super::EventContext, args: &{args_type}| callback(ctx, {unboxed_arg_refs})),
Box::new(move |ctx: &super::EventContext| {{
let super::EventContext {{
event: __sdk::Event::Reducer(__sdk::ReducerEvent {{
reducer: super::Reducer::{enum_variant_name} {{
{arg_names_list}
}},
..
}}),
..
}} = ctx else {{ unreachable!() }};
callback(ctx, {arg_names_list})
}}),
))
}}
fn remove_on_{func_name}(&self, callback: {callback_id}) {{
self.imp.remove_on_reducer::<{args_type}>({reducer_name:?}, callback.0)
self.imp.remove_on_reducer({reducer_name:?}, callback.0)
}}
}}

Expand Down Expand Up @@ -714,10 +766,11 @@ fn define_struct_for_product(
out: &mut Indenter,
name: &str,
elements: &[(Identifier, AlgebraicTypeUse)],
vis: &str,
) {
print_struct_derives(out);

write!(out, "pub struct {name} ");
write!(out, "{vis} struct {name} ");

// TODO: if elements is empty, define a unit struct with no brace-delimited list of fields.
write_struct_type_fields_in_braces(
Expand All @@ -744,14 +797,22 @@ fn table_method_name(table_name: &Identifier) -> String {
table_name.deref().to_case(Case::Snake)
}

fn table_access_trait_name(table_name: &Identifier) -> String {
table_name.deref().to_case(Case::Pascal) + "TableAccess"
}

fn reducer_args_type_name(reducer_name: &Identifier) -> String {
reducer_name.deref().to_case(Case::Pascal)
reducer_name.deref().to_case(Case::Pascal) + "Args"
}

fn reducer_variant_name(reducer_name: &Identifier) -> String {
reducer_name.deref().to_case(Case::Pascal)
}

fn reducer_callback_id_name(reducer_name: &Identifier) -> String {
reducer_name.deref().to_case(Case::Pascal) + "CallbackId"
}

fn reducer_module_name(reducer_name: &Identifier) -> String {
reducer_name.deref().to_case(Case::Snake) + "_reducer"
}
Expand All @@ -760,6 +821,10 @@ fn reducer_function_name(reducer: &ReducerDef) -> String {
reducer.name.deref().to_case(Case::Snake)
}

fn reducer_flags_trait_name(reducer: &ReducerDef) -> String {
format!("set_flags_for_{}", reducer_function_name(reducer))
}

/// Iterate over all of the Rust `mod`s for types, reducers and tables in the `module`.
fn iter_module_names(module: &ModuleDef) -> impl Iterator<Item = String> + '_ {
itertools::chain!(
Expand All @@ -776,10 +841,31 @@ fn print_module_decls(module: &ModuleDef, out: &mut Indenter) {
}
}

/// Print `pub use *` declarations for all the files that will be generated for `items`.
/// Print appropriate reexports for all the files that will be generated for `items`.
fn print_module_reexports(module: &ModuleDef, out: &mut Indenter) {
for module_name in iter_module_names(module) {
writeln!(out, "pub use {module_name}::*;");
for ty in module.types().sorted_by_key(|ty| &ty.name) {
let mod_name = type_module_name(&ty.name);
let type_name = collect_case(Case::Pascal, ty.name.name_segments());
writeln!(out, "pub use {mod_name}::{type_name};")
}
for table in iter_tables(module) {
let mod_name = table_module_name(&table.name);
// TODO: More precise reexport: we want:
// - The trait name.
// - The insert, delete and possibly update callback ids.
// We do not want:
// - The table handle.
writeln!(out, "pub use {mod_name}::*;");
}
for reducer in iter_reducers(module) {
let mod_name = reducer_module_name(&reducer.name);
let reducer_trait_name = reducer_function_name(reducer);
let flags_trait_name = reducer_flags_trait_name(reducer);
let callback_id_name = reducer_callback_id_name(&reducer.name);
writeln!(
out,
"pub use {mod_name}::{{{reducer_trait_name}, {flags_trait_name}, {callback_id_name}}};"
);
}
}

Expand Down Expand Up @@ -814,7 +900,9 @@ fn iter_unique_cols<'a>(
}

fn print_reducer_enum_defn(module: &ModuleDef, out: &mut Indenter) {
print_enum_derives(out);
// Don't derive ser/de on this enum;
// it's not a proper SATS enum and the derive will fail.
writeln!(out, "#[derive(Clone, PartialEq, Debug)]");
writeln!(
out,
"
Expand All @@ -828,13 +916,15 @@ fn print_reducer_enum_defn(module: &ModuleDef, out: &mut Indenter) {
"pub enum Reducer {",
|out| {
for reducer in iter_reducers(module) {
writeln!(
out,
"{}({}::{}),",
reducer_variant_name(&reducer.name),
reducer_module_name(&reducer.name),
reducer_args_type_name(&reducer.name),
);
write!(out, "{} ", reducer_variant_name(&reducer.name));
if !reducer.params_for_generate.elements.is_empty() {
// If the reducer has any arguments, generate a "struct variant,"
// like `Foo { bar: Baz, }`.
// If it doesn't, generate a "unit variant" instead,
// like `Foo,`.
write_struct_type_fields_in_braces(module, out, &reducer.params_for_generate.elements, false);
}
writeln!(out, ",");
}
},
"}\n",
Expand All @@ -859,27 +949,17 @@ impl __sdk::InModule for Reducer {{
"match self {",
|out| {
for reducer in iter_reducers(module) {
writeln!(
out,
"Reducer::{}(_) => {:?},",
reducer_variant_name(&reducer.name),
reducer.name.deref(),
);
}
},
"}\n",
);
},
"}\n",
);
out.delimited_block(
"fn reducer_args(&self) -> &dyn std::any::Any {",
|out| {
out.delimited_block(
"match self {",
|out| {
for reducer in iter_reducers(module) {
writeln!(out, "Reducer::{}(args) => args,", reducer_variant_name(&reducer.name));
write!(out, "Reducer::{}", reducer_variant_name(&reducer.name));
if !reducer.params_for_generate.elements.is_empty() {
// Because we're emitting unit variants when the payload is empty,
gefjon marked this conversation as resolved.
Show resolved Hide resolved
// we will emit different patterns for empty vs non-empty variants.
// This is not strictly required;
// Rust allows matching a struct-like pattern
// against a unit-like enum variant,
// but we prefer the clarity of not including the braces for unit variants.
write!(out, " {{ .. }}");
}
writeln!(out, " => {:?},", reducer.name.deref());
}
},
"}\n",
Expand All @@ -895,6 +975,21 @@ impl __sdk::InModule for Reducer {{
"impl TryFrom<__ws::ReducerCallInfo<__ws::BsatnFormat>> for Reducer {",
|out| {
writeln!(out, "type Error = __anyhow::Error;");
// We define an "args struct" for each reducer in `generate_reducer`.
// This is not user-facing, and is not exported past the "root" `mod.rs`;
// it is an internal helper for serialization and deserialization.
// We actually want to ser/de instances of `enum Reducer`, but:
//
// - `Reducer` will have struct-like variants, which SATS ser/de does not support.
// - The WS format does not contain a BSATN-serialized `Reducer` instance;
// it holds the reducer name or ID separately from the argument bytes.
// We could work up some magic with `DeserializeSeed`
// and/or custom `Serializer` and `Deserializer` types
// to account for this, but it's much easier to just use an intermediate struct per reducer.
//
// As such, we deserialize from the `value.args` bytes into that "args struct,"
// then convert it into a `Reducer` variant via `Into::into`,
// which we also implement in `generate_reducer`.
out.delimited_block(
"fn try_from(value: __ws::ReducerCallInfo<__ws::BsatnFormat>) -> __anyhow::Result<Self> {",
|out| {
Expand All @@ -904,9 +999,10 @@ impl __sdk::InModule for Reducer {{
for reducer in iter_reducers(module) {
writeln!(
out,
"{:?} => Ok(Reducer::{}(__sdk::parse_reducer_args({:?}, &value.args)?)),",
"{:?} => Ok(__sdk::parse_reducer_args::<{}::{}>({:?}, &value.args)?.into()),",
reducer.name.deref(),
reducer_variant_name(&reducer.name),
reducer_module_name(&reducer.name),
reducer_args_type_name(&reducer.name),
reducer.name.deref(),
);
}
Expand Down
Loading
Loading