Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve nn.scan docs #2839

Merged
merged 1 commit into from
Mar 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 76 additions & 38 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,47 +738,85 @@ def scan(target: Target,

Example::

import flax
import flax.linen as nn
from jax import random

class SimpleScan(nn.Module):
@nn.compact
def __call__(self, c, xs):
LSTM = nn.scan(nn.LSTMCell,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=1,
out_axes=1)
return LSTM()(c, xs)

seq_len, batch_size, in_feat, out_feat = 20, 16, 3, 5
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
init_carry = nn.LSTMCell.initialize_carry(key_2, (batch_size,), out_feat)

model = SimpleScan()
variables = model.init(key_3, init_carry, xs)
out_carry, out_val = model.apply(variables, init_carry, xs)

assert out_val.shape == (batch_size, seq_len, out_feat)

>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
...
>>> class LSTM(nn.Module):
... features: int
...
... @nn.compact
... def __call__(self, x):
... batch_size = x.shape[0]
... ScanLSTMCell = nn.scan(
... nn.LSTMCell, variable_broadcast="params",
... split_rngs={"params": False}, in_axes=1, out_axes=1)
...
... carry = nn.LSTMCell.initialize_carry(
... jax.random.PRNGKey(0), (batch_size,), self.features)
... carry, x = ScanLSTMCell()(carry, x)
... return x
...
>>> x = jnp.ones((4, 12, 7))
>>> module = LSTM(features=32)
>>> y, variables = module.init_with_output(jax.random.PRNGKey(0), x)

Note that when providing a function to ``nn.scan``, the scanning happens over
all arguments starting from the third argument, as specified by ``in_axes``.
So in the following example, the input that are being scanned over are ``xs``,
``*args*``, and ``**kwargs``::

def body_fn(cls, carry, xs, *args, **kwargs):
extended_states = cls.some_fn(xs, carry, *args, **kwargs)
return extended_states

scan_fn = nn.scan(
body_fn,
in_axes=0, # scan over axis 0 from third arg of body_fn onwards.
variable_axes=SCAN_VARIABLE_AXES,
split_rngs=SCAN_SPLIT_RNGS)
The previous example could also be written using the functional form as::

>>> class LSTM(nn.Module):
... features: int
...
... @nn.compact
... def __call__(self, x):
... batch_size = x.shape[0]
...
... cell = nn.LSTMCell()
... def body_fn(cell, carry, x):
... carry, y = cell(carry, x)
... return carry, y
... scan = nn.scan(
... body_fn, variable_broadcast="params",
... split_rngs={"params": False}, in_axes=1, out_axes=1)
...
... carry = nn.LSTMCell.initialize_carry(
... jax.random.PRNGKey(0), (batch_size,), self.features)
... carry, x = scan(cell, carry, x)
... return x
...
>>> module = LSTM(features=32)
>>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((4, 12, 7)))

You can also use ``scan`` to reduce the compilation time of your JAX program
by merging multiple layers into a single scan loop, you can do this when
you have a sequence of identical layers that you want to apply iteratively
to an input. For example::

>>> class ResidualMLPBlock(nn.Module):
... @nn.compact
... def __call__(self, x, _):
... h = nn.Dense(features=2)(x)
... h = nn.relu(h)
... return x + h, None
...
>>> class ResidualMLP(nn.Module):
... n_layers: int = 4
...
... @nn.compact
... def __call__(self, x):
... ScanMLP = nn.scan(
... ResidualMLPBlock, variable_axes={'params': 0},
... variable_broadcast=False, split_rngs={'params': True},
... length=self.n_layers)
... x, _ = ScanMLP()(x, None)
... return x
...
>>> model = ResidualMLP(n_layers=4)
>>> variables = model.init(jax.random.PRNGKey(42), jnp.ones((1, 2)))

To reduce both compilation and memory usage, you can use :func:`remat_scan`
which will in addition checkpoint each layer in the scan loop.

Args:
target: a ``Module`` or a function taking a ``Module``
Expand Down