-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] stabilize unsafe_pytree #4030
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Maybe I am missing something, is there is a discussion why is it unsafe? My guess would be because of lost refs, but I am not sure in the context of nnx. |
ma, mb = jax.tree.map(lambda x: x, (ma, mb)) | ||
|
||
print(f'After: {ma.shared is mb.shared = }') | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ASEM000 here is the explanation, indeed its because of sharing. The unsafe
naming hopefully gives users a visual queue that they shouldn't use this feature lightly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is more clear now, thanks.
@@ -199,8 +199,8 @@ def is_initializing(self) -> bool: | |||
|
|||
return self._object__state._initializing | |||
|
|||
def __init_subclass__(cls, experimental_pytree: bool = False) -> None: | |||
super().__init_subclass__(experimental_pytree) | |||
def __init_subclass__(cls, unsafe_pytree: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a link or docstring around here to explain this feature so that it's available somewhere in the API reference? I think this feature warrants explicit explanation (since its caveat is a bit hidden).
docs/nnx/nnx_basics.md
Outdated
reference semantics are broken by JAX's referential transparency, this | ||
is specially problematic when there is shared state between NNX graph nodes | ||
as reference identity is lost. Use `unsafe_pytree` only when there' i's only | ||
a single top-level Module or there is no shared state between top-level Modules. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when there' i's only a single top-level Module or there is no shared state between top-level Modules.
Should this be an "or" or an "and"?
Also nit: "when there is only a single..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! It's an "or". Changed the wording a bit.
3660b20
to
fe6a538
Compare
6f1d9d7
to
45e2cd2
Compare
b5fea8b
to
d2ca0ef
Compare
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions. PiperOrigin-RevId: 651347331
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions. PiperOrigin-RevId: 651347331
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions. PiperOrigin-RevId: 651347331
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions. PiperOrigin-RevId: 651347331
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions. PiperOrigin-RevId: 651419435
What does this PR do?
experimental_pytree
tounsafe_pytree
, making it a stable feature.