Skip to content

Commit

Permalink
improve nn.scan docs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Feb 3, 2023
1 parent 06323f2 commit bf34f90
Showing 1 changed file with 76 additions and 38 deletions.
114 changes: 76 additions & 38 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,47 +738,85 @@ def scan(target: Target,
Example::
import flax
import flax.linen as nn
from jax import random
class SimpleScan(nn.Module):
@nn.compact
def __call__(self, c, xs):
LSTM = nn.scan(nn.LSTMCell,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=1,
out_axes=1)
return LSTM()(c, xs)
seq_len, batch_size, in_feat, out_feat = 20, 16, 3, 5
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)
xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
init_carry = nn.LSTMCell.initialize_carry(key_2, (batch_size,), out_feat)
model = SimpleScan()
variables = model.init(key_3, init_carry, xs)
out_carry, out_val = model.apply(variables, init_carry, xs)
assert out_val.shape == (batch_size, seq_len, out_feat)
>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
...
>>> class LSTM(nn.Module):
... features: int
...
... @nn.compact
... def __call__(self, x):
... batch_size = x.shape[0]
... ScanLSTMCell = nn.scan(
... nn.LSTMCell, variable_broadcast="params",
... split_rngs={"params": False}, in_axes=1, out_axes=1)
...
... carry = nn.LSTMCell.initialize_carry(
... jax.random.PRNGKey(0), (batch_size,), self.features)
... carry, x = ScanLSTMCell()(carry, x)
... return x
...
>>> x = jnp.ones((4, 12, 7))
>>> module = LSTM(features=32)
>>> y, variables = module.init_with_output(jax.random.PRNGKey(0), x)
Note that when providing a function to ``nn.scan``, the scanning happens over
all arguments starting from the third argument, as specified by ``in_axes``.
So in the following example, the input that are being scanned over are ``xs``,
``*args*``, and ``**kwargs``::
def body_fn(cls, carry, xs, *args, **kwargs):
extended_states = cls.some_fn(xs, carry, *args, **kwargs)
return extended_states
scan_fn = nn.scan(
body_fn,
in_axes=0, # scan over axis 0 from third arg of body_fn onwards.
variable_axes=SCAN_VARIABLE_AXES,
split_rngs=SCAN_SPLIT_RNGS)
The previous example could also be written using the functional form as::
>>> class LSTM(nn.Module):
... features: int
...
... @nn.compact
... def __call__(self, x):
... batch_size = x.shape[0]
...
... cell = nn.LSTMCell()
... def body_fn(cell, carry, x):
... carry, y = cell(carry, x)
... return carry, y
... scan = nn.scan(
... body_fn, variable_broadcast="params",
... split_rngs={"params": False}, in_axes=1, out_axes=1)
...
... carry = nn.LSTMCell.initialize_carry(
... jax.random.PRNGKey(0), (batch_size,), self.features)
... carry, x = scan(cell, carry, x)
... return x
...
>>> module = LSTM(features=32)
>>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((4, 12, 7)))
You can also use ``scan`` to reduce the compilation time of your JAX program
by merging multiple layers into a single scan loop, you can do this when
you have a sequence of identical layers that you want to apply iteratively
to an input. For example::
>>> class ResidualMLPBlock(nn.Module):
... @nn.compact
... def __call__(self, x, _):
... h = nn.Dense(features=2)(x)
... h = nn.relu(h)
... return x + h, None
...
>>> class ResidualMLP(nn.Module):
... n_layers: int = 4
...
... @nn.compact
... def __call__(self, x):
... ScanMLP = nn.scan(
... ResidualMLPBlock, variable_axes={'params': 0},
... variable_broadcast=False, split_rngs={'params': True},
... length=self.n_layers)
... x, _ = ScanMLP()(x, None)
... return x
...
>>> model = ResidualMLP(n_layers=4)
>>> variables = model.init(jax.random.PRNGKey(42), jnp.ones((1, 2)))
To reduce both compilation and memory usage, you can use :func:`remat_scan`
which will in addition checkpoint each layer in the scan loop.
Args:
target: a ``Module`` or a function taking a ``Module``
Expand Down

0 comments on commit bf34f90

Please sign in to comment.