-
Notifications
You must be signed in to change notification settings - Fork 648
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
295 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
Using BatchNorm | ||
=============== | ||
|
||
In this guide we will go through the details of using ``BatchNorm`` in models, | ||
in the process we will highlight some of the differences between code that uses | ||
``BatchNorm`` and code that does not. | ||
|
||
.. testsetup:: | ||
|
||
import flax.linen as nn | ||
import jax.numpy as jnp | ||
import jax | ||
import optax | ||
from typing import Any | ||
from flax.core import FrozenDict | ||
|
||
Defining the model | ||
****************** | ||
|
||
``BatchNorm`` is a Module that has different runtime behavior between training and | ||
inference. In other frameworks this behavior specified via mutable state, however | ||
in Flax we specify it explicitly via the ``use_running_average`` argument. | ||
A common pattern is to accept a ``train`` argument in the parent Module and use it to define | ||
``BatchNorm``'s ``use_running_average`` argument. | ||
|
||
.. codediff:: | ||
:title_left: regular code | ||
:title_right: with BatchNorm | ||
:sync: | ||
|
||
class MLP(nn.Module): | ||
@nn.compact | ||
def __call__(self, x): | ||
x = nn.Dense(features=4)(x) | ||
|
||
x = nn.relu(x) | ||
x = nn.Dense(features=1)(x) | ||
return x | ||
|
||
--- | ||
class MLP(nn.Module): | ||
@nn.compact | ||
def __call__(self, x, train: bool): #! | ||
x = nn.Dense(features=4)(x) | ||
x = nn.BatchNorm(use_running_average=not train)(x) #! | ||
x = nn.relu(x) | ||
x = nn.Dense(features=1)(x) | ||
return x | ||
|
||
Once the model is created initializing is very similar, the additional | ||
``train`` argument is passed to ``init`` to get back the ``variables`` structure. | ||
|
||
The ``batch_stats`` collection | ||
****************************** | ||
|
||
Apart from the ``params`` collection, ``BatchNorm`` | ||
adds an additional ``batch_stats`` collection that contains the running | ||
average of the batch statistics. The ``batch_stats`` collection must be | ||
extracted from the ``variables`` for later use: | ||
|
||
.. codediff:: | ||
:title_left: regular code | ||
:title_right: with BatchNorm | ||
:sync: | ||
|
||
mlp = MLP() | ||
x = jnp.ones((1, 3)) | ||
variables = mlp.init(jax.random.PRNGKey(0), x) | ||
params = variables['params'] | ||
|
||
|
||
jax.tree_util.tree_map(jnp.shape, variables) | ||
--- | ||
mlp = MLP() | ||
x = jnp.ones((1, 3)) | ||
variables = mlp.init(jax.random.PRNGKey(0), x, train=False) #! | ||
params = variables['params'] | ||
batch_stats = variables['batch_stats'] #! | ||
|
||
jax.tree_util.tree_map(jnp.shape, variables) | ||
|
||
|
||
``BatchNorm`` adds a total of 4 variables: ``mean`` and ``var`` that live in the | ||
``batch_stats`` collection and ``scale`` and ``bias`` that live in the ``params`` | ||
collection. | ||
|
||
.. codediff:: | ||
:title_left: regular code | ||
:title_right: with BatchNorm | ||
:sync: | ||
|
||
FrozenDict({ | ||
|
||
|
||
|
||
|
||
|
||
|
||
'params': { | ||
|
||
|
||
|
||
|
||
'Dense_0': { | ||
'bias': (4,), | ||
'kernel': (3, 4), | ||
}, | ||
'Dense_1': { | ||
'bias': (1,), | ||
'kernel': (4, 1), | ||
}, | ||
}, | ||
}) | ||
--- | ||
FrozenDict({ | ||
'batch_stats': { #! | ||
'BatchNorm_0': { #! | ||
'mean': (4,), #! | ||
'var': (4,), #! | ||
}, #! | ||
}, #! | ||
'params': { | ||
'BatchNorm_0': { #! | ||
'bias': (4,), #! | ||
'scale': (4,), #! | ||
}, #! | ||
'Dense_0': { | ||
'bias': (4,), | ||
'kernel': (3, 4), | ||
}, | ||
'Dense_1': { | ||
'bias': (1,), | ||
'kernel': (4, 1), | ||
}, | ||
}, | ||
}) | ||
|
||
Calling ``apply`` | ||
************* | ||
|
||
When using ``apply`` to run your model a couple of things must | ||
be taken into consideration: | ||
|
||
- ``batch_stats`` must be passed as an input variable. | ||
- The ``train`` argument must be defined. | ||
- During training the ``batch_stats`` collection to be marked as | ||
mutable by setting ``mutable=['batch_stats']``. | ||
- When there are mutable collection, updates to these are returned as a | ||
second output. The updated ``batch_stats`` must be extracted from here. | ||
|
||
.. codediff:: | ||
:title_left: regular code | ||
:title_right: with BatchNorm | ||
:sync: | ||
|
||
y = mlp.apply( | ||
{'params': params}, | ||
x, | ||
|
||
) | ||
... | ||
|
||
--- | ||
y, updates = mlp.apply( #! | ||
{'params': params, 'batch_stats': batch_stats}, #! | ||
x, | ||
train=True, mutable=['batch_stats'] #! | ||
) | ||
batch_stats = updates['batch_stats'] #! | ||
|
||
Training and Evaluation | ||
*********************** | ||
|
||
To actually integrate this model into a training loop that uses | ||
``TrainState``, a new ``batch_stats`` field has to be added by subclassing | ||
``TrainState`` and passing the ``batch_stats`` values to the ``create`` method: | ||
|
||
.. codediff:: | ||
:title_left: regular code | ||
:title_right: with BatchNorm | ||
:sync: | ||
|
||
from flax.training import train_state | ||
|
||
|
||
|
||
|
||
state = train_state.TrainState.create( | ||
apply_fn=mlp.apply, | ||
params=params, | ||
|
||
tx=optax.adam(1e-3), | ||
) | ||
--- | ||
from flax.training import train_state | ||
|
||
class TrainState(train_state.TrainState): #! | ||
batch_stats: Any #! | ||
|
||
state = TrainState.create( #! | ||
apply_fn=mlp.apply, | ||
params=params, | ||
batch_stats=batch_stats, #! | ||
tx=optax.adam(1e-3), | ||
) | ||
|
||
Finally the ``train_step`` must be updated to reflect these changes, the main | ||
differences are: | ||
|
||
- All new parameters to ``apply`` must be passed (as discussed previously). | ||
- The ``updates`` to the ``batch_stats`` must be propagated out of the ``loss_fn``. | ||
- The ``batch_stats`` field from the ``TrainState`` must be updated. | ||
|
||
.. codediff:: | ||
:title_left: regular code | ||
:title_right: with BatchNorm | ||
:sync: | ||
|
||
@jax.jit | ||
def train_step(state: TrainState, batch): | ||
"""Train for a single step.""" | ||
def loss_fn(params): | ||
logits = state.apply_fn( | ||
{'params': params}, | ||
x=batch['image']) | ||
loss = optax.softmax_cross_entropy_with_integer_labels( | ||
logits=logits, labels=batch['label']) | ||
return loss, logits | ||
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) | ||
(loss, logits), grads = grad_fn(state.params) | ||
state = state.apply_gradients(grads=grads) | ||
|
||
metrics = { | ||
'loss': loss, | ||
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']), | ||
} | ||
return state, metrics | ||
--- | ||
@jax.jit | ||
def train_step(state: TrainState, batch): | ||
"""Train for a single step.""" | ||
def loss_fn(params): | ||
logits, updates = state.apply_fn( #! | ||
{'params': params, 'batch_stats': state.batch_stats}, #! | ||
x=batch['image'], train=True, mutable=['batch_stats']) #! | ||
loss = optax.softmax_cross_entropy_with_integer_labels( | ||
logits=logits, labels=batch['label']) | ||
return loss, (logits, updates) #! | ||
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) | ||
(loss, (logits, updates)), grads = grad_fn(state.params) #! | ||
state = state.apply_gradients(grads=grads) | ||
state = state.replace(batch_stats=updates['batch_stats']) #! | ||
metrics = { | ||
'loss': loss, | ||
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']), | ||
} | ||
return state, metrics | ||
|
||
The ``eval_step`` is much simpler, since ``batch_stats`` is not mutable no | ||
updates need to be propagated. Only difference is that ``batch_stats`` must be | ||
passed to ``apply``, and the ``train`` argument must be set to ``False``: | ||
|
||
.. codediff:: | ||
:title_left: regular code | ||
:title_right: with BatchNorm | ||
:sync: | ||
|
||
@jax.jit | ||
def eval_step(state: TrainState, batch): | ||
"""Train for a single step.""" | ||
logits = state.apply_fn( | ||
{'params': params}, | ||
x=batch['image']) | ||
loss = optax.softmax_cross_entropy_with_integer_labels( | ||
logits=logits, labels=batch['label']) | ||
metrics = { | ||
'loss': loss, | ||
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']), | ||
} | ||
return state, metrics | ||
--- | ||
@jax.jit | ||
def eval_step(state: TrainState, batch): | ||
"""Train for a single step.""" | ||
logits = state.apply_fn( | ||
{'params': params, 'batch_stats': state.batch_stats}, #! | ||
x=batch['image'], train=False) #! | ||
loss = optax.softmax_cross_entropy_with_integer_labels( | ||
logits=logits, labels=batch['label']) | ||
metrics = { | ||
'loss': loss, | ||
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']), | ||
} | ||
return state, metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters