Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support forward mode differentiation for SVI #1731

Merged
merged 16 commits into from
Feb 8, 2024
7 changes: 7 additions & 0 deletions numpyro/infer/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,13 @@ def init_kernel(
`d2` is the max tree depth during post warmup phase.
:param bool find_heuristic_step_size: whether to a heuristic function to adjust the
step size at the beginning of each adaptation window. Defaults to False.
:param bool forward_mode_differentiation: whether to use forward-mode differentiation
or reverse-mode differentiation. By default, we use reverse mode but the forward
mode can be useful in some cases to improve the performance. In addition, some
control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
only supports forward-mode differentiation. See
`JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
for more information.
:param bool regularize_mass_matrix: whether or not to regularize the estimated mass
matrix for numerical stability during warmup phase. Defaults to True. This flag
does not take effect if ``adapt_mass_matrix == False``.
Expand Down
8 changes: 6 additions & 2 deletions numpyro/infer/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,13 @@ def final_fn(state, regularize=False):

def _value_and_grad(f, x, forward_mode_differentiation=False):
if forward_mode_differentiation:
return f(x), jacfwd(f)(x)
def _wrapper(x):
out = f(x)
return out, out
grads, out = jacfwd(_wrapper, has_aux=True)(x)
return out, grads
else:
return value_and_grad(f)(x)
return value_and_grad(f, has_aux=False)(x)


def _kinetic_grad(kinetic_fn, inverse_mass_matrix, r):
Expand Down
28 changes: 22 additions & 6 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,16 @@ def get_params(self, svi_state):
params = self.constrain_fn(self.optim.get_params(svi_state.optim_state))
return params

def update(self, svi_state, *args, **kwargs):
def update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs):
"""
Take a single step of SVI (possibly on a batch / minibatch of data),
using the optimizer.

:param svi_state: current state of SVI.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
Defaults to False.
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: tuple of `(svi_state, loss)`.
Expand All @@ -281,18 +283,20 @@ def update(self, svi_state, *args, **kwargs):
mutable_state=svi_state.mutable_state,
)
(loss_val, mutable_state), optim_state = self.optim.eval_and_update(
loss_fn, svi_state.optim_state
loss_fn, svi_state.optim_state, forward_mode_differentiation=forward_mode_differentiation
)
return SVIState(optim_state, mutable_state, rng_key), loss_val

def stable_update(self, svi_state, *args, **kwargs):
def stable_update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs):
"""
Similar to :meth:`update` but returns the current state if the
the loss or the new state contains invalid values.

:param svi_state: current state of SVI.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
Defaults to False.
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: tuple of `(svi_state, loss)`.
Expand All @@ -310,7 +314,7 @@ def stable_update(self, svi_state, *args, **kwargs):
mutable_state=svi_state.mutable_state,
)
(loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update(
loss_fn, svi_state.optim_state
loss_fn, svi_state.optim_state, forward_mode_differentiation=forward_mode_differentiation
)
return SVIState(optim_state, mutable_state, rng_key), loss_val

Expand All @@ -321,6 +325,7 @@ def run(
*args,
progress_bar=True,
stable_update=False,
forward_mode_differentiation=False,
init_state=None,
init_params=None,
**kwargs,
Expand All @@ -342,6 +347,13 @@ def run(
``True``.
:param bool stable_update: whether to use :meth:`stable_update` to update
the state. Defaults to False.
:param bool forward_mode_differentiation: whether to use forward-mode differentiation
or reverse-mode differentiation. By default, we use reverse mode but the forward
mode can be useful in some cases to improve the performance. In addition, some
control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
only supports forward-mode differentiation. See
`JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
for more information.
:param SVIState init_state: if not None, begin SVI from the
final state of previous SVI run. Usage::

Expand All @@ -365,9 +377,13 @@ def run(

def body_fn(svi_state, _):
if stable_update:
svi_state, loss = self.stable_update(svi_state, *args, **kwargs)
svi_state, loss = self.stable_update(
svi_state, *args, forward_mode_differentiation=forward_mode_differentiation, **kwargs
)
else:
svi_state, loss = self.update(svi_state, *args, **kwargs)
svi_state, loss = self.update(
svi_state, *args, forward_mode_differentiation=forward_mode_differentiation, **kwargs
)
return svi_state, loss

if init_state is None:
Expand Down
33 changes: 27 additions & 6 deletions numpyro/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Callable
from typing import Any, TypeVar

from jax import lax, value_and_grad
from jax import jacfwd, lax, value_and_grad
from jax.example_libraries import optimizers
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
Expand All @@ -34,6 +34,15 @@
_OptState = TypeVar("_OptState")
_IterOptState = tuple[int, _OptState]

def _value_and_grad(f, x, forward_mode_differentiation=False):
if forward_mode_differentiation:
def _wrapper(x):
out, aux = f(x)
return out, (out, aux)
grads, (out, aux) = jacfwd(_wrapper, has_aux=True)(x)
return (out, aux), grads
else:
return value_and_grad(f, has_aux=True)(x)

class _NumPyroOptim(object):
def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
Expand Down Expand Up @@ -61,7 +70,9 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
opt_state = self.update_fn(i, g, opt_state)
return i + 1, opt_state

def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState):
def eval_and_update(
self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False
):
"""
Performs an optimization step for the objective function `fn`.
For most optimizers, the update is performed based on the gradient
Expand All @@ -74,24 +85,32 @@ def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState):
is a scalar loss function to be differentiated and the second item
is an auxiliary output.
:param state: current optimizer state.
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
:return: a pair of the output of objective function and the new optimizer state.
"""
params = self.get_params(state)
(out, aux), grads = value_and_grad(fn, has_aux=True)(params)
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
return (out, aux), self.update(grads, state)

def eval_and_stable_update(self, fn: Callable[[Any], tuple], state: _IterOptState):
def eval_and_stable_update(
self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False
):
"""
Like :meth:`eval_and_update` but when the value of the objective function
or the gradients are not finite, we will not update the input `state`
and will set the objective output to `nan`.

:param fn: objective function.
:param state: current optimizer state.
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
:return: a pair of the output of objective function and the new optimizer state.
"""
params = self.get_params(state)
(out, aux), grads = value_and_grad(fn, has_aux=True)(params)
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
out, state = lax.cond(
jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(),
lambda _: (out, self.update(grads, state)),
Expand Down Expand Up @@ -266,7 +285,9 @@ def __init__(self, method="BFGS", **kwargs):
self._method = method
self._kwargs = kwargs

def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState):
def eval_and_update(
self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation=False
):
i, (flat_params, unravel_fn) = state

def loss_fn(x):
Expand Down
19 changes: 18 additions & 1 deletion test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import jax
from jax import jit, random, value_and_grad
from jax import jit, lax, random, value_and_grad
from jax.example_libraries import optimizers
import jax.numpy as jnp
from jax.tree_util import tree_all, tree_map
Expand Down Expand Up @@ -757,3 +757,20 @@ def guide():
params = svi_results.params
assert_allclose(params["loc"], actual_loc, rtol=0.1)
assert_allclose(params["scale"], actual_scale, rtol=0.1)


def test_forward_mode_differentiation():
def model():
x = numpyro.sample("x", dist.Normal(0, 1))
y = lax.while_loop(lambda x: x < 10, lambda x: x + 1, x)
numpyro.sample("obs", dist.Normal(y, 1), obs=1.0)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=dist.constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

# this fails in reverse mode
optimizer = numpyro.optim.Adam(step_size=0.01)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi.run(random.PRNGKey(0), 1000, forward_mode_differentiation=True)
Loading