-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax NNX GSPMD guide #4220
Flax NNX GSPMD guide #4220
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
docs_nnx/guides/flax_gspmd.md
Outdated
|
||
+++ | ||
|
||
## Flax and `jax.jit` scaled up |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should change the writing here to talk about nnx.jit
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kinda want to convey the idea that essentially we are using JAX's compilation machinery for the scaling up work. I renamed the title and added another paragraph explaining this (and mentioning nnx.jit
there).
docs_nnx/guides/flax_gspmd.md
Outdated
self.w2 = nnx.Param( | ||
nnx.with_partitioning(init_fn, ('model', None))( | ||
rngs.params(), (depth, depth)) # RNG key and shape for W2 creation | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good opportunity to show how to manually add the sharding
metadata:
self.w2 = nnx.Param( | |
nnx.with_partitioning(init_fn, ('model', None))( | |
rngs.params(), (depth, depth)) # RNG key and shape for W2 creation | |
) | |
self.w2 = nnx.Param( | |
init_fn(rngs.params(), (depth, depth)) # RNG key and shape for W2 creation | |
sharding=('model', None), | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
# In data parallelism, input / intermediate value's first dimension (batch) | ||
# will be sharded on `data` axis | ||
y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', 'model')) | ||
z = jnp.dot(y, self.w2.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variables can be used as JAX arrays thanks to the __jax_aray__
protocol.
z = jnp.dot(y, self.w2.value) | |
z = jnp.dot(y, self.w2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason this will fail later when I do:
with mesh:
output = sharded_model(input)
With error: AttributeError: 'tuple' object has no attribute '_device_assignment'
.
I'll keep this as-is for now.
docs_nnx/guides/flax_gspmd.md
Outdated
print(unsharded_model.w2.value.sharding) # SingleDeviceSharding | ||
``` | ||
|
||
We should leverage JAX's compilation mechanism, aka. `jax.jit`, to create the sharded model. The key is to intialize a model and assign shardings upon the model state within a jitted function: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should leverage JAX's compilation mechanism, aka. `jax.jit`, to create the sharded model. The key is to intialize a model and assign shardings upon the model state within a jitted function: | |
We should leverage JAX's compilation mechanism, via `nnx.jit`, to create the sharded model. The key is to intialize a model and assign shardings upon the model state within a jitted function: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
docs_nnx/guides/flax_gspmd.md
Outdated
1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jax.jit` how to shard a variable! | ||
|
||
1. Throw away the unsharded state and return the model based upon the sharded state. | ||
|
||
1. Compile the whole function with `nnx.jit` instead of `jax.jit` because it allows the output to be a stateful NNX module. | ||
|
||
1. Run it under a device mesh context so that JAX knows which devices to shard it to. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: replaced jax.jit
with nnx.jit
in the other points and remove the point where you suggest using nnx.jit
instead of jax.jit
.
1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jax.jit` how to shard a variable! | |
1. Throw away the unsharded state and return the model based upon the sharded state. | |
1. Compile the whole function with `nnx.jit` instead of `jax.jit` because it allows the output to be a stateful NNX module. | |
1. Run it under a device mesh context so that JAX knows which devices to shard it to. | |
1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `nnx.jit` how to shard a variable! | |
1. Throw away the unsharded state and return the model based upon the sharded state. | |
1. Run it under a device mesh context so that JAX knows which devices to shard it to. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... I think we should still briefly explain why using nnx.jit
is a better pattern. Especially since we are making transforms closer to JAX style now we should assume some users have experience with jax.jit
. I can remove the mentions of jax.jit
here and direct users more explicitly to nnx.jit
.
docs_nnx/guides/flax_gspmd.md
Outdated
|
||
Now, from initialization or from checkpoint, we have a sharded model. To carry out the compiled, scaled up training, we need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across `data` device axis, so you should put your data in sharding `('data', None)`. You can use `jax.device_put` for this. | ||
|
||
Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `jax.jit`. See the example below - even without `jax.lax.with_sharding_constraint` on the output `y`, it was still sharded as `('data', None)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `jax.jit`. See the example below - even without `jax.lax.with_sharding_constraint` on the output `y`, it was still sharded as `('data', None)`. | |
Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `nnx.jit`. See the example below - even without `jax.lax.with_sharding_constraint` on the output `y`, it was still sharded as `('data', None)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
new_state = block_all(train_step(sharded_model, optimizer, input, label)) | ||
``` | ||
|
||
## Logical axis annotation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I didn't know about sharding_rules
. In nnx_lm1b
with have this other pattern which maps the mesh axes in the constructor:
flax/examples/lm1b_nnx/configs/default.py
Line 22 in e3772b2
class MeshRules: flax/examples/lm1b_nnx/models.py
Line 214 in e3772b2
config.axis_rules('mlp'),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I added them recently to align with Linen's LogicallyPartitioned
. It's just annotations so there's actually a ton of ways to make them work, and I like how you made it in nnx_lm1b
!
Add a guide to do GSPMD-style sharding annotation on NNX models.
Covered everything in the Linen pjit guide, but better explanations and demonstrations, and more concise code!
Also added a small example for loading sharded model from checkpoint.
Preview