-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] improve transforms guide #4333
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
docs_nnx/guides/transforms.md
Outdated
@@ -51,14 +51,14 @@ weights = Weights( | |||
kernel=random.uniform(random.key(0), (10, 2, 3)), | |||
bias=jnp.zeros((10, 3)), | |||
) | |||
x = jax.random.normal(random.key(1), (10, 2)) | |||
count = jax.random.normal(random.key(1), (10, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why renamed it to count
? x
makes more sense as a model input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this! There's a vscode bug when renaming variables in jupyter notebooks. Reverted back to x
.
docs_nnx/guides/transforms.md
Outdated
|
||
While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable, it is very easy to insert a tracer into a captured Module and get tracer leakage. To avoid this, Flax NNX Module and Variables check that the current context is valid every time they are performing an update operation. | ||
|
||
For example, if we a have stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer, however Flax NNX will raise an error instead to prevent this: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ", however" -> ". However"
While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable, it is very easy to insert a tracer into a captured Module and get tracer leakage. To avoid this, Flax NNX Module and Variables check that the current context is valid every time they are performing an update operation. | ||
|
||
For example, if we a have stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer, however Flax NNX will raise an error instead to prevent this: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth to add an explicit TL;DR line saying this can be solved by explicitly pass the module counter
as an argument of f
, to make sure the module is completely being traced.
The concepts of "tracer leakage" and "context" could be confusing for most users, who don't necessarily understand the internals of Flax and JAX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I simplified the language and added a comment at the end with the solution.
4c90748
to
8570051
Compare
8570051
to
a8e2ab1
Compare
What does this PR do?
Note: tracer level checks have been deactivated by JAX so the expected error is not showing up in the guide.