From 2d4c553e46723105f762361a7be7215ea820cb25 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 31 Aug 2023 09:07:55 -0700 Subject: [PATCH 1/3] Updated stateful operations to be able to support vmap'ing. This is actually a minor breaking change: the API has changed from StateIndex(callable_returning_value) to just StateIndex(value). This change also introduces substantially updated documentation on stateful layers, which should help a lot. --- docs/api/nn/stateful.md | 106 ++++++++++++++- equinox/nn/__init__.py | 7 +- equinox/nn/_batch_norm.py | 6 +- equinox/nn/_spectral_norm.py | 2 +- equinox/nn/_stateful.py | 256 +++++++++++++++++++++++++++++++---- examples/stateful.ipynb | 68 +++++++--- tests/test_stateful.py | 18 +++ 7 files changed, 404 insertions(+), 59 deletions(-) create mode 100644 tests/test_stateful.py diff --git a/docs/api/nn/stateful.md b/docs/api/nn/stateful.md index 51ad63bb..49c1d340 100644 --- a/docs/api/nn/stateful.md +++ b/docs/api/nn/stateful.md @@ -1,17 +1,113 @@ # Stateful operations -These are the tools that underly stateful operations, like [`equinox.nn.BatchNorm`][] or [`equinox.nn.SpectralNorm`][]. +These are the tools that underly stateful operations, like [`equinox.nn.BatchNorm`][] or [`equinox.nn.SpectralNorm`][]. These are fairly unusual layers, so most users will not need this part of the API. -See the [stateful example](../../examples/stateful.ipynb) for an example of working with stateful operations. +!!! Example -::: equinox.nn.State + The [stateful example](../../examples/stateful.ipynb) is a good reference for the typical workflow for stateful layers. + +--- + +::: equinox.nn.make_with_state + +## Extra features + +Let's explain how this works under the hood. First of all, all stateful layers (`BatchNorm` etc.) include an "index". This is basically just a unique hashable value (used later as a dictionary key), and an initial value for the state: + +::: equinox.nn.StateIndex selection: members: - __init__ --- -::: equinox.nn.StateIndex +This `State` object that's being passed around is essentially just a dictionary, mapping from `StateIndex`s to PyTrees-of-arrays. Correspondingly this has `.get` and `.set` methods to read and write values to it. + +::: equinox.nn.State selection: members: - - __init__ + - get + - set + - substate + - update + +## Custom stateful layers + +Let's use [`equinox.nn.StateIndex`][] to create a custom stateful layer. + +```python +import equinox as eqx +import jax.numpy as jnp +from jaxtyping import Array + +class Counter(eqx.Module): + index: eqx.nn.StateIndex + + def __init__(self): + init_state = jnp.array(0) + self.index = eqx.nn.StateIndex(init_state) + + def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]: + value = state.get(self.index) + new_x = x + value + new_state = state.set(self.index, value + 1) + return new_x, new_state + +counter, state = eqx.nn.make_with_state(Counter)() +x = jnp.array(2.3) + +num_calls = state.get(counter.index) +print(f"Called {num_calls} times.") # 0 + +_, state = counter(x, state) +num_calls = state.get(counter.index) +print(f"Called {num_calls} times.") # 1 + +_, state = counter(x, state) +num_calls = state.get(counter.index) +print(f"Called {num_calls} times.") # 2 +``` + +## Vmap'd stateful layers + +This is an advanced thing to do! Here we'll build on [the ensembling guide](../../../tricks/#ensembling), and see how how we can create vmap'd stateful layers. + +This follows on from the previous example, in which we define `Counter`. +```python +import jax.random as jr + +class Model(eqx.Module): + linear: eqx.nn.Linear + counter: Counter + v_counter: Counter + + def __init__(self, key): + # Not-stateful layer + self.linear = eqx.nn.Linear(2, 2, key=key) + # Stateful layer. + self.counter = Counter() + # Vmap'd stateful layer. (Whose initial state will include a batch dimension.) + self.v_counter = eqx.filter_vmap(Counter, axis_size=2)() + + def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]: + # This bit happens as normal. + assert x.shape == (2,) + x = self.linear(x) + x, state = self.counter(x, state) + + # For the vmap, we have to restrict our state to just those states we want to + # vmap, and then update the overall state again afterwards. + # + # After all, the state for `self.counter` isn't expecting to be batched, so we + # have to remove that. + substate = state.substate(self.v_counter) + x, substate = eqx.filter_vmap(self.v_counter)(x, substate) + state = state.update(substate) + + return x, state + +key = jr.PRNGKey(0) +model, state = eqx.nn.make_with_state(Model)(key) +x = jnp.array([5.0, -1.0]) +model(x, state) +``` diff --git a/equinox/nn/__init__.py b/equinox/nn/__init__.py index 96b60232..27b48808 100644 --- a/equinox/nn/__init__.py +++ b/equinox/nn/__init__.py @@ -40,4 +40,9 @@ StatefulLayer as StatefulLayer, ) from ._spectral_norm import SpectralNorm as SpectralNorm -from ._stateful import State as State, StateIndex as StateIndex +from ._stateful import ( + delete_init_state as delete_init_state, + make_with_state as make_with_state, + State as State, + StateIndex as StateIndex, +) diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index b12c30a6..dca0353c 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -92,12 +92,12 @@ def __init__( else: self.weight = None self.bias = None - self.first_time_index = StateIndex(lambda **_: jnp.array(True)) - make_buffers = lambda **_: ( + self.first_time_index = StateIndex(jnp.array(True)) + init_buffers = ( jnp.empty((input_size,), dtype=dtype), jnp.empty((input_size,), dtype=dtype), ) - self.state_index = StateIndex(make_buffers) + self.state_index = StateIndex(init_buffers) self.inference = inference self.axis_name = axis_name self.input_size = input_size diff --git a/equinox/nn/_spectral_norm.py b/equinox/nn/_spectral_norm.py index d4615738..618df99f 100644 --- a/equinox/nn/_spectral_norm.py +++ b/equinox/nn/_spectral_norm.py @@ -103,7 +103,7 @@ def __init__( v0 = jr.normal(vkey, (v_len,)) for _ in range(15): u0, v0 = _power_iteration(weight, u0, v0, eps) - self.uv_index = StateIndex(lambda **_: (u0, v0)) + self.uv_index = StateIndex((u0, v0)) @jax.named_scope("eqx.nn.SpectralNorm") def __call__( diff --git a/equinox/nn/_stateful.py b/equinox/nn/_stateful.py index b1e5a5aa..f3990876 100644 --- a/equinox/nn/_stateful.py +++ b/equinox/nn/_stateful.py @@ -1,6 +1,6 @@ -import types from collections.abc import Callable from typing import Any, Generic, TypeVar +from typing_extensions import ParamSpec import jax import jax.numpy as jnp @@ -9,31 +9,47 @@ from .._module import Module from .._pretty_print import bracketed, named_objs, text, tree_pformat +from .._tree import tree_at _Value = TypeVar("_Value") +_P = ParamSpec("_P") +_T = TypeVar("_T") class StateIndex(Module, Generic[_Value]): - """This is an advanced feature, used when creating custom stateful layers. + """This wraps together (a) a unique dictionary key used for looking up a stateful + value, and (b) how that stateful value should be initialised. - This wraps a dictionary key used for looking up the stateful value. + !!! Example - See the source code of [`equinox.nn.BatchNorm`][] for an example. - """ + ```python + class MyStatefulLayer(eqx.Module): + index: eqx.nn.StateIndex + + def __init__(self): + init_state = jnp.array(0) + self.index = StateIndex(init_state) + + def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]: + current_state = state.get(self.index) + new_x = x + current_state + new_state = state.set(current_state + 1) + return new_x, new_state + ``` + + See also e.g. the source code of built-in stateful layers like + [`equinox.nn.BatchNorm`][] for further reference. + """ # noqa: E501 marker: object - init: types.FunctionType + init: _Value - def __init__(self, init: Callable[..., _Value]): + def __init__(self, init: _Value): """**Arguments:** - - `init`: A function that is used to initialise the state of the layer. Should - be a function that returns a PyTree of JAX arrays. It will be called with - all keyword arguments passed to [`equinox.nn.State.__init__`]. + - `init`: The initial value for the state. """ - if not isinstance(init, types.FunctionType): - raise TypeError("`StateIndex(init=...)` must be a function") self.marker = object() self.init = init @@ -42,37 +58,46 @@ def _is_index(x: Any) -> bool: return isinstance(x, StateIndex) +# Used as a sentinel in two ways: keeping track of updated `State`s, and keeping track +# of deleted initial states. _sentinel = object() _state_error = """ Attempted to use old state. Probably you have done something like: - +``` x, state2 = layer1(x, state1) x, state3 = layer1(x, state1) # bug! Re-used state1 instead of using state2. +``` + +If you have done this intentionally, because you want to use an old state, then you can +avoid this error by making a clone of the state: +``` +leaves, treedef = jax.tree_util.tree_flatten(state) +state_clone = jax.tree_util.tree_unflatten(treedef, leaves) +``` """.strip() -# Basically just a dictionary which (a) works only with Markers, and which (b) works -# around a JAX bug that prevents flattening dicts with `object()` keys, and which (c) -# does error-checking that you're using the most up-to-date version of it. +# Basically just a dictionary which (a) works only with StateIndex-s, and which (b) +# works around a JAX bug that prevents flattening dicts with `object()` keys, and which +# (c) does error-checking that you're using the most up-to-date version of it. @jtu.register_pytree_node_class class State: """Stores the state of a model. For example, the running statistics of all [`equinox.nn.BatchNorm`][] layers in the model. - Most models won't need this. (As most models don't have any stateful layers.) - If used, the state will be passed to each layer at call time; see the - [stateful example](../../examples/stateful.ipynb). + This is essentially a dictionary mapping from [`equinox.nn.StateIndex`][]s to + PyTrees of arrays. + + This class should be initialised via [`equinox.nn.make_with_state`][]. """ - def __init__(self, model: PyTree, **kwargs): + def __init__(self, model: PyTree): """**Arguments:** - `model`: any PyTree. All stateful layers (e.g. [`equinox.nn.BatchNorm`][]) - will have their state initialised and stored inside the `State` object. - - `**kwargs`: all keyword arguments are forwarded to the `init` function of - `equinox.nn.StateIndex(init=...)` (used inside each stateful layer). + will have their initial state stored inside the `State` object. """ # Note that de/serialisation depends on the ordered-ness of this dictionary, # between serialisation and deserialisation. @@ -80,30 +105,115 @@ def __init__(self, model: PyTree, **kwargs): leaves = jtu.tree_leaves(model, is_leaf=_is_index) for leaf in leaves: if _is_index(leaf): - value = leaf.init(**kwargs) - value = jtu.tree_map(jnp.asarray, value) - state[leaf.marker] = value + if leaf.init is _sentinel: + raise ValueError( + "Cannot call `eqx.nn.State(eqx.nn.delete_init_state(model))`. " + "You should call `eqx.nn.State(model)`, using the original " + "model." + ) + state[leaf.marker] = jtu.tree_map(jnp.asarray, leaf.init) self._state = state def get(self, item: StateIndex[_Value]) -> _Value: + """Given an [`equinox.nn.StateIndex`][], returns the value of its state. + + **Arguments:** + + - `item`: an [`equinox.nn.StateIndex`][]. + + **Returns:** + + The current state associated with that index. + """ if self._state is _sentinel: raise ValueError(_state_error) if type(item) is not StateIndex: - raise ValueError("Can only use `eqx.nn.Marker`s as state keys.") + raise ValueError("Can only use `eqx.nn.StateIndex`s as state keys.") return self._state[item.marker] # pyright: ignore def set(self, item: StateIndex[_Value], value: _Value) -> "State": + """Sets a new value for an [`equinox.nn.StateIndex`][], **and returns the + updated state**. + + **Arguments:** + + - `item`: an [`equinox.nn.StateIndex`][]. + - `value`: the new value associated with that index. + + **Returns:** + + A new [`equinox.nn.State`][] object, with the update. + + As a safety guard against accidentally writing `state.set(item, value)` without + assigning it to a new value, then the old object (`self`) will become invalid. + """ if self._state is _sentinel: raise ValueError(_state_error) if type(item) is not StateIndex: - raise ValueError("Can only use `eqx.nn.Marker`s as state keys.") + raise ValueError("Can only use `eqx.nn.StateIndex`s as state keys.") old_value = self._state[item.marker] # pyright: ignore value = jtu.tree_map(jnp.asarray, value) if jax.eval_shape(lambda: old_value) != jax.eval_shape(lambda: value): raise ValueError("Old and new values have different structures.") - self._state[item.marker] = value # pyright: ignore + state = self._state.copy() # pyright: ignore + state[item.marker] = value new_self = object.__new__(State) - new_self._state = self._state + new_self._state = state + self._state = _sentinel + return new_self + + def substate(self, pytree: PyTree) -> "State": + """Creates a smaller `State` object, that tracks only the states of some smaller + part of the overall model. + + **Arguments:** + + - `pytree`: any PyTree. It will be iterated over to check for + [`equinox.nn.StateIndex`]s. + + **Returns:** + + A new [`equinox.nn.State`][] object, which tracks only some of the overall + states. + """ + if self._state is _sentinel: + raise ValueError(_state_error) + leaves = jtu.tree_leaves(pytree, is_leaf=_is_index) + markers = [x.marker for x in leaves if _is_index(x)] + substate = {k: self._state[k] for k in markers} # pyright: ignore + subself = object.__new__(State) + subself._state = substate + return subself + + def update(self, substate: "State") -> "State": + """Takes a smaller `State` object (typically produces via `.substate`), and + updates states by using all of its values. + + **Arguments:** + + - `substate`: a `State` object whose keys are a subset of the keys of `self`. + + **Returns:** + + A new [`equinox.nn.State`][] object, containing all of the updated values. + + As a safety guard against accidentally writing `state.set(item, value)` without + assigning it to a new value, then the old object (`self`) will become invalid. + """ + if self._state is _sentinel: + raise ValueError(_state_error) + if type(substate) is not State: + raise ValueError("Can only use `eqx.nn.State`s in `update`.") + state = self._state.copy() # pyright: ignore + for key, value in substate._state.items(): # pyright: ignore + if key not in state: + raise ValueError( + "Cannot call `state1.update(state2)` unless `state2` is a substate " + "of `state1`." + ) + state[key] = value + new_self = object.__new__(State) + new_self._state = state self._state = _sentinel return new_self @@ -145,3 +255,89 @@ def tree_unflatten(cls, keys, values): state[key] = value self._state = state return self + + +def _delete_init_state(x): + if _is_index(x): + return tree_at(lambda y: y.init, x, _sentinel) + else: + return x + + +def delete_init_state(model: PyTree) -> PyTree: + """For memory efficiency, this deletes the initial state stored within a model. + + Every stateful layer keeps a copy of the initial value of its state. This is then + collected by [`equinox.nn.State`][], when it is called on the model. However, this + means that the model must keep a copy of the initial state around, in case + `eqx.nn.State` is called on it again. This extra copy consumes extra memory. + + But in practice, it is quite common to only need to initialise the state once. In + this case, we can use this function to delete this extra copy, and in doing so save + some memory. + + !!! Example + + Here is the typical pattern in which this is used: + ```python + model_and_state = eqx.nn.BatchNorm(...) + state = eqx.nn.State(model_and_state) + model = eqx.nn.delete_init_state(model) + del model_and_state # ensure this goes out of scope and is garbage collected + ``` + Indeed the above is precisely what [`equinox.nn.make_with_state`][] does. + + **Arguments:** + + - `model`: any PyTree. + + **Returns:** + + A copy of `model`, with all the initial states stripped out. (As in the exampels + above, you should then dispose of the original `model` object.) + """ + return jtu.tree_map(_delete_init_state, model, is_leaf=_is_index) + + +def make_with_state(make_model: Callable[_P, _T]) -> Callable[_P, tuple[_T, State]]: + """This function is the most common API for working with stateful models. This + initialises both the parameters and the state of a stateful model. + + `eqx.nn.make_with_state(Model)(*args, **kwargs)` simply calls + `model_with_state = Model(*args, **kwargs)`, and then partitions the resulting + PyTree into two pieces: the parameters, and the state. + + **Arguments:** + + - `make_model`: some callable returning a PyTree. + + **Returns:** + + A callable, which when evaluated returns a 2-tuple of `(model, state)`, where + `model` is the result of `make_model(*args, **kwargs)` but with all of the initial + states stripped out, and `state` is an [`equinox.nn.State`][] object encapsulating + the initial states. + + !!! Example + + See [the stateful example](../../examples/stateful.ipynb) for a runnable + example. + + ```python + class Model(eqx.Module): + def __init__(self, foo, bar): + ... + + ... + + model, state = eqx.nn.make_with_state(Model)(foo=3, bar=4) + ``` + """ + + def make_with_state_impl(*args: _P.args, **kwargs: _P.kwargs) -> tuple[_T, State]: + model = make_model(*args, **kwargs) + state = State(model) + model = delete_init_state(model) + return model, state + + return make_with_state_impl diff --git a/examples/stateful.ipynb b/examples/stateful.ipynb index ab43867c..ef872523 100644 --- a/examples/stateful.ipynb +++ b/examples/stateful.ipynb @@ -11,6 +11,8 @@ "\n", "This just means that we need to plumb an extra input and output through our models. This example demonstrates both [`equinox.nn.BatchNorm`][] and [`equinox.nn.SpectralNorm`][].\n", "\n", + "See also the [stateful API reference](../api/nn/stateful.md).\n", + "\n", "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/equinox/blob/main/examples/stateful.ipynb)." ] }, @@ -37,8 +39,8 @@ "metadata": {}, "outputs": [], "source": [ - "# This model is just a weird mish-mash of layers for demonstration purposes, it isn't\n", - "# doing any clever.\n", + "# This model is just a weird mish-mash of stateful and non-stateful layers for\n", + "# demonstration purposes, it isn't doing any clever.\n", "class Model(eqx.Module):\n", " norm1: eqx.nn.BatchNorm\n", " spectral_linear: eqx.nn.SpectralNorm[eqx.nn.Linear]\n", @@ -69,6 +71,24 @@ " return x, state" ] }, + { + "cell_type": "markdown", + "id": "2b931647-01ff-4eb9-972f-55d3d50771c6", + "metadata": {}, + "source": [ + "We see from the above that we just define our models like normal. As advertised, we just need to thread the additional `state` object in and out of every call. An updated state object is returned.\n", + "\n", + "There's really nothing special here about stateful layers. Equinox isn't special-casing them in any way. We thread `state` in and out, just like we're thread `x` in and out. In fact calling it \"state\" is really just a matter of how it's advertised!\n", + "\n", + "---\n", + "\n", + "Alright, now let's see how we might train this model. This is also much like normal.\n", + "\n", + "Note the use of `in_axes` and `out_axes`: our data is batched, but our model state isn't batched -- just like how our model parameters isn't batched.\n", + "\n", + "Note how the `axis_name` argment matches the `axis_name` argument that the `BatchNorm` layers were initialised with. This tells `BatchNorm` which vmap'd axis it should compute statistics over. (This is a detail specific to `BatchNorm`, and is unrelated to stateful operations in general.)" + ] + }, { "cell_type": "code", "execution_count": 3, @@ -77,9 +97,6 @@ "outputs": [], "source": [ "def compute_loss(model, state, xs, ys):\n", - " # The `axis_name` argument is needed specifically for `BatchNorm`: so it knows\n", - " # what axis to compute batch statistics over.\n", - " # The `in_axes` and `out_axes` are needed so that `state` isn't batched.\n", " batch_model = jax.vmap(\n", " model, axis_name=\"batch\", in_axes=(0, None), out_axes=(0, None)\n", " )\n", @@ -96,6 +113,16 @@ " return model, state, opt_state" ] }, + { + "cell_type": "markdown", + "id": "d85113ae-7d63-4051-83bd-e9e921dfcae1", + "metadata": {}, + "source": [ + "---\n", + "\n", + "And now, let's see how we initialise this model, and initialise its state." + ] + }, { "cell_type": "code", "execution_count": 4, @@ -112,15 +139,16 @@ "\n", "key = jr.PRNGKey(seed)\n", "mkey, xkey, xkey2 = jr.split(key, 3)\n", - "model = Model(mkey)\n", - "state = eqx.nn.State(model)\n", + "\n", + "model, state = eqx.nn.make_with_state(Model)(mkey)\n", + "\n", "xs = jr.normal(xkey, (dataset_size, 3))\n", "ys = jnp.sin(xs) + 1\n", "optim = optax.adam(learning_rate)\n", "opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n", "\n", - "# Full-batch gradient descent in this simple example.\n", "for _ in range(steps):\n", + " # Full-batch gradient descent in this simple example.\n", " model, state, opt_state = make_step(model, state, opt_state, xs, ys)" ] }, @@ -129,15 +157,19 @@ "id": "9c2135ba-07dc-4ce7-89c1-105bd699e78d", "metadata": {}, "source": [ - "Overall, we see that this should be relatively straightforward!\n", + "!!! Info \"What is `eqx.nn.make_with_state` doing?\"\n", "\n", - "When calling `state = eqx.nn.State(model)`, then the model PyTree is iterated over, and any stateful layers store their initial states in the resulting `state` object. The `state` object is itself also a PyTree, so it can just be passed around in the usual way.\n", + " Here we come to the only interesting bit about using stateful layers!\n", "\n", - "In this example, `state` will store the running statistics for `BatchNorm`, and U-V power iterations for `SpectralNorm`.\n", + " When we initialise the model -- e.g. if we were to call `Model(mkey)` directly -- then the model PyTree would be initialised containing both (a) the initial parameters, and (b) the initial state. So `make_with_state` simply calls this, and then separates these two things. The returned `model` is a PyTree holding all the initial parameters (just like any other model), and `state` is a PyTree holding the initial state.\n", "\n", - "Subsequently, we just need to thread the `state` object in-and-out of every call. Each time a new state object is returned. (And the old state object should not be reused.)\n", + "---\n", "\n", - "Finally, let's use our trained model to perform inference:" + "Finally, let's use our trained model to perform inference.\n", + "\n", + "Remember to set the inference flag! Some layers have different behaviour between training and inference, and `BatchNorm` is one of these. (This is a detail specific to layers like `BatchNorm` and [`equinox.nn.Dropout`][], and is unrelated to stateful operations in general.)\n", + "\n", + "We also fix the final state in the model, using [`equinox.Partial`][]. The resulting `inference_model` is a PyTree (specifically, an `equinox.Partial`) containing both `model` and `state`." ] }, { @@ -168,17 +200,15 @@ "id": "d22e4395-9006-465e-ba0b-a4b0d8131bd7", "metadata": {}, "source": [ - "Here, we don't need the updated state object that is output, so we just discard it.\n", - "\n", - "(Also, don't forget to set the `inference` flags.)" + "Here, we don't need the updated state object that is produced, so we just discard it." ] } ], "metadata": { "kernelspec": { - "display_name": "jax38", + "display_name": "py311", "language": "python", - "name": "jax38" + "name": "py311" }, "language_info": { "codemirror_mode": { @@ -190,7 +220,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.11.3" } }, "nbformat": 4, diff --git a/tests/test_stateful.py b/tests/test_stateful.py new file mode 100644 index 00000000..acfca696 --- /dev/null +++ b/tests/test_stateful.py @@ -0,0 +1,18 @@ +import jax.tree_util as jtu +import pytest + +import equinox as eqx + + +def test_delete_init_state(): + model = eqx.nn.BatchNorm(3, "batch") + eqx.nn.State(model) + model2 = eqx.nn.delete_init_state(model) + + eqx.nn.State(model) + with pytest.raises(ValueError): + eqx.nn.State(model2) + + leaves = [x for x in jtu.tree_leaves(model) if eqx.is_array(x)] + leaves2 = [x for x in jtu.tree_leaves(model2) if eqx.is_array(x)] + assert len(leaves) == len(leaves2) + 3 From 61e47b89da04d55dfdd49ca52a450bbb124b24fa Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 11 Sep 2023 19:51:57 -0700 Subject: [PATCH 2/3] Stateful operations now support creating states multiple times, such that they are compatible with the original model. --- equinox/nn/_stateful.py | 42 +++++++++++++++++++++++++------ tests/test_stateful.py | 56 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 8 deletions(-) diff --git a/equinox/nn/_stateful.py b/equinox/nn/_stateful.py index f3990876..5a992d39 100644 --- a/equinox/nn/_stateful.py +++ b/equinox/nn/_stateful.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any, Generic, TypeVar +from typing import Any, Generic, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec import jax @@ -7,7 +7,7 @@ import jax.tree_util as jtu from jaxtyping import PyTree -from .._module import Module +from .._module import field, Module from .._pretty_print import bracketed, named_objs, text, tree_pformat from .._tree import tree_at @@ -42,7 +42,9 @@ def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]: [`equinox.nn.BatchNorm`][] for further reference. """ # noqa: E501 - marker: object + # Starts off as an `object` when initialised; later replaced with an `int` inside + # `make_with_state`. + marker: Union[object, int] = field(static=True) init: _Value def __init__(self, init: _Value): @@ -334,10 +336,34 @@ def __init__(self, foo, bar): ``` """ - def make_with_state_impl(*args: _P.args, **kwargs: _P.kwargs) -> tuple[_T, State]: - model = make_model(*args, **kwargs) - state = State(model) - model = delete_init_state(model) - return model, state + # _P.{args, kwargs} not supported by beartype + if TYPE_CHECKING: + + def make_with_state_impl( + *args: _P.args, **kwargs: _P.kwargs + ) -> tuple[_T, State]: + ... + + else: + + def make_with_state_impl(*args, **kwargs) -> tuple[_T, State]: + model = make_model(*args, **kwargs) + + # Replace all markers with `int`s. This is needed to ensure that two calls + # to `make_with_state` produce compatible models and states. + leaves, treedef = jtu.tree_flatten(model, is_leaf=_is_index) + counter = 0 + new_leaves = [] + for leaf in leaves: + if _is_index(leaf): + leaf = StateIndex(leaf.init) + object.__setattr__(leaf, "marker", counter) + counter += 1 + new_leaves.append(leaf) + model = jtu.tree_unflatten(treedef, new_leaves) + + state = State(model) + model = delete_init_state(model) + return model, state return make_with_state_impl diff --git a/tests/test_stateful.py b/tests/test_stateful.py index acfca696..dc97c13e 100644 --- a/tests/test_stateful.py +++ b/tests/test_stateful.py @@ -1,3 +1,6 @@ +import jax +import jax.numpy as jnp +import jax.random as jr import jax.tree_util as jtu import pytest @@ -16,3 +19,56 @@ def test_delete_init_state(): leaves = [x for x in jtu.tree_leaves(model) if eqx.is_array(x)] leaves2 = [x for x in jtu.tree_leaves(model2) if eqx.is_array(x)] assert len(leaves) == len(leaves2) + 3 + + +def test_double_state(): + # From https://github.com/patrick-kidger/equinox/issues/450#issuecomment-1714501666 + + class Counter(eqx.Module): + index: eqx.nn.StateIndex + + def __init__(self): + init_state = jnp.array(0) + self.index = eqx.nn.StateIndex(init_state) + + def __call__(self, x, state): + value = state.get(self.index) + new_x = x + value + new_state = state.set(self.index, value + 1) + return new_x, new_state + + class Model(eqx.Module): + linear: eqx.nn.Linear + counter: Counter + v_counter: Counter + + def __init__(self, key): + # Not-stateful layer + self.linear = eqx.nn.Linear(2, 2, key=key) + # Stateful layer. + self.counter = Counter() + # Vmap'd stateful layer. (Whose initial state will include a batch + # dimension.) + self.v_counter = eqx.filter_vmap(Counter, axis_size=2)() + + def __call__(self, x, state): + assert x.shape == (2,) + x = self.linear(x) + x, state = self.counter(x, state) + substate = state.substate(self.v_counter) + x, substate = eqx.filter_vmap(self.v_counter)(x, substate) + state = state.update(substate) + return x, state + + key = jr.PRNGKey(0) + model, state = eqx.nn.make_with_state(Model)(key) + x = jnp.array([5.0, -1.0]) + model(x, state) + + @jax.jit + def make_state(key): + _, state = eqx.nn.make_with_state(Model)(key) + return state + + new_state = make_state(jr.PRNGKey(1)) + model(x, new_state) From 3c8d4e86c9aa61a79836eb48a38f5a9079adca93 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 12 Sep 2023 06:34:56 -0700 Subject: [PATCH 3/3] Added deprecation message --- equinox/nn/_stateful.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/equinox/nn/_stateful.py b/equinox/nn/_stateful.py index 5a992d39..fc9b9ab0 100644 --- a/equinox/nn/_stateful.py +++ b/equinox/nn/_stateful.py @@ -1,3 +1,4 @@ +import types from collections.abc import Callable from typing import Any, Generic, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec @@ -52,6 +53,15 @@ def __init__(self, init: _Value): - `init`: The initial value for the state. """ + if isinstance(init, types.FunctionType): + # Technically a function is valid here, since we could allow any pytree. + # In practice that's weird / kind of useless, so better to explicitly raise + # the deprecation error. + raise ValueError( + "As of Equinox v0.11.0, `eqx.nn.StateIndex` now accepts the value " + "of the initial state directly. (Not a function that creates the " + "initial state.)" + ) self.marker = object() self.init = init