From fe61e7edc5afe65fb96cb95bdbe7e640b088b0e7 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:24:38 +0000 Subject: [PATCH] Upgrade Flax NNX Basics doc --- docs_nnx/nnx_basics.ipynb | 249 +++++++++++++++++-------------------- docs_nnx/nnx_basics.md | 253 +++++++++++++++++--------------------- 2 files changed, 228 insertions(+), 274 deletions(-) diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index 249c937d81..693c74a919 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -4,14 +4,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Flax Basics\n", + "# Flax basics\n", "\n", - "Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug,\n", - "and analyze neural networks in JAX. It achieves this by adding first class support\n", - "for Python reference semantics, allowing users to express their models using regular\n", - "Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference\n", - "sharing and mutability. This design should should make PyTorch or Keras users feel at\n", - "home." + "Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.\n", + "\n", + "In this guide you will learn about:\n", + "\n", + "- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.\n", + " - Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).\n", + " - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm layers.\n", + " - Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.\n", + "- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.\n", + " - [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.\n", + "- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.\n", + " - [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).\n", + " - [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`\n", + " - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.\n", + "\n", + "## Setup\n", + "\n", + "Install Flax with `pip` and impost necessary dependencies:" ] }, { @@ -42,18 +54,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## The Module System\n", - "To begin lets see how to create a `Linear` Module using Flax. The main difference between\n", - "Flax NNX and Module systems like Haiku or Flax Linen is that everything is **explicit**. This\n", - "means among other things that 1) the Module itself holds the state (e.g. parameters) directly,\n", - "2) the RNG state is threaded by the user, and 3) all shape information must be provided on\n", - "initialization (no shape inference).\n", + "## The Flax `nnx.Module` system\n", + "\n", + "The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that:\n", "\n", - "As shown next, dynamic state is usually stored in `nnx.Param`s, and static state\n", - "(all types not handled by Flax) such as integers or strings are stored directly.\n", - "Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic\n", - "state, although storing them inside `nnx.Variable`s such as `Param` is preferred.\n", - "Also, `nnx.Rngs` can be used to get new unique keys starting from a root key." + "1) The [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) itself holds the state (such as parameters) directly.\n", + "2) The [PRNG](https://jax.readthedocs.io/en/latest/random-numbers.html) state is threaded by the user.\n", + "3) All shape information must be provided on initialization (no shape inference).\n", + "\n", + "Let's begin by creating a `Linear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The following code shows that:\n", + "\n", + "- Dynamic state is usually stored in [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s, and static state (all types not handled by NNX), such as integers or strings are stored directly.\n", + "- Attributes of type [`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) and `numpy.ndarray` are also treated as dynamic states, although storing them inside [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s, such as `Param`, is preferred.\n", + "- The [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object can be used to get new unique keys based on a root PRNG key passed to the constructor." ] }, { @@ -77,14 +90,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`nnx.Variable`'s inner values can be accessed using the `.value` property, however\n", - "for convenience they implement all numeric operators and can be used directly in\n", - "arithmetic expressions (as shown above).\n", + "Also note that:\n", + "\n", + "- The inner values of [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above).\n", "\n", - "To actually initialize a Module you simply call the constructor, all the parameters\n", - "of a Module are usually created eagerly. Since Modules hold their own state methods\n", - "can be called directly without the need for a separate `apply` method, this is very\n", - "convenient for debugging as entire structure of the model can be inspected directly." + "To initialize a Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html), you just call the constructor, and all the parameters of a `Module` are usually created eagerly. Since [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s hold their own state methods, you can call them directly without the need for a separate `apply` method.\n", + "This can be very convenient for debugging, allowing you to directly inspect the entire structure of the model." ] }, { @@ -124,18 +135,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The above visualization by `nnx.display` is generated using the awesome [Treescope](https://treescope.readthedocs.io/en/stable/index.html#) library." + "The above visualization by `nnx.display` is generated using the awesome\n", + "[Treescope](https://treescope.readthedocs.io/en/stable/index.html#) library." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Stateful Computation\n", + "### Stateful computation\n", "\n", - "Implementing layers such as `BatchNorm` requires performing state updates during the\n", - "forward pass. To implement this in Flax you just create a `Variable` and update its\n", - "`.value` during the forward pass." + "Implementing layers, such as [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), requires performing state updates during a forward pass. In Flax NNX, you just need to create a [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and update its `.value` during the forward pass." ] }, { @@ -172,20 +182,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Mutable references are usually avoided in JAX, however as we'll see in later sections\n", - "Flax provides sound mechanisms to handle them." + "Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms\n", + "to handle them, as demonstrated in later sections of this guide." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Nested Modules\n", + "### Nested `nnx.Module`s\n", "\n", - "As expected, Modules can be used to compose other Modules in a nested structure, these can\n", - "be assigned directly as attributes, or inside an attribute of any (nested) pytree type e.g.\n", - " `list`, `dict`, `tuple`, etc. In the example below we define a simple `MLP` Module that\n", - "consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer." + "Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.\n", + "\n", + "The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:" ] }, { @@ -229,22 +238,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In Flax `Dropout` is a stateful module that stores an `Rngs` object so that it can generate\n", - "new masks during the forward pass without the need for the user to pass a new key each time." + "In Flax, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) is a stateful module that stores an [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object, so that it can generate new masks during the forward pass without the need for the user to pass a new key each time." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Model Surgery\n", - "Flax NNX Modules are mutable by default, this means their structure can be changed at any time,\n", - "this makes model surgery quite easy as any submodule attribute can be replaced with anything\n", - "else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over,\n", - "`Variable`s can also be modified or replaced / shared.\n", + "### Model surgery\n", + "\n", + "Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s are mutable by default. This means that their structure can be changed at any time, which makes model surgery quite easy as any sub-`Module` attribute can be replaced with anything else, such as new `Module`s, existing shared `Module`s, `Module`s of different types, and so on. Moreover, [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s can also be modified or replaced/shared.\n", "\n", - "The following example shows how to replace the `Linear` layers in the `MLP` model\n", - "from before with `LoraLinear` layers." + "The following example shows how to replace the `Linear` layers in the `MLP` model from the previous example with `LoraLinear` layers:" ] }, { @@ -280,7 +285,7 @@ "rngs = nnx.Rngs(0)\n", "model = MLP(2, 32, 5, rngs=rngs)\n", "\n", - "# model surgery\n", + "# Model surgery.\n", "model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)\n", "model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)\n", "\n", @@ -293,22 +298,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Transforms\n", + "## Flax transformations\n", "\n", - "Flax Transforms extend JAX transforms to support Modules and other objects.\n", - "They are supersets of their equivalent JAX counterpart with the addition of\n", - "being aware of the object's state and providing additional APIs to transform\n", - "it. One of the main features of Flax Transforms is the preservation of reference semantics,\n", - "meaning that any mutation of the object graph that occurs inside the transform is\n", - "propagated outisde as long as its legal within the transform rules. In practice this\n", - "means that Flax programs can be express using imperative code, highly simplifying\n", - "the user experience.\n", + "[Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html) extend [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations) to support [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s and other objects. They serve as supersets of their equivalent JAX counterparts with the addition of being aware of the object's state and providing additional APIs to transform it.\n", "\n", - "In the following example we define a `train_step` function that takes a `MLP` model,\n", - "an `Optimizer`, and a batch of data, and returns the loss for that step. The loss\n", - "and the gradients are computed using the `nnx.value_and_grad` transform over the\n", - "`loss_fn`. The gradients are passed to the optimizer's `update` method to update\n", - "the `model`'s parameters." + "One of the main features of Flax Transforms is the preservation of reference semantics, meaning that any mutation of the object graph that occurs inside the transform is propagated outside as long as it is legal within the transform rules. In practice this means that Flax programs can be express using imperative code, highly simplifying the user experience.\n", + "\n", + "In the following example, you define a `train_step` function that takes a `MLP` model, an [`nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html#module-flax.nnx.optimizer), and a batch of data, and returns the loss for that step. The loss and the gradients are computed using the [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) transform over the `loss_fn`. The gradients are passed to the optimizer's [`nnx.Optimizer.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html#flax.nnx.optimizer.Optimizer.update) method to update the `model`'s parameters." ] }, { @@ -328,18 +324,18 @@ "source": [ "import optax\n", "\n", - "# MLP contains 2 Linear layers, 1 Dropout layer, 1 BatchNorm layer\n", + "# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.\n", "model = MLP(2, 16, 10, rngs=nnx.Rngs(0))\n", "optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing\n", "\n", - "@nnx.jit # automatic state management\n", + "@nnx.jit # Automatic state management\n", "def train_step(model, optimizer, x, y):\n", " def loss_fn(model: MLP):\n", " y_pred = model(x)\n", " return jnp.mean((y_pred - y) ** 2)\n", "\n", " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", - " optimizer.update(grads) # inplace updates\n", + " optimizer.update(grads) # In place updates.\n", "\n", " return loss\n", "\n", @@ -354,29 +350,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Theres a couple of things happening in this example that are worth mentioning:\n", - "1. The updates to the `BatchNorm` and `Dropout` layer's state is automatically propagated\n", - " from within `loss_fn` to `train_step` all the way to the `model` reference outside.\n", - "2. `optimizer` holds a mutable reference to `model`, this relationship is preserved\n", - " inside the `train_step` function making it possible to update the model's parameters\n", - " using the optimizer alone.\n", + "There are two things happening in this example that are worth mentioning:\n", + "\n", + "1. The updates to each of the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer's state is automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside.\n", + "2. The `optimizer` holds a mutable reference to the `model` - this relationship is preserved inside the `train_step` function making it possible to update the model's parameters using the optimizer alone.\n", + "\n", + "### `nnx.scan` over layers\n", "\n", - "#### Scan over layers\n", - "Next lets take a look at a different example, which uses\n", - "[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)\n", - "to create a stack of multiple MLP layers and\n", - "[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)\n", - "to iteratively apply each layer of the stack to the input.\n", + "The next example uses Flax [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) to create a stack of multiple MLP layers and [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) to iteratively apply each layer of the stack to the input.\n", "\n", - "Notice the following:\n", - "1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys\n", - " and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.\n", - "2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.\n", - "3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimics `vmap` which is\n", - " more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,\n", - " and the position of the carry.\n", - "4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated\n", - " by `nnx.scan`." + "In the code below notice the following:\n", + "\n", + "1. The custom `create_model` function takes in a key and returns an `MLP` object, since you create five keys and use [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) over `create_model` a stack of 5 `MLP` objects is created.\n", + "2. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) is used to iteratively apply each `MLP` in the stack to the input `x`.\n", + "3. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) (consciously) deviates from [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) and instead mimics [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap), which is more expressive. [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.\n", + "4. [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) updates for the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layers are automatically propagated by [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan)." ] }, { @@ -428,25 +416,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "How do Flax transforms achieve this? To understand how Flax objects interact with\n", - "JAX transforms lets take a look at the Functional API." + "How do Flax NNX transforms achieve this? To understand how Flax NNX objects interact with JAX transforms, the next section explains the Flax NNX Functional API." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## The Functional API\n", + "## The Flax Functional API\n", + "\n", + "The Flax NNX Functional API establishes a clear boundary between reference/object semantics and value/pytree semantics. It also allows the same amount of fine-grained control over the state that Flax Linen and Haiku users are used to. The Flax NNX Functional API consists of three basic methods: [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update).\n", "\n", - "The Functional API establishes a clear boundary between reference/object semantics and\n", - "value/pytree semantics. It also allows same amount of fine-grained control over the\n", - "state that Linen/Haiku users are used to. The Functional API consists of 3 basic methods:\n", - "`split`, `merge`, and `update`.\n", + "Below is an example of of `StatefulLinear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) that uses the Functional API. It contains:\n", "\n", - "The `StatefulLinear` Module shown below will serve as an example for the use of the\n", - "Functional API. It contains some `nnx.Param` Variables and a custom `Count` Variable\n", - "type which is used to keep track of integer scalar state that increases on every\n", - "forward pass." + "- Some [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)sl and\n", + "- A custom `Count()` [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type, which is used to track the integer scalar state that increases on every forward pass." ] }, { @@ -490,12 +474,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### State and GraphDef\n", + "### `State` and `GraphDef`\n", "\n", - "A Module can be decomposed into `GraphDef` and `State` using the\n", - "`split` function. State is a Mapping from strings to Variables or nested\n", - "States. GraphDef contains all the static information needed to reconstruct\n", - "a Module graph, it is analogous to JAX's `PyTreeDef`." + "A Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) can be decomposed into [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) using the [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) function:\n", + "\n", + "- [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) is a `Mapping` from strings to `Variable`s or nested [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.\n", + "- [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) contains all the static information needed to reconstruct a [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) graph, it is analogous to [JAX's `PyTreeDef`](https://jax.readthedocs.io/en/latest/pytrees.html#internal-pytree-handling)." ] }, { @@ -538,13 +522,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Split, Merge, and Update\n", + "### `split`, `merge`, and `update`\n", + "\n", + "Flax's [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) is the reverse of [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). It takes the [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) + [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and reconstructs the [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The example below demonstrates this as follows:\n", "\n", - "`merge` is the reverse of `split`, it takes the GraphDef + State and reconstructs\n", - "the Module. As shown in the example below, by using `split` and `merge` in sequence\n", - "any Module can be lifted to be used in any JAX transform. `update` can\n", - "update an object inplace with the content of a given State. This pattern is used to\n", - "propagate the state from a transform back to the source object outside." + "- By using [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) and [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) in sequence any `Module` can be lifted to be used in any JAX transform.\n", + "- [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) can update an object in place with the content of a given [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State).\n", + "- This pattern is used to propagate the state from a transform back to the source object outside." ] }, { @@ -564,21 +548,21 @@ "source": [ "print(f'{model.count.value = }')\n", "\n", - "# 1. Use split to create a pytree representation of the Module\n", + "# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.\n", "graphdef, state = nnx.split(model)\n", "\n", "@jax.jit\n", "def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:\n", - " # 2. Use merge to create a new model inside the JAX transformation\n", + " # 2. Use `nnx.merge` to create a new model inside the JAX transformation.\n", " model = nnx.merge(graphdef, state)\n", - " # 3. Call the Module\n", + " # 3. Call the `nnx.Module`\n", " y = model(x)\n", - " # 4. Use split to propagate State updates\n", + " # 4. Use `nnx.split` to propagate `nnx.State` updates.\n", " _, state = nnx.split(model)\n", " return y, state\n", "\n", "y, state = forward(graphdef, state, x=jnp.ones((1, 3)))\n", - "# 5. Update the state of the original Module\n", + "# 5. Update the state of the original `nnx.Module`.\n", "nnx.update(model, state)\n", "\n", "print(f'{model.count.value = }')" @@ -588,34 +572,27 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The key insight of this pattern is that using mutable references is\n", - "fine within a transform context (including the base eager interpreter)\n", - "but its necessary to use the Functional API when crossing boundaries.\n", + "The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries.\n", "\n", - "**Why aren't Module's just Pytrees?** The main reason is that it is very\n", - "easy to lose track of shared references by accident this way, for example\n", - "if you pass two Module that have a shared Module through a JAX boundary\n", - "you will silently lose that sharing. The Functional API makes this\n", - "behavior explicit, and thus it is much easier to reason about." + "**Why aren't Flax `nnx.Module`s just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Fine-grained State Control\n", + "### Fine-grained `State` control\n", + "\n", + "Experienced [Flax Linen](https://flax-linen.readthedocs.io/) or [Haiku](https://dm-haiku.readthedocs.io/) API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This a common occurrence when interacting with [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n", + "\n", + "For example:\n", + "\n", + "- Not every model state can or should be differentiated when interacting with [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad).\n", + "- Or, sometimes, there is a need to specify what part of the model's state is a carry and what part is not when using [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan).\n", "\n", - "Seasoned Linen and Haiku users might recognize that having all the state in\n", - "a single structure is not always the best choice as there are cases in which\n", - "you might want to handle different subsets of the state differently. This a\n", - "common occurrence when interacting with JAX transforms, for example, not all\n", - "the model's state can or should be differentiated when interacting which `grad`,\n", - "or sometimes there is a need to specify what part of the model's state is a\n", - "carry and what part is not when using `scan`.\n", + "To address this, the Flax NNX API has [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), which allows you to pass one or more [`nnx.filterlib.Filter`s](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to partition the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s into mutually exclusive [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. Flax NNX uses `Filter` create `State` groups in APIs (such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of NNX transforms).\n", "\n", - "To solve this, `split` allows you to pass one or more `Filter`s to partition\n", - "the Variables into mutually exclusive States. The most common Filter being\n", - "types as shown below." + "The example below shows the most common `Filter`s:" ] }, { @@ -649,7 +626,7 @@ } ], "source": [ - "# use Variable type filters to split into multiple States\n", + "# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.\n", "graphdef, params, counts = nnx.split(model, nnx.Param, Count)\n", "\n", "nnx.display(params, counts)" @@ -659,9 +636,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that filters must be exhaustive, if a value is not matched an error will be raised.\n", + "**Note:** [`nnx.filterlib.Filter`s](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)s must be exhaustive, if a value is not matched an error will be raised.\n", "\n", - "As expected the `merge` and `update` methods naturally consume multiple States:" + "As expected, the [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) and [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) methods naturally consume multiple [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s:" ] }, { @@ -670,9 +647,9 @@ "metadata": {}, "outputs": [], "source": [ - "# merge multiple States\n", + "# Merge multiple `State`s\n", "model = nnx.merge(graphdef, params, counts)\n", - "# update with multiple States\n", + "# Update with multiple `State`s\n", "nnx.update(model, params, counts)" ] } diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md index 0b2bf564fb..c0d37cd98e 100644 --- a/docs_nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -8,14 +8,26 @@ jupytext: jupytext_version: 1.13.8 --- -# Flax Basics +# Flax basics -Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, -and analyze neural networks in JAX. It achieves this by adding first class support -for Python reference semantics, allowing users to express their models using regular -Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference -sharing and mutability. This design should should make PyTorch or Keras users feel at -home. +Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home. + +In this guide you will learn about: + +- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer. + - Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass). + - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm layers. + - Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers. +- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management. + - [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers. +- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state. + - [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef). + - [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update` + - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. + +## Setup + +Install Flax with `pip` and impost necessary dependencies: ```{code-cell} ipython3 :tags: [skip-execution] @@ -29,18 +41,19 @@ import jax import jax.numpy as jnp ``` -## The Module System -To begin lets see how to create a `Linear` Module using Flax. The main difference between -Flax NNX and Module systems like Haiku or Flax Linen is that everything is **explicit**. This -means among other things that 1) the Module itself holds the state (e.g. parameters) directly, -2) the RNG state is threaded by the user, and 3) all shape information must be provided on -initialization (no shape inference). +## The Flax `nnx.Module` system + +The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that: + +1) The [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) itself holds the state (such as parameters) directly. +2) The [PRNG](https://jax.readthedocs.io/en/latest/random-numbers.html) state is threaded by the user. +3) All shape information must be provided on initialization (no shape inference). -As shown next, dynamic state is usually stored in `nnx.Param`s, and static state -(all types not handled by Flax) such as integers or strings are stored directly. -Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic -state, although storing them inside `nnx.Variable`s such as `Param` is preferred. -Also, `nnx.Rngs` can be used to get new unique keys starting from a root key. +Let's begin by creating a `Linear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The following code shows that: + +- Dynamic state is usually stored in [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s, and static state (all types not handled by NNX), such as integers or strings are stored directly. +- Attributes of type [`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) and `numpy.ndarray` are also treated as dynamic states, although storing them inside [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s, such as `Param`, is preferred. +- The [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object can be used to get new unique keys based on a root PRNG key passed to the constructor. ```{code-cell} ipython3 class Linear(nnx.Module): @@ -54,14 +67,12 @@ class Linear(nnx.Module): return x @ self.w + self.b ``` -`nnx.Variable`'s inner values can be accessed using the `.value` property, however -for convenience they implement all numeric operators and can be used directly in -arithmetic expressions (as shown above). +Also note that: + +- The inner values of [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above). -To actually initialize a Module you simply call the constructor, all the parameters -of a Module are usually created eagerly. Since Modules hold their own state methods -can be called directly without the need for a separate `apply` method, this is very -convenient for debugging as entire structure of the model can be inspected directly. +To initialize a Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html), you just call the constructor, and all the parameters of a `Module` are usually created eagerly. Since [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s hold their own state methods, you can call them directly without the need for a separate `apply` method. +This can be very convenient for debugging, allowing you to directly inspect the entire structure of the model. ```{code-cell} ipython3 model = Linear(2, 5, rngs=nnx.Rngs(params=0)) @@ -71,15 +82,14 @@ print(y) nnx.display(model) ``` -The above visualization by `nnx.display` is generated using the awesome [Treescope](https://treescope.readthedocs.io/en/stable/index.html#) library. +The above visualization by `nnx.display` is generated using the awesome +[Treescope](https://treescope.readthedocs.io/en/stable/index.html#) library. +++ -### Stateful Computation +### Stateful computation -Implementing layers such as `BatchNorm` requires performing state updates during the -forward pass. To implement this in Flax you just create a `Variable` and update its -`.value` during the forward pass. +Implementing layers, such as [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), requires performing state updates during a forward pass. In Flax NNX, you just need to create a [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and update its `.value` during the forward pass. ```{code-cell} ipython3 class Count(nnx.Variable): pass @@ -97,17 +107,16 @@ counter() print(f'{counter.count.value = }') ``` -Mutable references are usually avoided in JAX, however as we'll see in later sections -Flax provides sound mechanisms to handle them. +Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms +to handle them, as demonstrated in later sections of this guide. +++ -### Nested Modules +### Nested `nnx.Module`s + +Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on. -As expected, Modules can be used to compose other Modules in a nested structure, these can -be assigned directly as attributes, or inside an attribute of any (nested) pytree type e.g. - `list`, `dict`, `tuple`, etc. In the example below we define a simple `MLP` Module that -consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer. +The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer: ```{code-cell} ipython3 class MLP(nnx.Module): @@ -128,19 +137,15 @@ y = model(x=jnp.ones((3, 2))) nnx.display(model) ``` -In Flax `Dropout` is a stateful module that stores an `Rngs` object so that it can generate -new masks during the forward pass without the need for the user to pass a new key each time. +In Flax, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) is a stateful module that stores an [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object, so that it can generate new masks during the forward pass without the need for the user to pass a new key each time. +++ -#### Model Surgery -Flax NNX Modules are mutable by default, this means their structure can be changed at any time, -this makes model surgery quite easy as any submodule attribute can be replaced with anything -else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, -`Variable`s can also be modified or replaced / shared. +### Model surgery -The following example shows how to replace the `Linear` layers in the `MLP` model -from before with `LoraLinear` layers. +Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s are mutable by default. This means that their structure can be changed at any time, which makes model surgery quite easy as any sub-`Module` attribute can be replaced with anything else, such as new `Module`s, existing shared `Module`s, `Module`s of different types, and so on. Moreover, [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s can also be modified or replaced/shared. + +The following example shows how to replace the `Linear` layers in the `MLP` model from the previous example with `LoraLinear` layers: ```{code-cell} ipython3 class LoraParam(nnx.Param): pass @@ -157,7 +162,7 @@ class LoraLinear(nnx.Module): rngs = nnx.Rngs(0) model = MLP(2, 32, 5, rngs=rngs) -# model surgery +# Model surgery. model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs) model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs) @@ -166,38 +171,29 @@ y = model(x=jnp.ones((3, 2))) nnx.display(model) ``` -## Transforms +## Flax transformations + +[Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html) extend [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations) to support [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s and other objects. They serve as supersets of their equivalent JAX counterparts with the addition of being aware of the object's state and providing additional APIs to transform it. -Flax Transforms extend JAX transforms to support Modules and other objects. -They are supersets of their equivalent JAX counterpart with the addition of -being aware of the object's state and providing additional APIs to transform -it. One of the main features of Flax Transforms is the preservation of reference semantics, -meaning that any mutation of the object graph that occurs inside the transform is -propagated outisde as long as its legal within the transform rules. In practice this -means that Flax programs can be express using imperative code, highly simplifying -the user experience. +One of the main features of Flax Transforms is the preservation of reference semantics, meaning that any mutation of the object graph that occurs inside the transform is propagated outside as long as it is legal within the transform rules. In practice this means that Flax programs can be express using imperative code, highly simplifying the user experience. -In the following example we define a `train_step` function that takes a `MLP` model, -an `Optimizer`, and a batch of data, and returns the loss for that step. The loss -and the gradients are computed using the `nnx.value_and_grad` transform over the -`loss_fn`. The gradients are passed to the optimizer's `update` method to update -the `model`'s parameters. +In the following example, you define a `train_step` function that takes a `MLP` model, an [`nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html#module-flax.nnx.optimizer), and a batch of data, and returns the loss for that step. The loss and the gradients are computed using the [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) transform over the `loss_fn`. The gradients are passed to the optimizer's [`nnx.Optimizer.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html#flax.nnx.optimizer.Optimizer.update) method to update the `model`'s parameters. ```{code-cell} ipython3 import optax -# MLP contains 2 Linear layers, 1 Dropout layer, 1 BatchNorm layer +# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer. model = MLP(2, 16, 10, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing -@nnx.jit # automatic state management +@nnx.jit # Automatic state management def train_step(model, optimizer, x, y): def loss_fn(model: MLP): y_pred = model(x) return jnp.mean((y_pred - y) ** 2) loss, grads = nnx.value_and_grad(loss_fn)(model) - optimizer.update(grads) # inplace updates + optimizer.update(grads) # In place updates. return loss @@ -208,29 +204,21 @@ print(f'{loss = }') print(f'{optimizer.step.value = }') ``` -Theres a couple of things happening in this example that are worth mentioning: -1. The updates to the `BatchNorm` and `Dropout` layer's state is automatically propagated - from within `loss_fn` to `train_step` all the way to the `model` reference outside. -2. `optimizer` holds a mutable reference to `model`, this relationship is preserved - inside the `train_step` function making it possible to update the model's parameters - using the optimizer alone. - -#### Scan over layers -Next lets take a look at a different example, which uses -[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) -to create a stack of multiple MLP layers and -[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) -to iteratively apply each layer of the stack to the input. - -Notice the following: -1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys - and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created. -2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`. -3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimics `vmap` which is - more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output, - and the position of the carry. -4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated - by `nnx.scan`. +There are two things happening in this example that are worth mentioning: + +1. The updates to each of the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer's state is automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside. +2. The `optimizer` holds a mutable reference to the `model` - this relationship is preserved inside the `train_step` function making it possible to update the model's parameters using the optimizer alone. + +### `nnx.scan` over layers + +The next example uses Flax [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) to create a stack of multiple MLP layers and [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) to iteratively apply each layer of the stack to the input. + +In the code below notice the following: + +1. The custom `create_model` function takes in a key and returns an `MLP` object, since you create five keys and use [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) over `create_model` a stack of 5 `MLP` objects is created. +2. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) is used to iteratively apply each `MLP` in the stack to the input `x`. +3. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) (consciously) deviates from [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) and instead mimics [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap), which is more expressive. [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry. +4. [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) updates for the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layers are automatically propagated by [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan). ```{code-cell} ipython3 @nnx.vmap(in_axes=0, out_axes=0) @@ -252,22 +240,18 @@ print(f'{y.shape = }') nnx.display(model) ``` -How do Flax transforms achieve this? To understand how Flax objects interact with -JAX transforms lets take a look at the Functional API. +How do Flax NNX transforms achieve this? To understand how Flax NNX objects interact with JAX transforms, the next section explains the Flax NNX Functional API. +++ -## The Functional API +## The Flax Functional API + +The Flax NNX Functional API establishes a clear boundary between reference/object semantics and value/pytree semantics. It also allows the same amount of fine-grained control over the state that Flax Linen and Haiku users are used to. The Flax NNX Functional API consists of three basic methods: [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update). -The Functional API establishes a clear boundary between reference/object semantics and -value/pytree semantics. It also allows same amount of fine-grained control over the -state that Linen/Haiku users are used to. The Functional API consists of 3 basic methods: -`split`, `merge`, and `update`. +Below is an example of of `StatefulLinear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) that uses the Functional API. It contains: -The `StatefulLinear` Module shown below will serve as an example for the use of the -Functional API. It contains some `nnx.Param` Variables and a custom `Count` Variable -type which is used to keep track of integer scalar state that increases on every -forward pass. +- Some [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)sl and +- A custom `Count()` [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type, which is used to track the integer scalar state that increases on every forward pass. ```{code-cell} ipython3 class Count(nnx.Variable): pass @@ -288,12 +272,12 @@ y = model(jnp.ones((1, 3))) nnx.display(model) ``` -### State and GraphDef +### `State` and `GraphDef` -A Module can be decomposed into `GraphDef` and `State` using the -`split` function. State is a Mapping from strings to Variables or nested -States. GraphDef contains all the static information needed to reconstruct -a Module graph, it is analogous to JAX's `PyTreeDef`. +A Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) can be decomposed into [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) using the [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) function: + +- [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) is a `Mapping` from strings to `Variable`s or nested [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. +- [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) contains all the static information needed to reconstruct a [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) graph, it is analogous to [JAX's `PyTreeDef`](https://jax.readthedocs.io/en/latest/pytrees.html#internal-pytree-handling). ```{code-cell} ipython3 graphdef, state = nnx.split(model) @@ -301,77 +285,70 @@ graphdef, state = nnx.split(model) nnx.display(graphdef, state) ``` -### Split, Merge, and Update +### `split`, `merge`, and `update` + +Flax's [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) is the reverse of [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). It takes the [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) + [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and reconstructs the [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The example below demonstrates this as follows: -`merge` is the reverse of `split`, it takes the GraphDef + State and reconstructs -the Module. As shown in the example below, by using `split` and `merge` in sequence -any Module can be lifted to be used in any JAX transform. `update` can -update an object inplace with the content of a given State. This pattern is used to -propagate the state from a transform back to the source object outside. +- By using [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) and [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) in sequence any `Module` can be lifted to be used in any JAX transform. +- [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) can update an object in place with the content of a given [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State). +- This pattern is used to propagate the state from a transform back to the source object outside. ```{code-cell} ipython3 print(f'{model.count.value = }') -# 1. Use split to create a pytree representation of the Module +# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`. graphdef, state = nnx.split(model) @jax.jit def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]: - # 2. Use merge to create a new model inside the JAX transformation + # 2. Use `nnx.merge` to create a new model inside the JAX transformation. model = nnx.merge(graphdef, state) - # 3. Call the Module + # 3. Call the `nnx.Module` y = model(x) - # 4. Use split to propagate State updates + # 4. Use `nnx.split` to propagate `nnx.State` updates. _, state = nnx.split(model) return y, state y, state = forward(graphdef, state, x=jnp.ones((1, 3))) -# 5. Update the state of the original Module +# 5. Update the state of the original `nnx.Module`. nnx.update(model, state) print(f'{model.count.value = }') ``` -The key insight of this pattern is that using mutable references is -fine within a transform context (including the base eager interpreter) -but its necessary to use the Functional API when crossing boundaries. +The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries. -**Why aren't Module's just Pytrees?** The main reason is that it is very -easy to lose track of shared references by accident this way, for example -if you pass two Module that have a shared Module through a JAX boundary -you will silently lose that sharing. The Functional API makes this -behavior explicit, and thus it is much easier to reason about. +**Why aren't Flax `nnx.Module`s just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about. +++ -### Fine-grained State Control +### Fine-grained `State` control + +Experienced [Flax Linen](https://flax-linen.readthedocs.io/) or [Haiku](https://dm-haiku.readthedocs.io/) API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This a common occurrence when interacting with [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations). + +For example: + +- Not every model state can or should be differentiated when interacting with [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad). +- Or, sometimes, there is a need to specify what part of the model's state is a carry and what part is not when using [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan). -Seasoned Linen and Haiku users might recognize that having all the state in -a single structure is not always the best choice as there are cases in which -you might want to handle different subsets of the state differently. This a -common occurrence when interacting with JAX transforms, for example, not all -the model's state can or should be differentiated when interacting which `grad`, -or sometimes there is a need to specify what part of the model's state is a -carry and what part is not when using `scan`. +To address this, the Flax NNX API has [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), which allows you to pass one or more [`nnx.filterlib.Filter`s](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to partition the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s into mutually exclusive [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. Flax NNX uses `Filter` create `State` groups in APIs (such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of NNX transforms). -To solve this, `split` allows you to pass one or more `Filter`s to partition -the Variables into mutually exclusive States. The most common Filter being -types as shown below. +The example below shows the most common `Filter`s: ```{code-cell} ipython3 -# use Variable type filters to split into multiple States +# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s. graphdef, params, counts = nnx.split(model, nnx.Param, Count) nnx.display(params, counts) ``` -Note that filters must be exhaustive, if a value is not matched an error will be raised. +**Note:** [`nnx.filterlib.Filter`s](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)s must be exhaustive, if a value is not matched an error will be raised. -As expected the `merge` and `update` methods naturally consume multiple States: +As expected, the [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) and [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) methods naturally consume multiple [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s: ```{code-cell} ipython3 -# merge multiple States +# Merge multiple `State`s model = nnx.merge(graphdef, params, counts) -# update with multiple States +# Update with multiple `State`s nnx.update(model, params, counts) ```