Skip to content

Commit

Permalink
Renamed eqx.tree_inference -> eqx.nn.inference_mode, as it's really a…
Browse files Browse the repository at this point in the history
…n eqx.nn thing, not an eqx thing.
  • Loading branch information
patrick-kidger committed Aug 31, 2023
1 parent a234840 commit 628fb48
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 88 deletions.
4 changes: 0 additions & 4 deletions docs/api/manipulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@

---

::: equinox.tree_inference

---

::: equinox.tree_flatten_one_level

---
Expand Down
3 changes: 3 additions & 0 deletions docs/api/nn/inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Training/Inference

::: equinox.nn.inference_mode
2 changes: 1 addition & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@
tree_check as tree_check,
tree_equal as tree_equal,
tree_flatten_one_level as tree_flatten_one_level,
tree_inference as tree_inference,
)
from ._update import apply_updates as apply_updates
from ._vmap_pmap import (
filter_pmap as filter_pmap,
filter_vmap as filter_vmap,
if_array as if_array,
)
from .nn import inference_mode as tree_inference # noqa: F401 - backward compatibility


__version__ = importlib.metadata.version("equinox")
69 changes: 0 additions & 69 deletions equinox/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,75 +270,6 @@ def tree_equal(*pytrees: PyTree) -> Union[bool, np.bool_, Bool[Array, ""]]:
return out


def _inferences(pytree):
is_leaf = lambda x: hasattr(x, "inference") and x is not pytree

out = [pytree.inference] if hasattr(pytree, "inference") else []

leaves = [x for x in jtu.tree_leaves(pytree, is_leaf=is_leaf) if is_leaf(x)]
# Nodes with an inference flag might have sub-nodes with an inference flag.

for x in leaves:
out.extend(_inferences(x))
return out


def tree_inference(pytree: PyTree, value: bool) -> PyTree:
"""Convenience function for setting all `inference` attributes on a PyTree.
`inference` flags are used to toggle the behaviour of a number of the pre-built
neural network layers, such as [`equinox.nn.Dropout`][] or
[`equinox.nn.BatchNorm`][].
!!! Example
```python
class Model(eqx.Module):
norm: eqx.nn.BatchNorm
dropout: eqx.nn.Dropout
linear: eqx.nn.Linear
def __init__(self, key):
key1, key2 = jax.random.split(key)
self.norm = eqx.nn.BatchNorm(3, "batch", key=key1)
self.dropout = eqx.nn.Dropout(0.4)
self.linear = eqx.nn.Linear(3, 1, key=key2)
def __call__(self, x, ctx, *, key):
x, ctx = self.norm(x, ctx)
x = self.dropout(x, key=key)
x = self.linear(x)
return x, ctx
training_model = Model(jax.random.PRNGKey(0))
inference_model = eqx.tree_inference(training_model, value=True)
training_model_again = eqx.tree_inference(inference_model, value=False)
```
Equivalent to:
```python
has_inference = lambda leaf: hasattr(leaf, "inference")
def where(pytree):
return tuple(x.inference
for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
if has_inference(x))
equinox.tree_at(where, pytree, replace_fn=lambda _: value)
```
**Arguments:**
- `pytree`: the PyTree to modify.
- `value`: the value to set all `inference` attributes to.
**Returns:**
A copy of `pytree` with all `inference` flags set to `value`.
"""
return tree_at(_inferences, pytree, replace_fn=lambda _: value)


def tree_flatten_one_level(
pytree: PyTree,
) -> tuple[list[PyTree], PyTreeDef]: # pyright: ignore
Expand Down
1 change: 1 addition & 0 deletions equinox/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from ._dropout import Dropout as Dropout
from ._embedding import Embedding as Embedding
from ._inference import inference_mode as inference_mode
from ._linear import Identity as Identity, Linear as Linear
from ._mlp import MLP as MLP
from ._normalisation import GroupNorm as GroupNorm, LayerNorm as LayerNorm
Expand Down
2 changes: 1 addition & 1 deletion equinox/nn/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def __init__(
- `dropout_p`: Dropout probability on attention weights.
- `inference`: Whether to actually apply dropout at all. If `True` then dropout
is not applied. If `False` then dropout is applied. This may be toggled
with [`equinox.tree_inference`][] or overridden during
with [`equinox.nn.inference_mode`][] or overridden during
[`equinox.nn.MultiheadAttention.__call__`][].
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
Expand Down
4 changes: 2 additions & 2 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BatchNorm(StatefulLayer):
training then statistics are computed using the input data, and the running
statistics updated. During inference then just the running statistics are used.
Whether the model is in training or inference mode should be toggled using
[`equinox.tree_inference`][].
[`equinox.nn.inference_mode`][].
""" # noqa: E501

weight: Optional[Float[Array, "input_size"]]
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
- `inference`: If `False` then the batch means and variances will be calculated
and used to update the running statistics. If `True` then the running
statistics are directly used for normalisation. This may be toggled with
[`equinox.tree_inference`][] or overridden during
[`equinox.nn.inference_mode`][] or overridden during
[`equinox.nn.BatchNorm.__call__`][].
- `dtype`: The dtype of the input array.
"""
Expand Down
4 changes: 2 additions & 2 deletions equinox/nn/_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Dropout(Module):
Note that this layer behaves differently during training and inference. During
training then dropout is randomly applied; during inference this layer does nothing.
Whether the model is in training or inference mode should be toggled using
[`equinox.tree_inference`][].
[`equinox.nn.inference_mode`][].
"""

# Not static fields as it makes sense to want to modify them via equinox.tree_at.
Expand All @@ -35,7 +35,7 @@ def __init__(
- `p`: The fraction of entries to set to zero. (On average.)
- `inference`: Whether to actually apply dropout at all. If `True` then dropout
is *not* applied. If `False` then dropout is applied. This may be toggled
with [`equinox.tree_inference`][] or overridden during
with [`equinox.nn.inference_mode`][] or overridden during
[`equinox.nn.Dropout.__call__`][].
- `deterministic`: Deprecated alternative to `inference`.
"""
Expand Down
74 changes: 74 additions & 0 deletions equinox/nn/_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import jax.tree_util as jtu
from jaxtyping import PyTree

from .._tree import tree_at


def _inferences(pytree):
is_leaf = lambda x: hasattr(x, "inference") and x is not pytree

out = [pytree.inference] if hasattr(pytree, "inference") else []

leaves = [x for x in jtu.tree_leaves(pytree, is_leaf=is_leaf) if is_leaf(x)]
# Nodes with an inference flag might have sub-nodes with an inference flag.

for x in leaves:
out.extend(_inferences(x))
return out


def inference_mode(pytree: PyTree, value: bool = True) -> PyTree:
"""Convenience function for setting all `inference` attributes.
`inference` flags are used to toggle the behaviour of a number of the pre-built
neural network layers, such as [`equinox.nn.Dropout`][] or
[`equinox.nn.BatchNorm`][].
!!! Example
```python
class Model(eqx.Module):
norm: eqx.nn.BatchNorm
dropout: eqx.nn.Dropout
linear: eqx.nn.Linear
def __init__(self, key):
key1, key2 = jax.random.split(key)
self.norm = eqx.nn.BatchNorm(3, "batch", key=key1)
self.dropout = eqx.nn.Dropout(0.4)
self.linear = eqx.nn.Linear(3, 1, key=key2)
def __call__(self, x, ctx, *, key):
x, ctx = self.norm(x, ctx)
x = self.dropout(x, key=key)
x = self.linear(x)
return x, ctx
training_model = Model(jax.random.PRNGKey(0))
inference_model = eqx.nn.inference_mode(training_model)
training_model_again = eqx.nn.inference_mode(inference_model, value=False)
```
This function is essentially equivalent to:
```python
has_inference = lambda leaf: hasattr(leaf, "inference")
def where(pytree):
return tuple(x.inference
for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
if has_inference(x))
inference_pytree = equinox.tree_at(where, pytree, replace_fn=lambda _: value)
```
**Arguments:**
- `pytree`: the PyTree to modify.
- `value`: the value to set all `inference` attributes to. Defaults to `True`, i.e.
inference mode.
**Returns:**
A copy of `pytree` with all `inference` flags set to `value`.
"""
return tree_at(_inferences, pytree, replace_fn=lambda _: value)
4 changes: 2 additions & 2 deletions equinox/nn/_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SpectralNorm(StatefulLayer, Generic[_Layer]):
Note that this layer behaves differently during training and inference. During
training then power iterations are updated; during inference they are fixed.
Whether the model is in training or inference mode should be toggled using
[`equinox.tree_inference`][].
[`equinox.nn.inference_mode`][].
""" # noqa: E501

layer: _Layer
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
- `eps`: Epsilon for numerical stability when calculating norms.
- `inference`: Whether this is in inference mode, at which time no power
iterations are performed. This may be toggled with
[`equinox.tree_inference`][].
[`equinox.nn.inference_mode`][].
- `key`: A `jax.random.PRNGKey` used to provide randomness for initialisation.
(Keyword only argument.)
"""
Expand Down
4 changes: 2 additions & 2 deletions examples/deep_convolutional_gan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -640,12 +640,12 @@
" return out\n",
"\n",
"\n",
"inference_gen = eqx.tree_inference(generator, value=True)\n",
"inference_gen = eqx.nn.inference_mode(generator)\n",
"inference_gen = eqx.Partial(inference_gen, state=generator_state)\n",
"\n",
"generated_images = evaluate(inference_gen, noise)\n",
"\n",
"inference_discriminator = eqx.tree_inference(discriminator, value=True)\n",
"inference_discriminator = eqx.nn.inference_mode(discriminator)\n",
"inference_discriminator = eqx.Partial(\n",
" inference_discriminator, state=discriminator_state\n",
")\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/stateful.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
"metadata": {},
"outputs": [],
"source": [
"inference_model = eqx.tree_inference(model, value=True)\n",
"inference_model = eqx.nn.inference_mode(model)\n",
"inference_model = eqx.Partial(inference_model, state=state)\n",
"\n",
"\n",
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ nav:
- 'api/nn/embedding.md'
- 'api/nn/mlp.md'
- 'api/nn/sequential.md'
- 'api/nn/inference.md'
- 'api/nn/stateful.md'
- Filtering:
- 'api/filtering/partition-combine.md'
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def test_batch_norm(getkey):

# Test that the statistics don't update at inference

ibn = eqx.tree_inference(bn, value=True)
ibn = eqx.nn.inference_mode(bn, value=True)
vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vibn(4 * x1 + 20, state)
running_mean3, running_var3 = state.get(bn.state_index)
Expand Down Expand Up @@ -869,7 +869,7 @@ def λ1():
spectral = eqx.tree_at(
lambda s: s.layer.weight, spectral, spectral.layer.weight + 1
)
spectral = eqx.tree_inference(spectral, value=True)
spectral = eqx.nn.inference_mode(spectral, value=True)
assert not jnp.allclose(λ1(), 1)
for _ in range(100):
_, state = spectral(x, state)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ def run3(x, y):
assert not run3(a, 1)


def test_tree_inference(getkey):
def test_inference_mode(getkey):
attention = eqx.nn.MultiheadAttention(2, 4, key=getkey())
assert attention.dropout.inference is False
attention2 = eqx.tree_inference(attention, True)
attention2 = eqx.nn.inference_mode(attention)
assert attention.dropout.inference is False
assert attention2.dropout.inference is True

Expand Down

0 comments on commit 628fb48

Please sign in to comment.