-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add enum module support #1337
Add enum module support #1337
Changes from 6 commits
178098e
c92dd71
a3414de
04bebbc
161021b
c2a4bf2
a8ed24a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen}; | ||
use crate::shared::enum_variant::{parse_variants, EnumVariant}; | ||
use proc_macro2::{Ident, TokenStream}; | ||
use quote::quote; | ||
|
||
pub(crate) struct EnumModuleCodegen { | ||
pub variants: Vec<EnumVariant>, | ||
} | ||
|
||
impl ModuleCodegen for EnumModuleCodegen { | ||
type RecordCodegen = EnumModuleRecordCodegen; | ||
|
||
fn gen_num_params(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|_| { | ||
quote! { | ||
burn::module::Module::<B>::num_params(module) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn num_params(&self) -> usize { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_visit(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|_| { | ||
quote! { | ||
burn::module::Module::visit(module, visitor) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_collect_devices(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|_| { | ||
quote! { | ||
burn::module::Module::<B>::collect_devices(module, devices) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn collect_devices( | ||
&self, | ||
devices: burn::module::Devices<B> | ||
) -> burn::module::Devices<B> { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_to_device(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|variant| { | ||
quote! { | ||
Self::#variant(burn::module::Module::<B>::to_device(module, device)) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn to_device(self, device: &B::Device) -> Self { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_fork(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|variant| { | ||
quote! { | ||
Self::#variant(burn::module::Module::<B>::fork(module, device)) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn fork(self, device: &B::Device) -> Self { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_map(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|variant| { | ||
quote! { | ||
Self::#variant(burn::module::Module::<B>::map(module, mapper)) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_valid(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|variant| { | ||
quote! { | ||
Self::InnerModule::#variant(burn::module::AutodiffModule::<B>::valid(module)) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn valid(&self) -> Self::InnerModule { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_into_record(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|variant| { | ||
quote! { | ||
Self::Record::#variant(burn::module::Module::<B>::into_record(module)) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn into_record(self) -> Self::Record { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_load_record(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|variant| { | ||
quote! { | ||
{ | ||
let Self::Record::#variant(r) = record else {panic!("Can't parse record from a different variant");}; | ||
Self::#variant(burn::module::Module::<B>::load_record(module, r)) | ||
} | ||
} | ||
}); | ||
|
||
quote! { | ||
fn load_record(self, record: Self::Record) -> Self { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn gen_clone(&self) -> TokenStream { | ||
let match_body = self.gen_variants_match_fn(|variant| { | ||
quote! { | ||
Self::#variant(module.clone()) | ||
} | ||
}); | ||
|
||
quote! { | ||
fn clone(&self) -> Self { | ||
#match_body | ||
} | ||
} | ||
} | ||
|
||
fn record_codegen(self) -> Self::RecordCodegen { | ||
EnumModuleRecordCodegen::new(self.variants) | ||
} | ||
} | ||
|
||
impl EnumModuleCodegen { | ||
pub fn from_ast(ast: &syn::DeriveInput) -> Self { | ||
Self { | ||
variants: parse_variants(ast), | ||
} | ||
} | ||
|
||
/// Generate the enum variants' match arm with the provided function | ||
fn gen_variants_match_fn<F>(&self, func: F) -> TokenStream | ||
where | ||
F: Fn(Ident) -> TokenStream, | ||
{ | ||
let mut match_arms = quote! {}; | ||
|
||
for variant in self.variants.iter() { | ||
let name = &variant.ident; | ||
let arm_pattern = quote! {Self::#name(module)}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at this, I think we don't yet support named enum, this can be added later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah only unnamed enum support in this PR. We could definitely add an issue and improve the support, I don't think it would require that much more work now that the table is set. |
||
let arm_code = func(name.clone()); | ||
|
||
match_arms.extend(quote! {#arm_pattern => #arm_code,}) | ||
} | ||
|
||
quote! { | ||
match self { | ||
#match_arms | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
use crate::shared::enum_variant::EnumVariant; | ||
use proc_macro2::{Ident, TokenStream}; | ||
use quote::quote; | ||
use syn::Generics; | ||
|
||
use super::record::ModuleRecordCodegen; | ||
|
||
#[derive(new)] | ||
pub(crate) struct EnumModuleRecordCodegen { | ||
variants: Vec<EnumVariant>, | ||
} | ||
|
||
impl ModuleRecordCodegen for EnumModuleRecordCodegen { | ||
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream { | ||
let mut variants = quote! {}; | ||
|
||
// Capture the Record enum variant types | ||
for variant in self.variants.iter() { | ||
let ty = &variant.ty; | ||
let name = &variant.ident; | ||
|
||
variants.extend(quote! { | ||
/// The module record associative type. | ||
#name(<#ty as burn::module::Module<B>>::Record), | ||
}); | ||
} | ||
|
||
let (generics, _generics_ty, generics_where) = generics.split_for_impl(); | ||
|
||
quote! { | ||
|
||
/// The record type for the module. | ||
#[derive(burn::record::Record)] | ||
pub enum #record_name #generics #generics_where { | ||
#variants | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supported?
If not maybe we can open an issue and support this in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet! Your follow-up comment was correct. I only added support for unnamed fields/variants. If you try you should get a compile-time error (I added a panic to check for that specifically).