-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Flax NNX Randomness #4279
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
88d53d9
to
0823b16
Compare
Also updating: - Here is a list of the main PRNG-related types in Flax NNX:
+ Here are the main PRNG-related types in Flax NNX: Added the JAX
Added a link to JAX Working with pytrees for new users:
Added
|
0823b16
to
363e8ab
Compare
rebasing after #4281 |
4ab4cfc
to
11f3616
Compare
57f8d34
to
3798306
Compare
@cgarciae PTAL, thanks! |
@@ -318,8 +344,14 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"## Splitting Rngs\n", | |||
"When interacting with transforms like `vmap` or `pmap` it is often necessary to split the random state such that each replica has its own unique state. This can be done in two way, either by manually splitting a key before passing it to one of the `Rngs` streams, or using the `nnx.split_rngs` decorator which will automatically split the random state of any `RngStream`s found in the inputs of the function and automatically \"lower\" them once the function call ends. `split_rngs` is more convenient as it works nicely with transforms so we'll show an example of that here:" | |||
"## Splitting `nnx.Rngs`\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not splitting nnx.Rngs
directly. Maybe say "Splitting RNG keys".
@cgarciae @IvyZX PTAL
Preview: https://flax--4279.org.readthedocs.build/en/4279/guides/randomness.html