-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add submodule iterator #3581
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #3581 +/- ##
==========================================
+ Coverage 56.31% 56.46% +0.15%
==========================================
Files 100 100
Lines 11973 11994 +21
==========================================
+ Hits 6742 6773 +31
+ Misses 5231 5221 -10 ☔ View full report in Codecov by Sentry. |
@@ -482,29 +482,10 @@ def sow( | |||
reduced_value = reduce_fn(init_fn(), value) | |||
setattr(self, name, variable_type(reduced_value)) | |||
|
|||
def for_each( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be called submodules
and not module
? Or better yet iter_submodules
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Called it modules() as it would be familiar to Pytorch users. Wondering if we should try to follow their conventions when possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend adding a few tests for the new API.
def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[Path, tp.Any]]: | ||
visited: set[int] = set() | ||
path_parts: PathParts = () | ||
yield from _iter_nodes(node, visited, path_parts) | ||
|
||
|
||
def _iter_nodes( | ||
node: tp.Any, visited: set[int], path_parts: PathParts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can the node
args have a more specific type annotation than tp.Any
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
problem is that a node could be of any type that is registered.
43eb65a
to
b220ec4
Compare
b220ec4
to
1e75509
Compare
@superbobry added test. |
What does this PR do?
Adds method to iterator over all unique submodules.