From 2bf7f9d195da13612f1a93f8328b4bed225c6904 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:54:13 +0000 Subject: [PATCH] Update Flax NNX Scale Up SPMD guide --- docs_nnx/guides/flax_gspmd.ipynb | 116 +++++++++++++++++-------------- docs_nnx/guides/flax_gspmd.md | 116 +++++++++++++++++-------------- 2 files changed, 124 insertions(+), 108 deletions(-) diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index e2650edd06..5ea92a675d 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -6,7 +6,7 @@ "source": [ "# Scale up on multiple devices\n", "\n", - "This guide shows how to scale up [Flax NNX Modules](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts using [JAX just-in-time compilation machinery](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)." + "This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)." ] }, { @@ -16,15 +16,20 @@ "source": [ "## Overview\n", "\n", - "Flax relies on JAX for numeric computations and for scaling the computations up across multiple devices (GPU, TPU, etc), and the core of scaling up is using [JAX's just-in-time compiler](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, we will mainly use [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps `jax.jit` and works more conveniently with NNX modules.\n", + "Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n", "\n", - "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and `jit` will automatically compile and run it on multiple devices.\n", + "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and `jax.jit` will automatically compile and run it on multiple devices.\n", "\n", - "To ensure the compilation performance, you often need to tell JAX how your model's variables are supposed to be sharded across devices. This is where Flax's [Sharding Metadata API](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) comes in - to help you annotate your model variables with this information.\n", + "To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n", "\n", - "> **NOTE to Flax Linen users**: this API is pretty much the same with what you may have learnt in [the Linen version](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on model definition level, but the top-level code will be simpler due to the benefits of NNX, and some text explanation will be more updated and clear.\n", + "> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.\n", "\n", - "You can learn more about JAX APIs for scaling up in [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) on JAX's documentation site." + "If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials:\n", + "\n", + "- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with `jax.jit`, semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map).\n", + "- [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html).\n", + "- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with `jax.jit` and `jax.lax.with_sharding_constraint`. Study it after the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html).\n", + "- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html)." ] }, { @@ -81,7 +86,7 @@ } ], "source": [ - "print(f'We have 8 fake JAX devices now: {jax.devices()}')" + "print(f'You have 8 “fake” JAX devices now: {jax.devices()}')" ] }, { @@ -91,9 +96,10 @@ "source": [ "The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide:\n", "\n", - "1. Start a 2x4 device `mesh` (8 devices) using JAX's `mesh_utils.create_device_mesh`. This layout is the same as the one of a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board).\n", + "1. Start a 2x4 device `mesh` (8 devices) using the JAX [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh). This layout is the same as on a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board) (also 8 devices).\n", + "\n", + "2. Annotate each axis with a name using the `axis_names` parameter. A typical way to annotate axis names is `axis_name=('data', 'model')`, where:\n", "\n", - "2. Annotate each axis with a name using the `axis_names` parameter in `jax.sharding.Mesh`. A typical way to annotate axis names is `axis_name=('data', 'model')`, where:\n", " * `'data'`: the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations.\n", " * `'model'`: the mesh dimension used for sharding parameters of the model across devices." ] @@ -125,11 +131,11 @@ "source": [ "## Define a model with specified sharding\n", "\n", - "Create an example layer called `DotReluDot`. This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", + "Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", "\n", - "To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding `nnx.Variable`.\n", + "To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n", "\n", - "Note that this annotation will be [preserved and adjusted accordingly across lifted transformations](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like `nnx.vmap`, `nnx.scan`), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [NNX transforms guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more." + "> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more." ] }, { @@ -142,15 +148,15 @@ " def __init__(self, depth: int, rngs: nnx.Rngs):\n", " init_fn = nnx.initializers.lecun_normal()\n", "\n", - " # Initialize a sublayer `self.dot1` and annotate its kernel with\n", - " # sharding (None, 'model')\n", + " # Initialize a sublayer `self.dot1` and annotate its kernel with.\n", + " # `sharding (None, 'model')`.\n", " self.dot1 = nnx.Linear(\n", " depth, depth,\n", " kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),\n", " use_bias=False, # or use `bias_init` to give it annotation too\n", " rngs=rngs)\n", "\n", - " # Initialize a weight param `w2` and annotate with sharding ('model', None)\n", + " # Initialize a weight param `w2` and annotate with sharding ('model', None).\n", " # Note that this is simply adding `.sharding` to the variable as metadata!\n", " self.w2 = nnx.Param(\n", " init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n", @@ -179,7 +185,7 @@ "So, when you define `W1` with shape `(depth, depth)` and annotate as `(None, 'model')`:\n", "\n", "* The first dimension will be replicated across all devices.\n", - "* The second dimension will be sharded over the `'model'` axis of the device mesh. This means `W1` will be sharded 4-way on devices `(0, 4)`, `(1, 5)`, `(2, 6)` and `(3, 7)`, on this dimension.\n", + "* The second dimension will be sharded over the `'model'` axis of the device mesh. This means `W1` will be sharded 4-way on devices `(0, 4)`, `(1, 5)`, `(2, 6)` and `(3, 7)`, in this dimension.\n", "\n", "JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide offers more examples and explanations." ] @@ -191,7 +197,7 @@ "source": [ "## Initialize a sharded model\n", "\n", - "Now we have annotations attached to the `nnx.Variable`, but the actual weights haven't been sharded yet. If you just go ahead and create this model, all JAX arrays are still stuck in device 0. In practice, you'd want to avoid this, because a large model will OOM in this situation, and all the other devices are not utilized." + "Now, you have annotations attached to the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights haven't been sharded yet. If you just go ahead and create this model, all JAX arrays are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will \"OOM\" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized." ] }, { @@ -213,11 +219,11 @@ "source": [ "unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))\n", "\n", - "# We have annotations sticked there, yay!\n", + "# You have annotations sticked there, yay!\n", "print(unsharded_model.dot1.kernel.sharding) # (None, 'model')\n", "print(unsharded_model.w2.sharding) # ('model', None)\n", "\n", - "# But the actual arrays are not sharded... wut?\n", + "# But the actual arrays are not sharded?\n", "print(unsharded_model.dot1.kernel.value.sharding) # SingleDeviceSharding\n", "print(unsharded_model.w2.value.sharding) # SingleDeviceSharding" ] @@ -226,9 +232,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We should leverage JAX's compilation mechanism, via `nnx.jit`, to create the sharded model. The key is to intialize a model and assign shardings upon the model state within a jitted function:\n", + "Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function:\n", "\n", - "1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables;\n", + "1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables.\n", "\n", "1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jit` how to shard a variable!\n", "\n", @@ -238,7 +244,7 @@ "\n", "1. Run it under a device mesh context so that JAX knows which devices to shard it to.\n", "\n", - "The whole compiled `create_sharded_model` will directly generate a model with sharded JAX arrays, and no single-device OOM will happen!" + "The entire compiled `create_sharded_model()` function will directly generate a model with sharded JAX arrays, and no single-device \"OOM\" will happen!" ] }, { @@ -258,17 +264,17 @@ "source": [ "@nnx.jit\n", "def create_sharded_model():\n", - " model = DotReluDot(1024, rngs=nnx.Rngs(0)) # unsharded at this moment\n", - " state = nnx.state(model) # the model's state, a pure pytree\n", - " pspecs = nnx.get_partition_spec(state) # strip out the annotations from state\n", + " model = DotReluDot(1024, rngs=nnx.Rngs(0)) # Unsharded at this moment.\n", + " state = nnx.state(model) # The model's state, a pure pytree.\n", + " pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state.\n", " sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n", - " nnx.update(model, sharded_state) # model is sharded now!\n", + " nnx.update(model, sharded_state) # The model is sharded now!\n", " return model\n", "\n", "with mesh:\n", " sharded_model = create_sharded_model()\n", "\n", - "# They are some `GSPMDSharding` now - not single device!\n", + "# They are some `GSPMDSharding` now - not a single device!\n", "print(sharded_model.dot1.kernel.value.sharding)\n", "print(sharded_model.w2.value.sharding)\n", "\n", @@ -387,24 +393,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### On `jax.lax.with_sharding_constraint`\n", + "### On `jax.lax.with_sharding_constraint` (semi-automatic parallelization)\n", "\n", - "The key to shard a JAX array is to call `jax.lax.with_sharding_constraint` inside a jitted function. Note that it will throw an error if not under a JAX device mesh context.\n", + "The key to shard a JAX array is to call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) inside a `jax.jit`ted function. Note that it will throw an error if not under a JAX device mesh context.\n", "\n", - "You may have noticed we also used `jax.lax.with_sharding_constraint` once in the model definition too, to contraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax API, if you want to explicitly shard values that are not model variables.\n", + "> **Note:** Both [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) in the JAX documentation cover automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), and semi-automatic parallelization with `jax.jit` and [`jax.lax.with_sharding_constraint](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) in greater detail.\n", "\n", - "This brings a question: Why use the Flax annotation API then? Why not just add JAX sharding constraints inside the model definition? The most important reason is that you still need the explicit annotations to load a sharded model from an on-disk checkpoint. See the section below." + "You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition too, to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API, if you want to explicitly shard values that are not model variables.\n", + "\n", + "This brings a question: Why use the Flax NNX Annotation API then? Why not just add JAX sharding constraints inside the model definition? The most important reason is that you still need the explicit annotations to load a sharded model from an on-disk checkpoint. This is described in the next section." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Load sharded model from checkpoint\n", + "## Load sharded model from a checkpoint\n", "\n", - "Now we can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries like [Orbax](https://orbax.readthedocs.io/en/latest/) usually supports loading it sharded if a sharding pytree is given.\n", + "Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given.\n", "\n", - "You can generate such as sharding pytree with `nnx.get_named_sharding`. To avoid any real memory allocation, use `nnx.eval_shape` to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n", + "You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n", "\n", "Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs." ] @@ -486,13 +494,13 @@ "source": [ "import orbax.checkpoint as ocp\n", "\n", - "# Save the sharded state\n", + "# Save the sharded state.\n", "sharded_state = nnx.state(sharded_model)\n", "path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')\n", "checkpointer = ocp.StandardCheckpointer()\n", "checkpointer.save(path / 'checkpoint_name', sharded_state)\n", "\n", - "# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`\n", + "# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`.\n", "abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))\n", "abs_state = nnx.state(abs_model)\n", "# Orbax API expects a tree of abstract `jax.ShapeDtypeStruct`\n", @@ -513,11 +521,11 @@ "source": [ "## Compile the training loop\n", "\n", - "Now, from initialization or from checkpoint, we have a sharded model. To carry out the compiled, scaled up training, we need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across `data` device axis, so you should put your data in sharding `('data', None)`. You can use `jax.device_put` for this.\n", + "Now, from initialization or from checkpoint, you have a sharded model. To carry out the compiled, scaled up training, you need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this.\n", "\n", - "Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. See the example below - even without `jax.lax.with_sharding_constraint` on the output `y`, it was still sharded as `('data', None)`.\n", + "Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`.\n", "\n", - "> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will natually shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happened at low level." + "> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happened at low level." ] }, { @@ -569,7 +577,7 @@ } ], "source": [ - "# In data parallelism, the first dimension (batch) will be sharded on `data` axis\n", + "# In data parallelism, the first dimension (batch) will be sharded on `data` axis.\n", "data_sharding = NamedSharding(mesh, PartitionSpec('data', None))\n", "input = jax.device_put(jnp.ones((8, 1024)), data_sharding)\n", "\n", @@ -585,7 +593,7 @@ "source": [ "Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded.\n", "\n", - "`nnx.jit` will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs." + "[`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs." ] }, { @@ -625,7 +633,7 @@ "with mesh:\n", " for i in range(5):\n", " loss = train_step(sharded_model, optimizer, input, label)\n", - " print(loss) # model (over-)fitting to the labels quickly" + " print(loss) # Model (over-)fitting to the labels quickly." ] }, { @@ -634,7 +642,7 @@ "source": [ "## Profiling\n", "\n", - "If you are running on a TPU pod or a pod slice, you can use a custom `block_all` utility function, as defined below, to measure the performance:" + "If you are running on a TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance:" ] }, { @@ -668,9 +676,9 @@ "source": [ "## Logical axis annotation\n", "\n", - "JAX's automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.\n", + "JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD]((https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.\n", "\n", - "You can provide the mapping along with the annotation as another metadata of the corresponding `nnx.Variable`, or overwrite it at top-level. Check out the `LogicalDotReluDot` example below." + "You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below." ] }, { @@ -679,25 +687,25 @@ "metadata": {}, "outputs": [], "source": [ - "# The mapping from alias annotation to device mesh\n", + "# The mapping from alias annotation to the device mesh.\n", "sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None))\n", "\n", "class LogicalDotReluDot(nnx.Module):\n", " def __init__(self, depth: int, rngs: nnx.Rngs):\n", " init_fn = nnx.initializers.lecun_normal()\n", "\n", - " # Initialize a sublayer `self.dot1`\n", + " # Initialize a sublayer `self.dot1`.\n", " self.dot1 = nnx.Linear(\n", " depth, depth,\n", " kernel_init=nnx.with_metadata(\n", - " # We provide the sharding rules here\n", + " # Provide the sharding rules here.\n", " init_fn, sharding=('embed', 'hidden'), sharding_rules=sharding_rules),\n", " use_bias=False,\n", " rngs=rngs)\n", "\n", - " # Initialize a weight param `w2`\n", + " # Initialize a weight param `w2`.\n", " self.w2 = nnx.Param(\n", - " # We didn't provide the sharding rules here - to show you how to overwrite it later\n", + " # Didn't provide the sharding rules here to show you how to overwrite it later.\n", " nnx.with_metadata(init_fn, sharding=('hidden', 'embed'))(\n", " rngs.params(), (depth, depth))\n", " )\n", @@ -705,7 +713,7 @@ " def __call__(self, x: jax.Array):\n", " y = self.dot1(x)\n", " y = jax.nn.relu(y)\n", - " # Unfortunately the logical aliasing doesn't work on lower-level JAX calls\n", + " # Unfortunately the logical aliasing doesn't work on lower-level JAX calls.\n", " y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', None))\n", " z = jnp.dot(y, self.w2.value)\n", " return z" @@ -715,7 +723,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to `nnx.State` of the model, before the call of `nnx.get_partition_spec` or `nnx.get_named_sharding`." + "If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)." ] }, { @@ -814,7 +822,7 @@ "jax.debug.visualize_array_sharding(sharded_logical_model.dot1.kernel.value)\n", "jax.debug.visualize_array_sharding(sharded_logical_model.w2.value)\n", "\n", - "# Check out their equivalency with some easier-to-read sharding descriptions\n", + "# Check out their equivalency with some easier-to-read sharding descriptions.\n", "assert sharded_logical_model.dot1.kernel.value.sharding.is_equivalent_to(\n", " NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2\n", ")\n", @@ -840,7 +848,7 @@ "\n", "* **Device mesh axis**:\n", "\n", - " * For simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming.\n", + " * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", " * Shardings of intermediate *activation* values can only be done via `jax.lax.with_sharding_constraint` and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n", "\n", diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index ca6874a5cf..5a9c1f6a32 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -10,21 +10,26 @@ jupytext: # Scale up on multiple devices -This guide shows how to scale up [Flax NNX Modules](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts using [JAX just-in-time compilation machinery](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). +This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). +++ ## Overview -Flax relies on JAX for numeric computations and for scaling the computations up across multiple devices (GPU, TPU, etc), and the core of scaling up is using [JAX's just-in-time compiler](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, we will mainly use [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps `jax.jit` and works more conveniently with NNX modules. +Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s. -JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and `jit` will automatically compile and run it on multiple devices. +JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and `jax.jit` will automatically compile and run it on multiple devices. -To ensure the compilation performance, you often need to tell JAX how your model's variables are supposed to be sharded across devices. This is where Flax's [Sharding Metadata API](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) comes in - to help you annotate your model variables with this information. +To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information. -> **NOTE to Flax Linen users**: this API is pretty much the same with what you may have learnt in [the Linen version](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on model definition level, but the top-level code will be simpler due to the benefits of NNX, and some text explanation will be more updated and clear. +> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer. -You can learn more about JAX APIs for scaling up in [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) on JAX's documentation site. +If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials: + +- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with `jax.jit`, semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map). +- [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html). +- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with `jax.jit` and `jax.lax.with_sharding_constraint`. Study it after the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html). +- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html). +++ @@ -53,14 +58,15 @@ import optax # Optax for common losses and optimizers. ``` ```{code-cell} ipython3 -print(f'We have 8 fake JAX devices now: {jax.devices()}') +print(f'You have 8 “fake” JAX devices now: {jax.devices()}') ``` The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide: -1. Start a 2x4 device `mesh` (8 devices) using JAX's `mesh_utils.create_device_mesh`. This layout is the same as the one of a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board). +1. Start a 2x4 device `mesh` (8 devices) using the JAX [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh). This layout is the same as on a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board) (also 8 devices). + +2. Annotate each axis with a name using the `axis_names` parameter. A typical way to annotate axis names is `axis_name=('data', 'model')`, where: -2. Annotate each axis with a name using the `axis_names` parameter in `jax.sharding.Mesh`. A typical way to annotate axis names is `axis_name=('data', 'model')`, where: * `'data'`: the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations. * `'model'`: the mesh dimension used for sharding parameters of the model across devices. @@ -73,26 +79,26 @@ print(mesh) ## Define a model with specified sharding -Create an example layer called `DotReluDot`. This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between. +Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between. -To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding `nnx.Variable`. +To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable). -Note that this annotation will be [preserved and adjusted accordingly across lifted transformations](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like `nnx.vmap`, `nnx.scan`), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [NNX transforms guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more. +> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more. ```{code-cell} ipython3 class DotReluDot(nnx.Module): def __init__(self, depth: int, rngs: nnx.Rngs): init_fn = nnx.initializers.lecun_normal() - # Initialize a sublayer `self.dot1` and annotate its kernel with - # sharding (None, 'model') + # Initialize a sublayer `self.dot1` and annotate its kernel with. + # `sharding (None, 'model')`. self.dot1 = nnx.Linear( depth, depth, kernel_init=nnx.with_partitioning(init_fn, (None, 'model')), use_bias=False, # or use `bias_init` to give it annotation too rngs=rngs) - # Initialize a weight param `w2` and annotate with sharding ('model', None) + # Initialize a weight param `w2` and annotate with sharding ('model', None). # Note that this is simply adding `.sharding` to the variable as metadata! self.w2 = nnx.Param( init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation @@ -116,7 +122,7 @@ The so-called "sharding annotations" are essentially tuples of device axis names So, when you define `W1` with shape `(depth, depth)` and annotate as `(None, 'model')`: * The first dimension will be replicated across all devices. -* The second dimension will be sharded over the `'model'` axis of the device mesh. This means `W1` will be sharded 4-way on devices `(0, 4)`, `(1, 5)`, `(2, 6)` and `(3, 7)`, on this dimension. +* The second dimension will be sharded over the `'model'` axis of the device mesh. This means `W1` will be sharded 4-way on devices `(0, 4)`, `(1, 5)`, `(2, 6)` and `(3, 7)`, in this dimension. JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide offers more examples and explanations. @@ -124,23 +130,23 @@ JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs ## Initialize a sharded model -Now we have annotations attached to the `nnx.Variable`, but the actual weights haven't been sharded yet. If you just go ahead and create this model, all JAX arrays are still stuck in device 0. In practice, you'd want to avoid this, because a large model will OOM in this situation, and all the other devices are not utilized. +Now, you have annotations attached to the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights haven't been sharded yet. If you just go ahead and create this model, all JAX arrays are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will "OOM" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized. ```{code-cell} ipython3 unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0)) -# We have annotations sticked there, yay! +# You have annotations sticked there, yay! print(unsharded_model.dot1.kernel.sharding) # (None, 'model') print(unsharded_model.w2.sharding) # ('model', None) -# But the actual arrays are not sharded... wut? +# But the actual arrays are not sharded? print(unsharded_model.dot1.kernel.value.sharding) # SingleDeviceSharding print(unsharded_model.w2.value.sharding) # SingleDeviceSharding ``` -We should leverage JAX's compilation mechanism, via `nnx.jit`, to create the sharded model. The key is to intialize a model and assign shardings upon the model state within a jitted function: +Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function: -1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables; +1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables. 1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jit` how to shard a variable! @@ -150,22 +156,22 @@ We should leverage JAX's compilation mechanism, via `nnx.jit`, to create the sha 1. Run it under a device mesh context so that JAX knows which devices to shard it to. -The whole compiled `create_sharded_model` will directly generate a model with sharded JAX arrays, and no single-device OOM will happen! +The entire compiled `create_sharded_model()` function will directly generate a model with sharded JAX arrays, and no single-device "OOM" will happen! ```{code-cell} ipython3 @nnx.jit def create_sharded_model(): - model = DotReluDot(1024, rngs=nnx.Rngs(0)) # unsharded at this moment - state = nnx.state(model) # the model's state, a pure pytree - pspecs = nnx.get_partition_spec(state) # strip out the annotations from state + model = DotReluDot(1024, rngs=nnx.Rngs(0)) # Unsharded at this moment. + state = nnx.state(model) # The model's state, a pure pytree. + pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state. sharded_state = jax.lax.with_sharding_constraint(state, pspecs) - nnx.update(model, sharded_state) # model is sharded now! + nnx.update(model, sharded_state) # The model is sharded now! return model with mesh: sharded_model = create_sharded_model() -# They are some `GSPMDSharding` now - not single device! +# They are some `GSPMDSharding` now - not a single device! print(sharded_model.dot1.kernel.value.sharding) print(sharded_model.w2.value.sharding) @@ -187,34 +193,36 @@ print("sharded_model.w2 ('model', None) :") jax.debug.visualize_array_sharding(sharded_model.w2.value) ``` -### On `jax.lax.with_sharding_constraint` +### On `jax.lax.with_sharding_constraint` (semi-automatic parallelization) -The key to shard a JAX array is to call `jax.lax.with_sharding_constraint` inside a jitted function. Note that it will throw an error if not under a JAX device mesh context. +The key to shard a JAX array is to call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) inside a `jax.jit`ted function. Note that it will throw an error if not under a JAX device mesh context. -You may have noticed we also used `jax.lax.with_sharding_constraint` once in the model definition too, to contraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax API, if you want to explicitly shard values that are not model variables. +> **Note:** Both [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) in the JAX documentation cover automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), and semi-automatic parallelization with `jax.jit` and [`jax.lax.with_sharding_constraint](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) in greater detail. -This brings a question: Why use the Flax annotation API then? Why not just add JAX sharding constraints inside the model definition? The most important reason is that you still need the explicit annotations to load a sharded model from an on-disk checkpoint. See the section below. +You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition too, to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API, if you want to explicitly shard values that are not model variables. + +This brings a question: Why use the Flax NNX Annotation API then? Why not just add JAX sharding constraints inside the model definition? The most important reason is that you still need the explicit annotations to load a sharded model from an on-disk checkpoint. This is described in the next section. +++ -## Load sharded model from checkpoint +## Load sharded model from a checkpoint -Now we can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries like [Orbax](https://orbax.readthedocs.io/en/latest/) usually supports loading it sharded if a sharding pytree is given. +Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given. -You can generate such as sharding pytree with `nnx.get_named_sharding`. To avoid any real memory allocation, use `nnx.eval_shape` to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree. +You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree. Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs. ```{code-cell} ipython3 import orbax.checkpoint as ocp -# Save the sharded state +# Save the sharded state. sharded_state = nnx.state(sharded_model) path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/') checkpointer = ocp.StandardCheckpointer() checkpointer.save(path / 'checkpoint_name', sharded_state) -# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state` +# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`. abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0))) abs_state = nnx.state(abs_model) # Orbax API expects a tree of abstract `jax.ShapeDtypeStruct` @@ -231,14 +239,14 @@ jax.debug.visualize_array_sharding(loaded_sharded.w2.value) ## Compile the training loop -Now, from initialization or from checkpoint, we have a sharded model. To carry out the compiled, scaled up training, we need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across `data` device axis, so you should put your data in sharding `('data', None)`. You can use `jax.device_put` for this. +Now, from initialization or from checkpoint, you have a sharded model. To carry out the compiled, scaled up training, you need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this. -Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. See the example below - even without `jax.lax.with_sharding_constraint` on the output `y`, it was still sharded as `('data', None)`. +Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`. -> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will natually shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happened at low level. +> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happened at low level. ```{code-cell} ipython3 -# In data parallelism, the first dimension (batch) will be sharded on `data` axis +# In data parallelism, the first dimension (batch) will be sharded on `data` axis. data_sharding = NamedSharding(mesh, PartitionSpec('data', None)) input = jax.device_put(jnp.ones((8, 1024)), data_sharding) @@ -250,7 +258,7 @@ jax.debug.visualize_array_sharding(output) # Also sharded as ('data', None) Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded. -`nnx.jit` will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs. +[`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs. ```{code-cell} ipython3 optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3)) # reference sharing @@ -272,12 +280,12 @@ label = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sha with mesh: for i in range(5): loss = train_step(sharded_model, optimizer, input, label) - print(loss) # model (over-)fitting to the labels quickly + print(loss) # Model (over-)fitting to the labels quickly. ``` ## Profiling -If you are running on a TPU pod or a pod slice, you can use a custom `block_all` utility function, as defined below, to measure the performance: +If you are running on a TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance: ```{code-cell} ipython3 %%timeit @@ -292,30 +300,30 @@ with mesh: ## Logical axis annotation -JAX's automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes. +JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD]((https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes. -You can provide the mapping along with the annotation as another metadata of the corresponding `nnx.Variable`, or overwrite it at top-level. Check out the `LogicalDotReluDot` example below. +You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below. ```{code-cell} ipython3 -# The mapping from alias annotation to device mesh +# The mapping from alias annotation to the device mesh. sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None)) class LogicalDotReluDot(nnx.Module): def __init__(self, depth: int, rngs: nnx.Rngs): init_fn = nnx.initializers.lecun_normal() - # Initialize a sublayer `self.dot1` + # Initialize a sublayer `self.dot1`. self.dot1 = nnx.Linear( depth, depth, kernel_init=nnx.with_metadata( - # We provide the sharding rules here + # Provide the sharding rules here. init_fn, sharding=('embed', 'hidden'), sharding_rules=sharding_rules), use_bias=False, rngs=rngs) - # Initialize a weight param `w2` + # Initialize a weight param `w2`. self.w2 = nnx.Param( - # We didn't provide the sharding rules here - to show you how to overwrite it later + # Didn't provide the sharding rules here to show you how to overwrite it later. nnx.with_metadata(init_fn, sharding=('hidden', 'embed'))( rngs.params(), (depth, depth)) ) @@ -323,13 +331,13 @@ class LogicalDotReluDot(nnx.Module): def __call__(self, x: jax.Array): y = self.dot1(x) y = jax.nn.relu(y) - # Unfortunately the logical aliasing doesn't work on lower-level JAX calls + # Unfortunately the logical aliasing doesn't work on lower-level JAX calls. y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', None)) z = jnp.dot(y, self.w2.value) return z ``` -If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to `nnx.State` of the model, before the call of `nnx.get_partition_spec` or `nnx.get_named_sharding`. +If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). ```{code-cell} ipython3 def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState: @@ -353,7 +361,7 @@ with mesh: jax.debug.visualize_array_sharding(sharded_logical_model.dot1.kernel.value) jax.debug.visualize_array_sharding(sharded_logical_model.w2.value) -# Check out their equivalency with some easier-to-read sharding descriptions +# Check out their equivalency with some easier-to-read sharding descriptions. assert sharded_logical_model.dot1.kernel.value.sharding.is_equivalent_to( NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2 ) @@ -374,7 +382,7 @@ Choosing when to use a device or logical axis depends on how much you want to co * **Device mesh axis**: - * For simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming. + * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming. * Shardings of intermediate *activation* values can only be done via `jax.lax.with_sharding_constraint` and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.