Skip to content

Commit

Permalink
Merge pull request #4157 from google:nnx-improve-scan
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 670690186
  • Loading branch information
Flax Authors committed Sep 3, 2024
2 parents 98dff5e + aeeda79 commit 35c4edf
Show file tree
Hide file tree
Showing 5 changed files with 608 additions and 264 deletions.
2 changes: 1 addition & 1 deletion docs/nnx/haiku_linen_vs_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ in ``__init__`` to scan over the sequence.
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=1
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)

return y
Expand Down
10 changes: 5 additions & 5 deletions docs/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@
"1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys\n",
" and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.\n",
"2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.\n",
"3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is\n",
"3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimics `vmap` which is\n",
" more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,\n",
" and the position of the carry.\n",
"4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated\n",
Expand Down Expand Up @@ -415,13 +415,13 @@
"keys = jax.random.split(jax.random.key(0), 5)\n",
"model = create_model(keys)\n",
"\n",
"@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)\n",
"def forward(x, model: MLP):\n",
"@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)\n",
"def forward(model: MLP, x):\n",
" x = model(x)\n",
" return x, None\n",
" return x\n",
"\n",
"x = jnp.ones((3, 10))\n",
"y, _ = forward(x, model)\n",
"y = forward(model, x)\n",
"\n",
"print(f'{y.shape = }')\n",
"nnx.display(model)"
Expand Down
10 changes: 5 additions & 5 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Notice the following:
1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys
and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.
2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.
3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is
3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimics `vmap` which is
more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,
and the position of the carry.
4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated
Expand All @@ -243,13 +243,13 @@ def create_model(key: jax.Array):
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)
def forward(x, model: MLP):
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
x = model(x)
return x, None
return x
x = jnp.ones((3, 10))
y, _ = forward(x, model)
y = forward(model, x)
print(f'{y.shape = }')
nnx.display(model)
Expand Down
Loading

0 comments on commit 35c4edf

Please sign in to comment.