Skip to content

Commit

Permalink
add batch_norm guide
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 4, 2022
1 parent 8b40930 commit 96c89de
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 0 deletions.
294 changes: 294 additions & 0 deletions docs/guides/batch_norm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
Using BatchNorm
===============

In this guide we will go through the details of using ``BatchNorm`` in models,
in the process we will highlight some of the differences between code that uses
``BatchNorm`` and code that does not.

.. testsetup::

import flax.linen as nn
import jax.numpy as jnp
import jax
import optax
from typing import Any
from flax.core import FrozenDict

Defining the model
******************

``BatchNorm`` is a Module that has different runtime behavior between training and
inference. In other frameworks this behavior specified via mutable state, however
in Flax we specify it explicitly via the ``use_running_average`` argument.
A common pattern is to accept a ``train`` argument in the parent Module and use it to define
``BatchNorm``'s ``use_running_average`` argument.

.. codediff::
:title_left: regular code
:title_right: with BatchNorm
:sync:

class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=4)(x)

x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x

---
class MLP(nn.Module):
@nn.compact
def __call__(self, x, train: bool): #!
x = nn.Dense(features=4)(x)
x = nn.BatchNorm(use_running_average=not train)(x) #!
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x

Once the model is created initializing is very similar, the additional
``train`` argument is passed to ``init`` to get back the ``variables`` structure.

The ``batch_stats`` collection
******************************

Apart from the ``params`` collection, ``BatchNorm``
adds an additional ``batch_stats`` collection that contains the running
average of the batch statistics. The ``batch_stats`` collection must be
extracted from the ``variables`` for later use:

.. codediff::
:title_left: regular code
:title_right: with BatchNorm
:sync:

mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x)
params = variables['params']


jax.tree_util.tree_map(jnp.shape, variables)
---
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x, train=False) #!
params = variables['params']
batch_stats = variables['batch_stats'] #!

jax.tree_util.tree_map(jnp.shape, variables)


``BatchNorm`` adds a total of 4 variables: ``mean`` and ``var`` that live in the
``batch_stats`` collection and ``scale`` and ``bias`` that live in the ``params``
collection.

.. codediff::
:title_left: regular code
:title_right: with BatchNorm
:sync:

FrozenDict({






'params': {




'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
---
FrozenDict({
'batch_stats': { #!
'BatchNorm_0': { #!
'mean': (4,), #!
'var': (4,), #!
}, #!
}, #!
'params': {
'BatchNorm_0': { #!
'bias': (4,), #!
'scale': (4,), #!
}, #!
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})

Calling ``apply``
*************

When using ``apply`` to run your model a couple of things must
be taken into consideration:

- ``batch_stats`` must be passed as an input variable.
- The ``train`` argument must be defined.
- During training the ``batch_stats`` collection to be marked as
mutable by setting ``mutable=['batch_stats']``.
- When there are mutable collection, updates to these are returned as a
second output. The updated ``batch_stats`` must be extracted from here.

.. codediff::
:title_left: regular code
:title_right: with BatchNorm
:sync:

y = mlp.apply(
{'params': params},
x,

)
...

---
y, updates = mlp.apply( #!
{'params': params, 'batch_stats': batch_stats}, #!
x,
train=True, mutable=['batch_stats'] #!
)
batch_stats = updates['batch_stats'] #!

Training and Evaluation
***********************

To actually integrate this model into a training loop that uses
``TrainState``, a new ``batch_stats`` field has to be added by subclassing
``TrainState`` and passing the ``batch_stats`` values to the ``create`` method:

.. codediff::
:title_left: regular code
:title_right: with BatchNorm
:sync:

from flax.training import train_state




state = train_state.TrainState.create(
apply_fn=mlp.apply,
params=params,

tx=optax.adam(1e-3),
)
---
from flax.training import train_state

class TrainState(train_state.TrainState): #!
batch_stats: Any #!

state = TrainState.create( #!
apply_fn=mlp.apply,
params=params,
batch_stats=batch_stats, #!
tx=optax.adam(1e-3),
)

Finally the ``train_step`` must be updated to reflect these changes, the main
differences are:

- All new parameters to ``apply`` must be passed (as discussed previously).
- The ``updates`` to the ``batch_stats`` must be propagated out of the ``loss_fn``.
- The ``batch_stats`` field from the ``TrainState`` must be updated.

.. codediff::
:title_left: regular code
:title_right: with BatchNorm
:sync:

@jax.jit
def train_step(state: TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)

metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
---
@jax.jit
def train_step(state: TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits, updates = state.apply_fn( #!
{'params': params, 'batch_stats': state.batch_stats}, #!
x=batch['image'], train=True, mutable=['batch_stats']) #!
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, (logits, updates) #!
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (logits, updates)), grads = grad_fn(state.params) #!
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=updates['batch_stats']) #!
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics

The ``eval_step`` is much simpler, since ``batch_stats`` is not mutable no
updates need to be propagated. Only difference is that ``batch_stats`` must be
passed to ``apply``, and the ``train`` argument must be set to ``False``:

.. codediff::
:title_left: regular code
:title_right: with BatchNorm
:sync:

@jax.jit
def eval_step(state: TrainState, batch):
"""Train for a single step."""
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
---
@jax.jit
def eval_step(state: TrainState, batch):
"""Train for a single step."""
logits = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats}, #!
x=batch['image'], train=False) #!
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
1 change: 1 addition & 0 deletions docs/guides/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Guides
flax_basics
state_params
setup_or_nncompact
batch_norm
model_surgery
transfer_learning
extracting_intermediates
Expand Down

0 comments on commit 96c89de

Please sign in to comment.