-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add enum module support #1337
Add enum module support #1337
Conversation
Should we add something in the book to state that we support enum modules? |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #1337 +/- ##
==========================================
+ Coverage 78.77% 78.87% +0.09%
==========================================
Files 551 555 +4
Lines 61836 62176 +340
==========================================
+ Hits 48712 49040 +328
- Misses 13124 13136 +12 ☔ View full report in Codecov by Sentry. |
Now that I think about it, for simpler cases like For example, fn num_params(&self) -> usize {
match self {
Self::Basic(module) | Self::Composed(module) => burn::module::Module::<B>::num_params(module),
}
} Lmk if you have a preference. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I think we can open issues to support all enum types, but for now it's not limiting, you can safely extract anything into modules and use a basic enum.
#[derive(Module, Debug)] | ||
enum ModuleEnumNested<B: Backend> { | ||
AnotherEnum(ModuleEnum<B>), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supported?
#[derive(Module, Debug)]
enum ModuleEnumNamed<B: Backend> {
Variant {
fc1: nn::Linear<B>,
fc2: nn::Linear<B>,
},
}
If not maybe we can open an issue and support this in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet! Your follow-up comment was correct. I only added support for unnamed fields/variants. If you try you should get a compile-time error (I added a panic to check for that specifically).
|
||
for variant in self.variants.iter() { | ||
let name = &variant.ident; | ||
let arm_pattern = quote! {Self::#name(module)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this, I think we don't yet support named enum, this can be added later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah only unnamed enum support in this PR. We could definitely add an issue and improve the support, I don't think it would require that much more work now that the table is set.
Checklist
run-checks all
script has been executed.Related Issues/PRs
Issue #726
Changes
Added support for
#[derive(Module)]
withenum
s.In both cases, the generated methods defer the calls to the concrete types.
For example:
Testing
Added enum module definitions in derive module tests and transposed some existing unit tests for enum modules.