-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add GraphNode base class #3790
Conversation
5d9c822
to
e217b27
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3790 +/- ##
==========================================
+ Coverage 59.56% 59.62% +0.06%
==========================================
Files 101 101
Lines 12624 12655 +31
==========================================
+ Hits 7519 7546 +27
- Misses 5105 5109 +4 ☔ View full report in Codecov by Sentry. |
state = deepcopy(state) | ||
return graphdef.merge(state) | ||
|
||
def __hash__(self) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also override __eq__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm more tented to remove this implementation of __hash__
and just have the base one.
@@ -860,6 +887,101 @@ def _graph_update_static( | |||
node_impl.set_key(node, name, value_updates) | |||
|
|||
|
|||
@tp.overload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally would probably leave the overloading out. It is verbose and only marginally useful.
I understand that you are trying to preserve tuple cardinality, but I'm not sure how common such a mistake is to justify all this boilerplate.
34de63f
to
0db3362
Compare
0db3362
to
6948138
Compare
What does this PR do?
GraphNode
as a base class forModule
and other types that need a graph flatten/unflatten implementation according to Module's current definition.nnx.split
andnnx.update
to work with any graph node.nnx.merge
to match the spread signature returned by.split