diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index f3b31adf0..aa2fcd802 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -265,6 +265,13 @@ def init_kernel( `d2` is the max tree depth during post warmup phase. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. + :param bool forward_mode_differentiation: whether to use forward-mode differentiation + or reverse-mode differentiation. By default, we use reverse mode but the forward + mode can be useful in some cases to improve the performance. In addition, some + control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` + only supports forward-mode differentiation. See + `JAX's The Autodiff Cookbook `_ + for more information. :param bool regularize_mass_matrix: whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if ``adapt_mass_matrix == False``. diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index f3b2d81ee..07f36a81a 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -241,9 +241,13 @@ def final_fn(state, regularize=False): def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: - return f(x), jacfwd(f)(x) + def _wrapper(x): + out = f(x) + return out, out + grads, out = jacfwd(_wrapper, has_aux=True)(x) + return out, grads else: - return value_and_grad(f)(x) + return value_and_grad(f, has_aux=False)(x) def _kinetic_grad(kinetic_fn, inverse_mass_matrix, r): diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 4b99302cc..77f05255f 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -256,7 +256,7 @@ def get_params(self, svi_state): params = self.constrain_fn(self.optim.get_params(svi_state.optim_state)) return params - def update(self, svi_state, *args, **kwargs): + def update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs): """ Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer. @@ -264,6 +264,8 @@ def update(self, svi_state, *args, **kwargs): :param svi_state: current state of SVI. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. + Defaults to False. :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple of `(svi_state, loss)`. @@ -281,11 +283,11 @@ def update(self, svi_state, *args, **kwargs): mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_update( - loss_fn, svi_state.optim_state + loss_fn, svi_state.optim_state, forward_mode_differentiation=forward_mode_differentiation ) return SVIState(optim_state, mutable_state, rng_key), loss_val - def stable_update(self, svi_state, *args, **kwargs): + def stable_update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs): """ Similar to :meth:`update` but returns the current state if the the loss or the new state contains invalid values. @@ -293,6 +295,8 @@ def stable_update(self, svi_state, *args, **kwargs): :param svi_state: current state of SVI. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. + Defaults to False. :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple of `(svi_state, loss)`. @@ -310,7 +314,7 @@ def stable_update(self, svi_state, *args, **kwargs): mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update( - loss_fn, svi_state.optim_state + loss_fn, svi_state.optim_state, forward_mode_differentiation=forward_mode_differentiation ) return SVIState(optim_state, mutable_state, rng_key), loss_val @@ -321,6 +325,7 @@ def run( *args, progress_bar=True, stable_update=False, + forward_mode_differentiation=False, init_state=None, init_params=None, **kwargs, @@ -342,6 +347,13 @@ def run( ``True``. :param bool stable_update: whether to use :meth:`stable_update` to update the state. Defaults to False. + :param bool forward_mode_differentiation: whether to use forward-mode differentiation + or reverse-mode differentiation. By default, we use reverse mode but the forward + mode can be useful in some cases to improve the performance. In addition, some + control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` + only supports forward-mode differentiation. See + `JAX's The Autodiff Cookbook `_ + for more information. :param SVIState init_state: if not None, begin SVI from the final state of previous SVI run. Usage:: @@ -365,9 +377,13 @@ def run( def body_fn(svi_state, _): if stable_update: - svi_state, loss = self.stable_update(svi_state, *args, **kwargs) + svi_state, loss = self.stable_update( + svi_state, *args, forward_mode_differentiation=forward_mode_differentiation, **kwargs + ) else: - svi_state, loss = self.update(svi_state, *args, **kwargs) + svi_state, loss = self.update( + svi_state, *args, forward_mode_differentiation=forward_mode_differentiation, **kwargs + ) return svi_state, loss if init_state is None: diff --git a/numpyro/optim.py b/numpyro/optim.py index 8a3b78149..5ad451bab 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -11,7 +11,7 @@ from collections.abc import Callable from typing import Any, TypeVar -from jax import lax, value_and_grad +from jax import jacfwd, lax, value_and_grad from jax.example_libraries import optimizers from jax.flatten_util import ravel_pytree import jax.numpy as jnp @@ -34,6 +34,15 @@ _OptState = TypeVar("_OptState") _IterOptState = tuple[int, _OptState] +def _value_and_grad(f, x, forward_mode_differentiation=False): + if forward_mode_differentiation: + def _wrapper(x): + out, aux = f(x) + return out, (out, aux) + grads, (out, aux) = jacfwd(_wrapper, has_aux=True)(x) + return (out, aux), grads + else: + return value_and_grad(f, has_aux=True)(x) class _NumPyroOptim(object): def __init__(self, optim_fn: Callable, *args, **kwargs) -> None: @@ -61,7 +70,9 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState: opt_state = self.update_fn(i, g, opt_state) return i + 1, opt_state - def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): + def eval_and_update( + self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False + ): """ Performs an optimization step for the objective function `fn`. For most optimizers, the update is performed based on the gradient @@ -74,13 +85,18 @@ def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): is a scalar loss function to be differentiated and the second item is an auxiliary output. :param state: current optimizer state. + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. :return: a pair of the output of objective function and the new optimizer state. """ params = self.get_params(state) - (out, aux), grads = value_and_grad(fn, has_aux=True)(params) + (out, aux), grads = _value_and_grad( + fn, x=params, forward_mode_differentiation=forward_mode_differentiation + ) return (out, aux), self.update(grads, state) - def eval_and_stable_update(self, fn: Callable[[Any], tuple], state: _IterOptState): + def eval_and_stable_update( + self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False + ): """ Like :meth:`eval_and_update` but when the value of the objective function or the gradients are not finite, we will not update the input `state` @@ -88,10 +104,13 @@ def eval_and_stable_update(self, fn: Callable[[Any], tuple], state: _IterOptStat :param fn: objective function. :param state: current optimizer state. + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. :return: a pair of the output of objective function and the new optimizer state. """ params = self.get_params(state) - (out, aux), grads = value_and_grad(fn, has_aux=True)(params) + (out, aux), grads = _value_and_grad( + fn, x=params, forward_mode_differentiation=forward_mode_differentiation + ) out, state = lax.cond( jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(), lambda _: (out, self.update(grads, state)), @@ -266,7 +285,9 @@ def __init__(self, method="BFGS", **kwargs): self._method = method self._kwargs = kwargs - def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): + def eval_and_update( + self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation=False + ): i, (flat_params, unravel_fn) = state def loss_fn(x): diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index 0e62bd395..65a1751ff 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -8,7 +8,7 @@ import pytest import jax -from jax import jit, random, value_and_grad +from jax import jit, lax, random, value_and_grad from jax.example_libraries import optimizers import jax.numpy as jnp from jax.tree_util import tree_all, tree_map @@ -757,3 +757,20 @@ def guide(): params = svi_results.params assert_allclose(params["loc"], actual_loc, rtol=0.1) assert_allclose(params["scale"], actual_scale, rtol=0.1) + + +def test_forward_mode_differentiation(): + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + y = lax.while_loop(lambda x: x < 10, lambda x: x + 1, x) + numpyro.sample("obs", dist.Normal(y, 1), obs=1.0) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=dist.constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + # this fails in reverse mode + optimizer = numpyro.optim.Adam(step_size=0.01) + svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) + svi.run(random.PRNGKey(0), 1000, forward_mode_differentiation=True)