Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust SDK: no more reducer args structs #2036

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 119 additions & 52 deletions crates/cli/src/subcommands/generate/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Requested namespace: {namespace}",
AlgebraicTypeDef::Product(product) => {
gen_and_print_imports(module, out, &product.elements, &[typ.ty]);
out.newline();
define_struct_for_product(module, out, &type_name, &product.elements);
define_struct_for_product(module, out, &type_name, &product.elements, "pub");
}
AlgebraicTypeDef::Sum(sum) => {
gen_and_print_imports(module, out, &sum.variants, &[typ.ty]);
Expand Down Expand Up @@ -130,7 +130,7 @@ Requested namespace: {namespace}",
let table_handle = table_name_pascalcase.clone() + "TableHandle";
let insert_callback_id = table_name_pascalcase.clone() + "InsertCallbackId";
let delete_callback_id = table_name_pascalcase.clone() + "DeleteCallbackId";
let accessor_trait = table_name_pascalcase.clone() + "TableAccess";
let accessor_trait = table_access_trait_name(&table.name);
let accessor_method = table_method_name(&table.name);

write!(
Expand Down Expand Up @@ -353,14 +353,21 @@ Requested namespace: {namespace}",

let reducer_name = reducer.name.deref();
let func_name = reducer_function_name(reducer);
let set_reducer_flags_trait = format!("set_flags_for_{func_name}");
let set_reducer_flags_trait = reducer_flags_trait_name(reducer);
let args_type = reducer_args_type_name(&reducer.name);
let enum_variant_name = reducer_variant_name(&reducer.name);

define_struct_for_product(module, out, &args_type, &reducer.params_for_generate.elements);
define_struct_for_product(
module,
out,
&args_type,
&reducer.params_for_generate.elements,
"pub(super)",
);

out.newline();

let callback_id = args_type.clone() + "CallbackId";
let callback_id = reducer_callback_id_name(&reducer.name);

// The reducer arguments as `ident: ty, ident: ty, ident: ty,`,
// like an argument list.
Expand All @@ -373,10 +380,6 @@ Requested namespace: {namespace}",
// The reducer argument names as `ident, ident, ident`,
// for passing to function call and struct literal expressions.
let mut arg_names_list = String::new();
// The reducer argument names as `&args.ident, &args.ident, &args.ident`,
// for extracting from a structure named `args` by reference
// and passing to a function call.
let mut unboxed_arg_refs = String::new();
for (arg_ident, arg_ty) in &reducer.params_for_generate.elements[..] {
arg_types_ref_list += "&";
write_type(module, &mut arg_types_ref_list, arg_ty).unwrap();
Expand All @@ -385,12 +388,40 @@ Requested namespace: {namespace}",
let arg_name = arg_ident.deref().to_case(Case::Snake);
arg_names_list += &arg_name;
arg_names_list += ", ";

unboxed_arg_refs += "&args.";
unboxed_arg_refs += &arg_name;
unboxed_arg_refs += ", ";
}

write!(out, "impl From<{args_type}> for super::Reducer ");
out.delimited_block(
"{",
|out| {
write!(out, "fn from(args: {args_type}) -> Self ");
out.delimited_block(
"{",
|out| {
write!(out, "Self::{enum_variant_name}");
if !reducer.params_for_generate.elements.is_empty() {
// We generate "struct variants" for reducers with arguments,
// but "unit variants" for reducers of no arguments.
// These use different constructor syntax.
out.delimited_block(
" {",
|out| {
for (arg_ident, _ty) in &reducer.params_for_generate.elements[..] {
let arg_name = arg_ident.deref().to_case(Case::Snake);
writeln!(out, "{arg_name}: args.{arg_name},");
}
},
"}",
);
}
out.newline();
},
"}\n",
);
},
"}\n",
);

// TODO: check for lifecycle reducers and do not generate the invoke method.

writeln!(
Expand Down Expand Up @@ -437,13 +468,24 @@ impl {func_name} for super::RemoteReducers {{
&self,
mut callback: impl FnMut(&super::EventContext, {arg_types_ref_list}) + Send + 'static,
) -> {callback_id} {{
{callback_id}(self.imp.on_reducer::<{args_type}>(
{callback_id}(self.imp.on_reducer(
{reducer_name:?},
Box::new(move |ctx: &super::EventContext, args: &{args_type}| callback(ctx, {unboxed_arg_refs})),
Box::new(move |ctx: &super::EventContext| {{
let super::EventContext {{
event: __sdk::Event::Reducer(__sdk::ReducerEvent {{
reducer: super::Reducer::{enum_variant_name} {{
{arg_names_list}
}},
..
}}),
..
}} = ctx else {{ unreachable!() }};
callback(ctx, {arg_names_list})
}}),
))
}}
fn remove_on_{func_name}(&self, callback: {callback_id}) {{
self.imp.remove_on_reducer::<{args_type}>({reducer_name:?}, callback.0)
self.imp.remove_on_reducer({reducer_name:?}, callback.0)
}}
}}

Expand Down Expand Up @@ -714,10 +756,11 @@ fn define_struct_for_product(
out: &mut Indenter,
name: &str,
elements: &[(Identifier, AlgebraicTypeUse)],
vis: &str,
) {
print_struct_derives(out);

write!(out, "pub struct {name} ");
write!(out, "{vis} struct {name} ");

// TODO: if elements is empty, define a unit struct with no brace-delimited list of fields.
write_struct_type_fields_in_braces(
Expand All @@ -744,14 +787,22 @@ fn table_method_name(table_name: &Identifier) -> String {
table_name.deref().to_case(Case::Snake)
}

fn table_access_trait_name(table_name: &Identifier) -> String {
table_name.deref().to_case(Case::Pascal) + "TableAccess"
}

fn reducer_args_type_name(reducer_name: &Identifier) -> String {
reducer_name.deref().to_case(Case::Pascal)
reducer_name.deref().to_case(Case::Pascal) + "Args"
}

fn reducer_variant_name(reducer_name: &Identifier) -> String {
reducer_name.deref().to_case(Case::Pascal)
}

fn reducer_callback_id_name(reducer_name: &Identifier) -> String {
reducer_name.deref().to_case(Case::Pascal) + "CallbackId"
}

fn reducer_module_name(reducer_name: &Identifier) -> String {
reducer_name.deref().to_case(Case::Snake) + "_reducer"
}
Expand All @@ -760,6 +811,10 @@ fn reducer_function_name(reducer: &ReducerDef) -> String {
reducer.name.deref().to_case(Case::Snake)
}

fn reducer_flags_trait_name(reducer: &ReducerDef) -> String {
format!("set_flags_for_{}", reducer_function_name(reducer))
}

/// Iterate over all of the Rust `mod`s for types, reducers and tables in the `module`.
fn iter_module_names(module: &ModuleDef) -> impl Iterator<Item = String> + '_ {
itertools::chain!(
Expand All @@ -776,10 +831,31 @@ fn print_module_decls(module: &ModuleDef, out: &mut Indenter) {
}
}

/// Print `pub use *` declarations for all the files that will be generated for `items`.
/// Print appropriate reexports for all the files that will be generated for `items`.
fn print_module_reexports(module: &ModuleDef, out: &mut Indenter) {
for module_name in iter_module_names(module) {
writeln!(out, "pub use {module_name}::*;");
for ty in module.types().sorted_by_key(|ty| &ty.name) {
let mod_name = type_module_name(&ty.name);
let type_name = collect_case(Case::Pascal, ty.name.name_segments());
writeln!(out, "pub use {mod_name}::{type_name};")
}
for table in iter_tables(module) {
let mod_name = table_module_name(&table.name);
// TODO: More precise reexport: we want:
// - The trait name.
// - The insert, delete and possibly update callback ids.
// We do not want:
// - The table handle.
writeln!(out, "pub use {mod_name}::*;");
}
for reducer in iter_reducers(module) {
let mod_name = reducer_module_name(&reducer.name);
let reducer_trait_name = reducer_function_name(reducer);
let flags_trait_name = reducer_flags_trait_name(reducer);
let callback_id_name = reducer_callback_id_name(&reducer.name);
writeln!(
out,
"pub use {mod_name}::{{{reducer_trait_name}, {flags_trait_name}, {callback_id_name}}};"
);
}
}

Expand Down Expand Up @@ -814,7 +890,9 @@ fn iter_unique_cols<'a>(
}

fn print_reducer_enum_defn(module: &ModuleDef, out: &mut Indenter) {
print_enum_derives(out);
// Don't derive ser/de on this enum;
// it's not a proper SATS enum and the derive will fail.
writeln!(out, "#[derive(Clone, PartialEq, Debug)]");
writeln!(
out,
"
Expand All @@ -828,13 +906,15 @@ fn print_reducer_enum_defn(module: &ModuleDef, out: &mut Indenter) {
"pub enum Reducer {",
|out| {
for reducer in iter_reducers(module) {
writeln!(
out,
"{}({}::{}),",
reducer_variant_name(&reducer.name),
reducer_module_name(&reducer.name),
reducer_args_type_name(&reducer.name),
);
write!(out, "{} ", reducer_variant_name(&reducer.name));
if !reducer.params_for_generate.elements.is_empty() {
// If the reducer has any arguments, generate a "struct variant,"
// like `Foo { bar: Baz, }`.
// If it doesn't, generate a "unit variant" instead,
// like `Foo,`.
write_struct_type_fields_in_braces(module, out, &reducer.params_for_generate.elements, false);
}
writeln!(out, ",");
}
},
"}\n",
Expand All @@ -859,27 +939,13 @@ impl __sdk::InModule for Reducer {{
"match self {",
|out| {
for reducer in iter_reducers(module) {
writeln!(
out,
"Reducer::{}(_) => {:?},",
reducer_variant_name(&reducer.name),
reducer.name.deref(),
);
}
},
"}\n",
);
},
"}\n",
);
out.delimited_block(
"fn reducer_args(&self) -> &dyn std::any::Any {",
|out| {
out.delimited_block(
"match self {",
|out| {
for reducer in iter_reducers(module) {
writeln!(out, "Reducer::{}(args) => args,", reducer_variant_name(&reducer.name));
write!(out, "Reducer::{}", reducer_variant_name(&reducer.name));
if !reducer.params_for_generate.elements.is_empty() {
// Because we're emitting unit variants when the payload is empty,
// we have to emit different patterns for empty vs non-empty variants.
write!(out, " {{ .. }}");
}
writeln!(out, " => {:?},", reducer.name.deref());
}
},
"}\n",
Expand All @@ -904,9 +970,10 @@ impl __sdk::InModule for Reducer {{
for reducer in iter_reducers(module) {
writeln!(
out,
"{:?} => Ok(Reducer::{}(__sdk::parse_reducer_args({:?}, &value.args)?)),",
"{:?} => Ok(__sdk::parse_reducer_args::<{}::{}>({:?}, &value.args)?.into()),",
reducer.name.deref(),
reducer_variant_name(&reducer.name),
reducer_module_name(&reducer.name),
reducer_args_type_name(&reducer.name),
reducer.name.deref(),
);
}
Expand Down
Loading
Loading