Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enum module support #1337

Merged
merged 7 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions crates/burn-core/tests/derive_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ struct ModuleWithGenericModule<B: Backend, M> {
_backend: PhantomData<B>,
}

#[derive(Module, Debug)]
enum ModuleEnum<B: Backend> {
Basic(ModuleBasic<B>),
Composed(ModuleComposed<B>),
}

#[derive(Module, Debug)]
enum ModuleEnumNested<B: Backend> {
AnotherEnum(ModuleEnum<B>),
}
Comment on lines +43 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supported?

#[derive(Module, Debug)]
enum ModuleEnumNamed<B: Backend> {
    Variant {
       fc1: nn::Linear<B>,
       fc2: nn::Linear<B>,
    },
}

If not maybe we can open an issue and support this in another PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet! Your follow-up comment was correct. I only added support for unnamed fields/variants. If you try you should get a compile-time error (I added a panic to check for that specifically).


#[derive(Module, Debug)]
enum ModuleEnumWithGenericModule<B: Backend, M: Module<B>> {
Basic(ModuleBasic<B>),
Generic(ModuleWithGenericModule<B, M>),
}

#[derive(Module, Debug)]
pub struct ModuleComposed<B: Backend> {
weight: Param<Tensor<B, 2>>,
Expand Down Expand Up @@ -95,6 +112,46 @@ mod state {
module_2.basic.weight_basic.to_data()
);
}

#[test]
fn should_load_from_record_enum() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let mut module_2 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let state_1 = module_1.clone().into_record();

let ModuleEnum::Basic(module_1_basic) = module_1 else {
panic!("Invalid module type")
};
let ModuleEnum::Basic(module_2_basic) = module_2.clone() else {
panic!("Invalid module type")
};
assert_ne!(
module_1_basic.weight_basic.to_data(),
module_2_basic.weight_basic.to_data()
);

module_2 = module_2.load_record(state_1);

let ModuleEnum::Basic(module_2_basic) = module_2 else {
panic!("Invalid module type")
};
assert_eq!(
module_1_basic.weight_basic.to_data(),
module_2_basic.weight_basic.to_data()
);
}

#[test]
#[should_panic(expected = "Can't parse record from a different variant")]
fn should_panic_load_from_incorrect_enum_variant() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let module_2 = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
let state_1 = module_1.clone().into_record();

module_2.load_record(state_1);
}
}

mod num_params {
Expand All @@ -113,6 +170,16 @@ mod num_params {
let module = ModuleComposed::<TestBackend>::new(&device);
assert_eq!(4 * 20 * 20, module.num_params());
}

#[test]
fn should_calculate_num_params_enum() {
let device = <TestBackend as Backend>::Device::default();
let module = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
assert_eq!(20 * 20, module.num_params());

let module = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
assert_eq!(4 * 20 * 20, module.num_params());
}
}

#[cfg(feature = "std")]
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-derive/src/module/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{
codegen::{generate_module_const, generate_module_standard},
codegen_enum::EnumModuleCodegen,
codegen_struct::StructModuleCodegen,
};
use proc_macro::TokenStream;
Expand All @@ -22,7 +23,8 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
}
syn::Data::Enum(_data) => {
if has_backend {
panic!("Enum modules aren't supported yet.")
generate_module_standard(ast, EnumModuleCodegen::from_ast(ast))
// panic!("Enum modules aren't supported yet.")
laggui marked this conversation as resolved.
Show resolved Hide resolved
} else {
generate_module_const(ast)
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-derive/src/module/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(

let record = codegen.record_codegen();
let record_name = Ident::new(format!("{}Record", name).as_str(), name.span());
let record_struct = record.gen_record_type(&record_name, &generics.module);
let record_type = record.gen_record_type(&record_name, &generics.module);

let (generics_module, generics_ty_module, generics_where_module) =
generics.module.split_for_impl();
Expand Down Expand Up @@ -86,7 +86,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
#clone_fn
}

#record_struct
#record_type
};

gen
Expand Down
192 changes: 192 additions & 0 deletions crates/burn-derive/src/module/codegen_enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen};
use crate::shared::enum_variant::{parse_variants, EnumVariant};
use proc_macro2::{Ident, TokenStream};
use quote::quote;

pub(crate) struct EnumModuleCodegen {
pub variants: Vec<EnumVariant>,
}

impl ModuleCodegen for EnumModuleCodegen {
type RecordCodegen = EnumModuleRecordCodegen;

fn gen_num_params(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|_| {
quote! {
burn::module::Module::<B>::num_params(module)
}
});

quote! {
fn num_params(&self) -> usize {
#match_body
}
}
}

fn gen_visit(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|_| {
quote! {
burn::module::Module::visit(module, visitor)
}
});

quote! {
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {
#match_body
}
}
}

fn gen_collect_devices(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|_| {
quote! {
burn::module::Module::<B>::collect_devices(module, devices)
}
});

quote! {
fn collect_devices(
&self,
devices: burn::module::Devices<B>
) -> burn::module::Devices<B> {
#match_body
}
}
}

fn gen_to_device(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(burn::module::Module::<B>::to_device(module, device))
}
});

quote! {
fn to_device(self, device: &B::Device) -> Self {
#match_body
}
}
}

fn gen_fork(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(burn::module::Module::<B>::fork(module, device))
}
});

quote! {
fn fork(self, device: &B::Device) -> Self {
#match_body
}
}
}

fn gen_map(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(burn::module::Module::<B>::map(module, mapper))
}
});

quote! {
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {
#match_body
}
}
}

fn gen_valid(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::InnerModule::#variant(burn::module::AutodiffModule::<B>::valid(module))
}
});

quote! {
fn valid(&self) -> Self::InnerModule {
#match_body
}
}
}

fn gen_into_record(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::Record::#variant(burn::module::Module::<B>::into_record(module))
}
});

quote! {
fn into_record(self) -> Self::Record {
#match_body
}
}
}

fn gen_load_record(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
{
let Self::Record::#variant(r) = record else {panic!("Can't parse record from a different variant");};
Self::#variant(burn::module::Module::<B>::load_record(module, r))
}
}
});

quote! {
fn load_record(self, record: Self::Record) -> Self {
#match_body
}
}
}

fn gen_clone(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(module.clone())
}
});

quote! {
fn clone(&self) -> Self {
#match_body
}
}
}

fn record_codegen(self) -> Self::RecordCodegen {
EnumModuleRecordCodegen::new(self.variants)
}
}

impl EnumModuleCodegen {
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
variants: parse_variants(ast),
}
}

/// Generate the enum variants' match arm with the provided function
fn gen_variants_match_fn<F>(&self, func: F) -> TokenStream
where
F: Fn(Ident) -> TokenStream,
{
let mut match_arms = quote! {};

for variant in self.variants.iter() {
let name = &variant.ident;
let arm_pattern = quote! {Self::#name(module)};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this, I think we don't yet support named enum, this can be added later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah only unnamed enum support in this PR. We could definitely add an issue and improve the support, I don't think it would require that much more work now that the table is set.

let arm_code = func(name.clone());

match_arms.extend(quote! {#arm_pattern => #arm_code,})
}

quote! {
match self {
#match_arms
}
}
}
}
2 changes: 2 additions & 0 deletions crates/burn-derive/src/module/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
pub(crate) mod codegen;
pub(crate) mod codegen_enum;
pub(crate) mod codegen_struct;
pub(crate) mod display;
pub(crate) mod record;
pub(crate) mod record_enum;
pub(crate) mod record_struct;

mod base;
Expand Down
39 changes: 39 additions & 0 deletions crates/burn-derive/src/module/record_enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use crate::shared::enum_variant::EnumVariant;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::Generics;

use super::record::ModuleRecordCodegen;

#[derive(new)]
pub(crate) struct EnumModuleRecordCodegen {
variants: Vec<EnumVariant>,
}

impl ModuleRecordCodegen for EnumModuleRecordCodegen {
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
let mut variants = quote! {};

// Capture the Record enum variant types
for variant in self.variants.iter() {
let ty = &variant.ty;
let name = &variant.ident;

variants.extend(quote! {
/// The module record associative type.
#name(<#ty as burn::module::Module<B>>::Record),
});
}

let (generics, _generics_ty, generics_where) = generics.split_for_impl();

quote! {

/// The record type for the module.
#[derive(burn::record::Record)]
pub enum #record_name #generics #generics_where {
#variants
}
}
}
}
Loading
Loading