Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful now supports vmap #466

Merged
merged 3 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 101 additions & 5 deletions docs/api/nn/stateful.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,113 @@
# Stateful operations

These are the tools that underly stateful operations, like [`equinox.nn.BatchNorm`][] or [`equinox.nn.SpectralNorm`][].
These are the tools that underly stateful operations, like [`equinox.nn.BatchNorm`][] or [`equinox.nn.SpectralNorm`][]. These are fairly unusual layers, so most users will not need this part of the API.

See the [stateful example](../../examples/stateful.ipynb) for an example of working with stateful operations.
!!! Example

::: equinox.nn.State
The [stateful example](../../examples/stateful.ipynb) is a good reference for the typical workflow for stateful layers.

---

::: equinox.nn.make_with_state

## Extra features

Let's explain how this works under the hood. First of all, all stateful layers (`BatchNorm` etc.) include an "index". This is basically just a unique hashable value (used later as a dictionary key), and an initial value for the state:

::: equinox.nn.StateIndex
selection:
members:
- __init__

---

::: equinox.nn.StateIndex
This `State` object that's being passed around is essentially just a dictionary, mapping from `StateIndex`s to PyTrees-of-arrays. Correspondingly this has `.get` and `.set` methods to read and write values to it.

::: equinox.nn.State
selection:
members:
- __init__
- get
- set
- substate
- update

## Custom stateful layers

Let's use [`equinox.nn.StateIndex`][] to create a custom stateful layer.

```python
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array

class Counter(eqx.Module):
index: eqx.nn.StateIndex

def __init__(self):
init_state = jnp.array(0)
self.index = eqx.nn.StateIndex(init_state)

def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
value = state.get(self.index)
new_x = x + value
new_state = state.set(self.index, value + 1)
return new_x, new_state

counter, state = eqx.nn.make_with_state(Counter)()
x = jnp.array(2.3)

num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 0

_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 1

_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 2
```

## Vmap'd stateful layers

This is an advanced thing to do! Here we'll build on [the ensembling guide](../../../tricks/#ensembling), and see how how we can create vmap'd stateful layers.

This follows on from the previous example, in which we define `Counter`.
```python
import jax.random as jr

class Model(eqx.Module):
linear: eqx.nn.Linear
counter: Counter
v_counter: Counter

def __init__(self, key):
# Not-stateful layer
self.linear = eqx.nn.Linear(2, 2, key=key)
# Stateful layer.
self.counter = Counter()
# Vmap'd stateful layer. (Whose initial state will include a batch dimension.)
self.v_counter = eqx.filter_vmap(Counter, axis_size=2)()

def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
# This bit happens as normal.
assert x.shape == (2,)
x = self.linear(x)
x, state = self.counter(x, state)

# For the vmap, we have to restrict our state to just those states we want to
# vmap, and then update the overall state again afterwards.
#
# After all, the state for `self.counter` isn't expecting to be batched, so we
# have to remove that.
substate = state.substate(self.v_counter)
x, substate = eqx.filter_vmap(self.v_counter)(x, substate)
state = state.update(substate)

return x, state

key = jr.PRNGKey(0)
model, state = eqx.nn.make_with_state(Model)(key)
x = jnp.array([5.0, -1.0])
model(x, state)
```
7 changes: 6 additions & 1 deletion equinox/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,9 @@
StatefulLayer as StatefulLayer,
)
from ._spectral_norm import SpectralNorm as SpectralNorm
from ._stateful import State as State, StateIndex as StateIndex
from ._stateful import (
delete_init_state as delete_init_state,
make_with_state as make_with_state,
State as State,
StateIndex as StateIndex,
)
6 changes: 3 additions & 3 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def __init__(
else:
self.weight = None
self.bias = None
self.first_time_index = StateIndex(lambda **_: jnp.array(True))
make_buffers = lambda **_: (
self.first_time_index = StateIndex(jnp.array(True))
init_buffers = (
jnp.empty((input_size,), dtype=dtype),
jnp.empty((input_size,), dtype=dtype),
)
self.state_index = StateIndex(make_buffers)
self.state_index = StateIndex(init_buffers)
self.inference = inference
self.axis_name = axis_name
self.input_size = input_size
Expand Down
2 changes: 1 addition & 1 deletion equinox/nn/_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
v0 = jr.normal(vkey, (v_len,))
for _ in range(15):
u0, v0 = _power_iteration(weight, u0, v0, eps)
self.uv_index = StateIndex(lambda **_: (u0, v0))
self.uv_index = StateIndex((u0, v0))

@jax.named_scope("eqx.nn.SpectralNorm")
def __call__(
Expand Down
Loading
Loading