Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] improve transforms guide #4333

Merged
merged 1 commit into from
Oct 29, 2024
Merged

Conversation

cgarciae
Copy link
Collaborator

What does this PR do?

  • Adds section about passing Modules by closure.

Note: tracer level checks have been deactivated by JAX so the expected error is not showing up in the guide.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -51,14 +51,14 @@ weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
)
x = jax.random.normal(random.key(1), (10, 2))
count = jax.random.normal(random.key(1), (10, 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why renamed it to count? x makes more sense as a model input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! There's a vscode bug when renaming variables in jupyter notebooks. Reverted back to x.


While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable, it is very easy to insert a tracer into a captured Module and get tracer leakage. To avoid this, Flax NNX Module and Variables check that the current context is valid every time they are performing an update operation.

For example, if we a have stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer, however Flax NNX will raise an error instead to prevent this:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ", however" -> ". However"

While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable, it is very easy to insert a tracer into a captured Module and get tracer leakage. To avoid this, Flax NNX Module and Variables check that the current context is valid every time they are performing an update operation.

For example, if we a have stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer, however Flax NNX will raise an error instead to prevent this:

Copy link
Collaborator

@IvyZX IvyZX Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth to add an explicit TL;DR line saying this can be solved by explicitly pass the module counter as an argument of f, to make sure the module is completely being traced.

The concepts of "tracer leakage" and "context" could be confusing for most users, who don't necessarily understand the internals of Flax and JAX.

Copy link
Collaborator Author

@cgarciae cgarciae Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I simplified the language and added a comment at the end with the solution.

@copybara-service copybara-service bot merged commit 58d1e53 into main Oct 29, 2024
19 checks passed
@copybara-service copybara-service bot deleted the improve-transforms-guide branch October 29, 2024 21:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants