Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] stabilize unsafe_pytree #4030

Merged
merged 1 commit into from
Jul 7, 2024
Merged

[nnx] stabilize unsafe_pytree #4030

merged 1 commit into from
Jul 7, 2024

Conversation

cgarciae
Copy link
Collaborator

What does this PR do?

  • Renames experimental_pytree to unsafe_pytree, making it a stable feature.
  • Adds the "Modules as Pytrees" section to the NNX Basics guide.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link

google-cla bot commented Jun 26, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@cgarciae cgarciae changed the base branch from Module-split-merge to main June 26, 2024 09:22
@ASEM000
Copy link

ASEM000 commented Jun 28, 2024

Maybe I am missing something, is there is a discussion why is it unsafe? My guess would be because of lost refs, but I am not sure in the context of nnx.

@cgarciae cgarciae marked this pull request as ready for review June 28, 2024 20:03
ma, mb = jax.tree.map(lambda x: x, (ma, mb))

print(f'After: {ma.shared is mb.shared = }')
```
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ASEM000 here is the explanation, indeed its because of sharing. The unsafe naming hopefully gives users a visual queue that they shouldn't use this feature lightly.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more clear now, thanks.

@@ -199,8 +199,8 @@ def is_initializing(self) -> bool:

return self._object__state._initializing

def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__(experimental_pytree)
def __init_subclass__(cls, unsafe_pytree: bool = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a link or docstring around here to explain this feature so that it's available somewhere in the API reference? I think this feature warrants explicit explanation (since its caveat is a bit hidden).

reference semantics are broken by JAX's referential transparency, this
is specially problematic when there is shared state between NNX graph nodes
as reference identity is lost. Use `unsafe_pytree` only when there' i's only
a single top-level Module or there is no shared state between top-level Modules.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when there' i's only a single top-level Module or there is no shared state between top-level Modules.

Should this be an "or" or an "and"?
Also nit: "when there is only a single..."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It's an "or". Changed the wording a bit.

@cgarciae cgarciae force-pushed the nnx-pytree branch 2 times, most recently from 3660b20 to fe6a538 Compare July 5, 2024 17:15
@cgarciae cgarciae requested a review from IvyZX July 5, 2024 17:15
@cgarciae cgarciae force-pushed the nnx-pytree branch 2 times, most recently from 6f1d9d7 to 45e2cd2 Compare July 5, 2024 17:31
@cgarciae cgarciae force-pushed the nnx-pytree branch 2 times, most recently from b5fea8b to d2ca0ef Compare July 5, 2024 21:02
@copybara-service copybara-service bot merged commit 96bb44c into main Jul 7, 2024
18 checks passed
@copybara-service copybara-service bot deleted the nnx-pytree branch July 7, 2024 21:49
copybara-service bot pushed a commit that referenced this pull request Jul 11, 2024
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions.

PiperOrigin-RevId: 651347331
copybara-service bot pushed a commit that referenced this pull request Jul 11, 2024
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions.

PiperOrigin-RevId: 651347331
copybara-service bot pushed a commit that referenced this pull request Jul 11, 2024
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions.

PiperOrigin-RevId: 651347331
copybara-service bot pushed a commit that referenced this pull request Jul 11, 2024
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions.

PiperOrigin-RevId: 651347331
copybara-service bot pushed a commit that referenced this pull request Jul 11, 2024
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions.

PiperOrigin-RevId: 651419435
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants