diff --git a/Makefile b/Makefile index fcdad2f4b..468e9c391 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ install: FORCE pip install -e .[dev,doc,test,examples] doctest: FORCE - $(MAKE) -C docs doctest + JAX_PLATFORM_NAME=cpu $(MAKE) -C docs doctest test: lint FORCE pytest -v test diff --git a/examples/covtype.py b/examples/covtype.py index f00c1b0b1..e62867ea1 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -174,7 +174,8 @@ def benchmark_hmc(args, features, labels): subsample_size = 1000 guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) - params, losses = svi.run(random.PRNGKey(2), 2000, features, labels) + svi_result = svi.run(random.PRNGKey(2), 2000, features, labels) + params, losses = svi_result.params, svi_result.losses plt.plot(losses) plt.show() diff --git a/examples/hmcecs.py b/examples/hmcecs.py index c6abdc4ed..010898009 100644 --- a/examples/hmcecs.py +++ b/examples/hmcecs.py @@ -50,9 +50,8 @@ def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel): optimizer = numpyro.optim.Adam(step_size=1e-3) guide = autoguide.AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) - params, losses = svi.run( - svi_key, args.num_svi_steps, data, obs, args.subsample_size - ) + svi_result = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size) + params, losses = svi_result.params, svi_result.losses ref_params = {"theta": params["theta_auto_loc"]} # taylor proxy estimates log likelihood (ll) by diff --git a/numpyro/contrib/control_flow/cond.py b/numpyro/contrib/control_flow/cond.py index 54c0b20d1..338f90309 100644 --- a/numpyro/contrib/control_flow/cond.py +++ b/numpyro/contrib/control_flow/cond.py @@ -111,7 +111,7 @@ def cond(pred, true_fun, false_fun, operand): ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) - >>> params, losses = svi.run(random.PRNGKey(0), num_steps=2500) + >>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500) .. warning:: This is an experimental utility function that allows users to use JAX control flow with NumPyro's effect handlers. Currently, `sample` and diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py index ffdad805f..b6f5114bc 100644 --- a/numpyro/contrib/module.py +++ b/numpyro/contrib/module.py @@ -5,10 +5,12 @@ from copy import deepcopy from functools import partial -from jax import numpy as jnp +from jax import random +import jax.numpy as jnp from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten import numpyro +from numpyro.primitives import mutable as numpyro_mutable __all__ = [ "flax_module", @@ -18,16 +20,40 @@ ] -def flax_module(name, nn_module, *, input_shape=None, **kwargs): +def flax_module( + name, nn_module, *, input_shape=None, apply_rng=None, mutable=None, **kwargs +): """ Declare a :mod:`~flax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. + Given a flax ``nn_module``, in flax to evaluate the module with + a given set of parameters, we use: ``nn_module.apply(params, x)``. + In a NumPyro model, the pattern will be:: + + net = flax_module("net", nn_module) + y = net(x) + + or with dropout layers:: + + net = flax_module("net", nn_module, apply_rng=["dropout"]) + rng_key = numpyro.prng_key() + y = net(x, rngs={"dropout": rng_key}) + :param str name: name of the module to be registered. - :param flax.nn.Module nn_module: a `flax` Module which has .init and .apply methods + :param flax.linen.Module nn_module: a `flax` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. + :param list apply_rng: A list to indicate which extra rng _kinds_ are needed for + ``nn_module``. For example, when ``nn_module`` includes dropout layers, we + need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra + rng key is needed. Please see + `Flax Linen Intro `_ + for more information in how Flax deals with stochastic layers like dropout. + :param list mutable: A list to indicate mutable states of ``nn_module``. For example, + if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``. + See the above `Flax Linen Intro` tutorial for more information. :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :return: a callable with bound parameters that takes an array @@ -45,28 +71,84 @@ def flax_module(name, nn_module, *, input_shape=None, **kwargs): ) from e module_key = name + "$params" nn_params = numpyro.param(module_key) + + if mutable: + nn_state = numpyro_mutable(name + "$state") + assert nn_state is None or isinstance(nn_state, dict) + assert (nn_state is None) == (nn_params is None) + if nn_params is None: - args = (jnp.ones(input_shape),) if input_shape is not None else () # feed in dummy data to init params + args = (jnp.ones(input_shape),) if input_shape is not None else () rng_key = numpyro.prng_key() - _, nn_params = nn_module.init(rng_key, *args, **kwargs) + # split rng_key into a dict of rng_kind: rng_key + rngs = {} + if apply_rng: + assert isinstance(apply_rng, list) + for kind in apply_rng: + rng_key, subkey = random.split(rng_key) + rngs[kind] = subkey + rngs["params"] = rng_key + + nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs)) + if "params" not in nn_vars: + raise ValueError( + "Your nn_module does not have any parameter. Currently, it is not" + " supported in NumPyro. Please make a github issue if you need" + " that feature." + ) + nn_params = nn_vars["params"] + if mutable: + nn_state = {k: v for k, v in nn_vars.items() if k != "params"} + assert set(mutable) == set(nn_state) + numpyro_mutable(name + "$state", nn_state) # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) - return partial(nn_module.call, nn_params) + def apply_with_state(params, *args, **kwargs): + params = {"params": params, **nn_state} + out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs) + nn_state.update(**new_state) + return out + + def apply_without_state(params, *args, **kwargs): + return nn_module.apply({"params": params}, *args, **kwargs) + + apply_fn = apply_with_state if mutable else apply_without_state + return partial(apply_fn, nn_params) -def haiku_module(name, nn_module, *, input_shape=None, **kwargs): + +def haiku_module(name, nn_module, *, input_shape=None, apply_rng=False, **kwargs): """ Declare a :mod:`~haiku` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. + Given a haiku ``nn_module``, in haiku to evaluate the module with + a given set of parameters, we use: ``nn_module.apply(params, None, x)``. + In a NumPyro model, the pattern will be:: + + net = haiku_module("net", nn_module) + y = net(x) # or y = net(rng_key, x) + + or with dropout layers:: + + net = haiku_module("net", nn_module, apply_rng=True) + rng_key = numpyro.prng_key() + y = net(rng_key, x) + :param str name: name of the module to be registered. - :param haiku.Module nn_module: a `haiku` Module which has .init and .apply methods + :param nn_module: a `haiku` Module which has .init and .apply methods + :type nn_module: haiku.Transformed or haiku.TransformedWithState :param tuple input_shape: shape of the input taken by the neural network. + :param bool apply_rng: A flag to indicate if the returned callable requires + an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults + to False, which means no rng argument is needed. If this is True, the signature + of the returned callable ``nn = haiku_module(..., apply_rng=True)`` will be + ``nn(rng_key, x)`` (rather than ``nn(x)``). :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :return: a callable with bound parameters that takes an array @@ -74,7 +156,7 @@ def haiku_module(name, nn_module, *, input_shape=None, **kwargs): array. """ try: - import haiku # noqa: F401 + import haiku as hk # noqa: F401 except ImportError as e: raise ImportError( "Looking like you want to use haiku to declare " @@ -83,21 +165,42 @@ def haiku_module(name, nn_module, *, input_shape=None, **kwargs): "It can be installed with `pip install dm-haiku`." ) from e + if not apply_rng: + nn_module = hk.without_apply_rng(nn_module) + module_key = name + "$params" nn_params = numpyro.param(module_key) + with_state = isinstance(nn_module, hk.TransformedWithState) + if with_state: + nn_state = numpyro_mutable(name + "$state") + assert nn_state is None or isinstance(nn_state, dict) + assert (nn_state is None) == (nn_params is None) + if nn_params is None: args = (jnp.ones(input_shape),) if input_shape is not None else () # feed in dummy data to init params rng_key = numpyro.prng_key() - nn_params = nn_module.init(rng_key, *args, **kwargs) + if with_state: + nn_params, nn_state = nn_module.init(rng_key, *args, **kwargs) + nn_state = dict(nn_state) + numpyro_mutable(name + "$state", nn_state) + else: + nn_params = nn_module.init(rng_key, *args, **kwargs) # haiku init returns an immutable dict - nn_params = haiku.data_structures.to_mutable_dict(nn_params) + nn_params = hk.data_structures.to_mutable_dict(nn_params) # we cast it to a mutable one to be able to set priors for parameters # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) - return partial(nn_module.apply, nn_params, None) + + def apply_with_state(params, *args, **kwargs): + out, new_state = nn_module.apply(params, nn_state, *args, **kwargs) + nn_state.update(**new_state) + return out + + apply_fn = apply_with_state if with_state else nn_module.apply + return partial(apply_fn, nn_params) # register an "empty" parameter which only stores its shape @@ -133,7 +236,9 @@ def _update_params(params, new_params, prior, prefix=""): ) -def random_flax_module(name, nn_module, prior, *, input_shape=None, **kwargs): +def random_flax_module( + name, nn_module, prior, *, input_shape=None, apply_rng=None, mutable=None, **kwargs +): """ A primitive to place a prior over the parameters of the Flax module `nn_module`. @@ -141,30 +246,41 @@ def random_flax_module(name, nn_module, prior, *, input_shape=None, **kwargs): Parameters of a Flax module are stored in a nested dict. For example, the module `B` defined as follows:: - class A(nn.Module): - def apply(self, x): - return nn.Dense(x, 1, bias=False, name='dense') + class A(flax.linen.Module): + @flax.linen.compact + def __call__(self, x): + return nn.Dense(1, use_bias=False, name='dense')(x) - class B(nn.Module): - def apply(self, x): - return A(x, name='inner') + class B(flax.linen.Module): + @flax.linen.compact + def __call__(self, x): + return A(name='inner')(x) has parameters `{'inner': {'dense': {'kernel': param_value}}}`. In the argument `prior`, to specify `kernel` parameter, we join the path to it using dots: `prior={"inner.dense.kernel": param_prior}`. :param str name: name of NumPyro module - :param flax.nn.Module: the module to be registered with NumPyro + :param flax.linen.Module: the module to be registered with NumPyro :param prior: a NumPyro distribution or a Python dict with parameter names as keys and respective distributions as values. For example:: net = random_flax_module("net", - flax.nn.Dense.partial(features=1), + flax.linen.Dense(features=1), prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}, input_shape=(4,)) - :type param: dict or ~numpyro.distributions.Distribution + :type prior: dict or ~numpyro.distributions.Distribution :param tuple input_shape: shape of the input taken by the neural network. + :param list apply_rng: A list to indicate which extra rng _kinds_ are needed for + ``nn_module``. For example, when ``nn_module`` includes dropout layers, we + need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra + rng key is needed. Please see + `Flax Linen Intro `_ + for more information in how Flax deals with stochastic layers like dropout. + :param list mutable: A list to indicate mutable states of ``nn_module``. For example, + if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``. + See the above `Flax Linen Intro` tutorial for more information. :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :returns: a sampled module @@ -176,30 +292,33 @@ def apply(self, x): # NB: this example is ported from https://github.com/ctallec/pyvarinf/blob/master/main_regression.ipynb >>> import numpy as np; np.random.seed(0) >>> import tqdm - >>> from flax import nn + >>> from flax import linen as nn >>> from jax import jit, random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.module import random_flax_module >>> from numpyro.infer import Predictive, SVI, TraceMeanField_ELBO, autoguide, init_to_feasible - >>> + ... >>> class Net(nn.Module): - ... def apply(self, x, n_units): - ... x = nn.Dense(x[..., None], features=n_units) + ... n_units: int + ... + ... @nn.compact + ... def __call__(self, x): + ... x = nn.Dense(self.n_units)(x[..., None]) ... x = nn.relu(x) - ... x = nn.Dense(x, features=n_units) + ... x = nn.Dense(self.n_units)(x) ... x = nn.relu(x) - ... mean = nn.Dense(x, features=1) - ... rho = nn.Dense(x, features=1) + ... mean = nn.Dense(1)(x) + ... rho = nn.Dense(1)(x) ... return mean.squeeze(), rho.squeeze() - >>> + ... >>> def generate_data(n_samples): ... x = np.random.normal(size=n_samples) ... y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2 ... return x, y - >>> + ... >>> def model(x, y=None, batch_size=None): - ... module = Net.partial(n_units=32) + ... module = Net(n_units=32) ... net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=()) ... with numpyro.plate("batch", x.shape[0], subsample_size=batch_size): ... batch_x = numpyro.subsample(x, event_dim=0) @@ -207,14 +326,14 @@ def apply(self, x): ... mean, rho = net(batch_x) ... sigma = nn.softplus(rho) ... numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y) - >>> + ... >>> n_train_data = 5000 >>> x_train, y_train = generate_data(n_train_data) >>> guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible) >>> svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO()) - >>> >>> n_iterations = 3000 - >>> params, losses = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256) + >>> svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256) + >>> params, losses = svi_result.params, svi_result.losses >>> n_test_data = 100 >>> x_test, y_test = generate_data(n_test_data) >>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000) @@ -222,7 +341,14 @@ def apply(self, x): >>> assert losses[-1] < 3000 >>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1 """ - nn = flax_module(name, nn_module, input_shape=input_shape, **kwargs) + nn = flax_module( + name, + nn_module, + input_shape=input_shape, + apply_rng=apply_rng, + mutable=mutable, + **kwargs + ) params = nn.args[0] new_params = deepcopy(params) with numpyro.handlers.scope(prefix=name): @@ -231,12 +357,15 @@ def apply(self, x): return nn_new -def random_haiku_module(name, nn_module, prior, *, input_shape=None, **kwargs): +def random_haiku_module( + name, nn_module, prior, *, input_shape=None, apply_rng=False, **kwargs +): """ A primitive to place a prior over the parameters of the Haiku module `nn_module`. :param str name: name of NumPyro module - :param haiku.Module: the module to be registered with NumPyro + :param nn_module: the module to be registered with NumPyro + :type nn_module: haiku.Transformed or haiku.TransformedWithState :param prior: a NumPyro distribution or a Python dict with parameter names as keys and respective distributions as values. For example:: @@ -245,11 +374,20 @@ def random_haiku_module(name, nn_module, prior, *, input_shape=None, **kwargs): prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()}, input_shape=(4,)) - :type param: dict or ~numpyro.distributions.Distribution + :type prior: dict or ~numpyro.distributions.Distribution :param tuple input_shape: shape of the input taken by the neural network. + :param bool apply_rng: A flag to indicate if the returned callable requires + an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults + to False, which means no rng argument is needed. If this is True, the signature + of the returned callable ``nn = haiku_module(..., apply_rng=True)`` will be + ``nn(rng_key, x)`` (rather than ``nn(x)``). + :param kwargs: optional keyword arguments to initialize flax neural network + as an alternative to `input_shape` :returns: a sampled module """ - nn = haiku_module(name, nn_module, input_shape=input_shape, **kwargs) + nn = haiku_module( + name, nn_module, input_shape=input_shape, apply_rng=apply_rng, **kwargs + ) params = nn.args[0] new_params = deepcopy(params) with numpyro.handlers.scope(prefix=name): diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 53be5116d..8f1386b9e 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -600,7 +600,7 @@ class scope(Messenger): """ This handler prepend a prefix followed by a divider to the name of sample sites. - Example:: + **Example** .. doctest:: @@ -745,7 +745,7 @@ def __init__(self, fn=None, data=None, substitute_fn=None): super(substitute, self).__init__(fn) def process_message(self, msg): - if (msg["type"] not in ("sample", "param", "plate")) or msg.get( + if (msg["type"] not in ("sample", "param", "mutable", "plate")) or msg.get( "_control_flow_done", False ): if msg["type"] == "control_flow": diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index a816847c2..8f3e58585 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -15,7 +15,60 @@ from numpyro.infer.util import get_importance_trace, log_density -class Trace_ELBO: +class ELBO: + """ + Base class for all ELBO objectives. + + Subclasses should implement either :meth:`loss` or :meth:`loss_with_mutable_state`. + + :param num_particles: The number of particles/samples used to form the ELBO + (gradient) estimators. + """ + + def __init__(self, num_particles=1): + self.num_particles = num_particles + + def loss(self, rng_key, param_map, model, guide, *args, **kwargs): + """ + Evaluates the ELBO with an estimator that uses num_particles many samples/particles. + + :param jax.random.PRNGKey rng_key: random number generator seed. + :param dict param_map: dictionary of current parameter values keyed by site + name. + :param model: Python callable with NumPyro primitives for the model. + :param guide: Python callable with NumPyro primitives for the guide. + :param args: arguments to the model / guide (these can possibly vary during + the course of fitting). + :param kwargs: keyword arguments to the model / guide (these can possibly vary + during the course of fitting). + :return: negative of the Evidence Lower Bound (ELBO) to be minimized. + """ + return self.loss_with_mutable_state( + rng_key, param_map, model, guide, *args, **kwargs + )["loss"] + + def loss_with_mutable_state( + self, rng_key, param_map, model, guide, *args, **kwargs + ): + """ + Likes :meth:`loss` but also update and return the mutable state, which stores the + values at :func:`~numpyro.mutable` sites. + + :param jax.random.PRNGKey rng_key: random number generator seed. + :param dict param_map: dictionary of current parameter values keyed by site + name. + :param model: Python callable with NumPyro primitives for the model. + :param guide: Python callable with NumPyro primitives for the guide. + :param args: arguments to the model / guide (these can possibly vary during + the course of fitting). + :param kwargs: keyword arguments to the model / guide (these can possibly vary + during the course of fitting). + :return: a tuple of ELBO loss and the mutable state + """ + raise NotImplementedError("This ELBO objective does not support mutable state.") + + +class Trace_ELBO(ELBO): """ A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the @@ -43,52 +96,56 @@ class Trace_ELBO: def __init__(self, num_particles=1): self.num_particles = num_particles - def loss(self, rng_key, param_map, model, guide, *args, **kwargs): - """ - Evaluates the ELBO with an estimator that uses num_particles many samples/particles. - - :param jax.random.PRNGKey rng_key: random number generator seed. - :param dict param_map: dictionary of current parameter values keyed by site - name. - :param model: Python callable with NumPyro primitives for the model. - :param guide: Python callable with NumPyro primitives for the guide. - :param args: arguments to the model / guide (these can possibly vary during - the course of fitting). - :param kwargs: keyword arguments to the model / guide (these can possibly vary - during the course of fitting). - :return: negative of the Evidence Lower Bound (ELBO) to be minimized. - """ - + def loss_with_mutable_state( + self, rng_key, param_map, model, guide, *args, **kwargs + ): def single_particle_elbo(rng_key): + params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density( seeded_guide, args, kwargs, param_map ) + mutable_params = { + name: site["value"] + for name, site in guide_trace.items() + if site["type"] == "mutable" + } + params.update(mutable_params) seeded_model = replay(seeded_model, guide_trace) - model_log_density, _ = log_density(seeded_model, args, kwargs, param_map) + model_log_density, model_trace = log_density( + seeded_model, args, kwargs, params + ) + mutable_params.update( + { + name: site["value"] + for name, site in model_trace.items() + if site["type"] == "mutable" + } + ) # log p(z) - log q(z) - elbo = model_log_density - guide_log_density - return elbo + elbo_particle = model_log_density - guide_log_density + if mutable_params: + if self.num_particles == 1: + return elbo_particle, mutable_params + else: + raise ValueError( + "Currently, we only support mutable states with num_particles=1." + ) + else: + return elbo_particle, None # Return (-elbo) since by convention we do gradient descent on a loss and # the ELBO is a lower bound that needs to be maximized. if self.num_particles == 1: - return -single_particle_elbo(rng_key) + elbo, mutable_state = single_particle_elbo(rng_key) + return {"loss": -elbo, "mutable_state": mutable_state} else: rng_keys = random.split(rng_key, self.num_particles) - return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) - - -class ELBO(Trace_ELBO): - def __init__(self, num_particles=1): - warnings.warn( - "Using ELBO directly in SVI is deprecated. Please use Trace_ELBO class instead.", - FutureWarning, - ) - super().__init__(num_particles=num_particles) + elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) + return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} def _get_log_prob_sum(site): @@ -128,7 +185,7 @@ def _check_mean_field_requirement(model_trace, guide_trace): ) -class TraceMeanField_ELBO(Trace_ELBO): +class TraceMeanField_ELBO(ELBO): """ A trace implementation of ELBO-based SVI. This is currently the only ELBO estimator in NumPyro that uses analytic KL divergences when those @@ -146,30 +203,31 @@ class TraceMeanField_ELBO(Trace_ELBO): dependency structures. """ - def loss(self, rng_key, param_map, model, guide, *args, **kwargs): - """ - Evaluates the ELBO with an estimator that uses num_particles many samples/particles. - - :param jax.random.PRNGKey rng_key: random number generator seed. - :param dict param_map: dictionary of current parameter values keyed by site - name. - :param model: Python callable with NumPyro primitives for the model. - :param guide: Python callable with NumPyro primitives for the guide. - :param args: arguments to the model / guide (these can possibly vary during - the course of fitting). - :param kwargs: keyword arguments to the model / guide (these can possibly vary - during the course of fitting). - :return: negative of the Evidence Lower Bound (ELBO) to be minimized. - """ - + def loss_with_mutable_state( + self, rng_key, param_map, model, guide, *args, **kwargs + ): def single_particle_elbo(rng_key): + params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) subs_guide = substitute(seeded_guide, data=param_map) guide_trace = trace(subs_guide).get_trace(*args, **kwargs) - subs_model = substitute(replay(seeded_model, guide_trace), data=param_map) + mutable_params = { + name: site["value"] + for name, site in guide_trace.items() + if site["type"] == "mutable" + } + params.update(mutable_params) + subs_model = substitute(replay(seeded_model, guide_trace), data=params) model_trace = trace(subs_model).get_trace(*args, **kwargs) + mutable_params.update( + { + name: site["value"] + for name, site in model_trace.items() + if site["type"] == "mutable" + } + ) _check_mean_field_requirement(model_trace, guide_trace) elbo_particle = 0 @@ -196,16 +254,26 @@ def single_particle_elbo(rng_key): assert site["infer"].get("is_auxiliary") elbo_particle = elbo_particle - _get_log_prob_sum(site) - return elbo_particle + if mutable_params: + if self.num_particles == 1: + return elbo_particle, mutable_params + else: + raise ValueError( + "Currently, we only support mutable states with num_particles=1." + ) + else: + return elbo_particle, None if self.num_particles == 1: - return -single_particle_elbo(rng_key) + elbo, mutable_state = single_particle_elbo(rng_key) + return {"loss": -elbo, "mutable_state": mutable_state} else: rng_keys = random.split(rng_key, self.num_particles) - return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) + elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) + return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} -class RenyiELBO(Trace_ELBO): +class RenyiELBO(ELBO): r""" An implementation of Renyi's :math:`\alpha`-divergence variational inference following reference [1]. @@ -235,24 +303,9 @@ def __init__(self, alpha=0, num_particles=2): "for the case alpha = 1." ) self.alpha = alpha - super(RenyiELBO, self).__init__(num_particles=num_particles) + super().__init__(num_particles=num_particles) def loss(self, rng_key, param_map, model, guide, *args, **kwargs): - """ - Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles. - - :param jax.random.PRNGKey rng_key: random number generator seed. - :param dict param_map: dictionary of current parameter values keyed by site - name. - :param model: Python callable with NumPyro primitives for the model. - :param guide: Python callable with NumPyro primitives for the guide. - :param args: arguments to the model / guide (these can possibly vary during - the course of fitting). - :param kwargs: keyword arguments to the model / guide (these can possibly vary - during the course of fitting). - :returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized. - """ - def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) @@ -458,7 +511,7 @@ def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes): return downstream_costs, downstream_guide_cost_nodes -class TraceGraph_ELBO: +class TraceGraph_ELBO(ELBO): """ A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case @@ -479,7 +532,7 @@ class TraceGraph_ELBO: """ def __init__(self, num_particles=1): - self.num_particles = num_particles + super().__init__(num_particles=num_particles) def loss(self, rng_key, param_map, model, guide, *args, **kwargs): """ diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index ddf120644..7df085062 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -17,28 +17,52 @@ from numpyro.infer.util import helpful_support_errors, transform_fn from numpyro.optim import _NumPyroOptim -SVIState = namedtuple("SVIState", ["optim_state", "rng_key"]) +SVIState = namedtuple("SVIState", ["optim_state", "mutable_state", "rng_key"]) """ A :func:`~collections.namedtuple` consisting of the following fields: - **optim_state** - current optimizer's state. + - **mutable_state** - extra state to store values of `"mutable"` sites - **rng_key** - random number generator seed used for the iteration. """ -SVIRunResult = namedtuple("SVIRunResult", ["params", "losses"]) +SVIRunResult = namedtuple("SVIRunResult", ["params", "state", "losses"]) """ A :func:`~collections.namedtuple` consisting of the following fields: - **params** - the optimized parameters. + - **state** - the last :class:`SVIState` - **losses** - the losses collected at every step. """ -def _apply_loss_fn( - loss_fn, rng_key, constrain_fn, model, guide, args, kwargs, static_kwargs, params +def _make_loss_fn( + elbo, + rng_key, + constrain_fn, + model, + guide, + args, + kwargs, + static_kwargs, + mutable_state=None, ): - return loss_fn( - rng_key, constrain_fn(params), model, guide, *args, **kwargs, **static_kwargs - ) + def loss_fn(params): + params = constrain_fn(params) + if mutable_state is not None: + params.update(mutable_state) + result = elbo.loss_with_mutable_state( + rng_key, params, model, guide, *args, **kwargs, **static_kwargs + ) + return result["loss"], result["mutable_state"] + else: + return ( + elbo.loss( + rng_key, params, model, guide, *args, **kwargs, **static_kwargs + ), + None, + ) + + return loss_fn class SVI(object): @@ -150,6 +174,7 @@ def init(self, rng_key, *args, **kwargs): ) params = {} inv_transforms = {} + mutable_state = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site["type"] == "param": @@ -158,6 +183,8 @@ def init(self, rng_key, *args, **kwargs): transform = biject_to(constraint) inv_transforms[site["name"]] = transform params[site["name"]] = transform.inv(site["value"]) + elif site["type"] == "mutable": + mutable_state[site["name"]] = site["value"] elif ( site["type"] == "sample" and (not site["is_observed"]) @@ -167,13 +194,16 @@ def init(self, rng_key, *args, **kwargs): "Currently, SVI does not support models with discrete latent variables" ) + if not mutable_state: + mutable_state = None self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run - params = tree_map( - lambda x: lax.convert_element_type(x, jnp.result_type(x)), params + params, mutable_state = tree_map( + lambda x: lax.convert_element_type(x, jnp.result_type(x)), + (params, mutable_state), ) - return SVIState(self.optim.init(params), rng_key) + return SVIState(self.optim.init(params), mutable_state, rng_key) def get_params(self, svi_state): """ @@ -198,9 +228,8 @@ def update(self, svi_state, *args, **kwargs): :return: tuple of `(svi_state, loss)`. """ rng_key, rng_key_step = random.split(svi_state.rng_key) - loss_fn = partial( - _apply_loss_fn, - self.loss.loss, + loss_fn = _make_loss_fn( + self.loss, rng_key_step, self.constrain_fn, self.model, @@ -208,11 +237,12 @@ def update(self, svi_state, *args, **kwargs): args, kwargs, self.static_kwargs, + mutable_state=svi_state.mutable_state, ) - loss_val, optim_state = self.optim.eval_and_update( + (loss_val, mutable_state), optim_state = self.optim.eval_and_update( loss_fn, svi_state.optim_state ) - return SVIState(optim_state, rng_key), loss_val + return SVIState(optim_state, mutable_state, rng_key), loss_val def stable_update(self, svi_state, *args, **kwargs): """ @@ -227,9 +257,8 @@ def stable_update(self, svi_state, *args, **kwargs): :return: tuple of `(svi_state, loss)`. """ rng_key, rng_key_step = random.split(svi_state.rng_key) - loss_fn = partial( - _apply_loss_fn, - self.loss.loss, + loss_fn = _make_loss_fn( + self.loss, rng_key_step, self.constrain_fn, self.model, @@ -237,11 +266,12 @@ def stable_update(self, svi_state, *args, **kwargs): args, kwargs, self.static_kwargs, + mutable_state=svi_state.mutable_state, ) - loss_val, optim_state = self.optim.eval_and_stable_update( + (loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update( loss_fn, svi_state.optim_state ) - return SVIState(optim_state, rng_key), loss_val + return SVIState(optim_state, mutable_state, rng_key), loss_val def run( self, @@ -311,7 +341,9 @@ def body_fn(svi_state, _): else: svi_state, losses = lax.scan(body_fn, svi_state, None, length=num_steps) - return SVIRunResult(self.get_params(svi_state), losses) + # XXX: we also return the last svi_state for further inspection of both + # optimizer's state and mutable state. + return SVIRunResult(self.get_params(svi_state), svi_state, losses) def evaluate(self, svi_state, *args, **kwargs): """ diff --git a/numpyro/optim.py b/numpyro/optim.py index bf205e860..bc33f36b3 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -8,7 +8,7 @@ """ from collections import namedtuple -from typing import Callable, Tuple, TypeVar +from typing import Any, Callable, Tuple, TypeVar from jax import lax, value_and_grad from jax.experimental import optimizers @@ -60,7 +60,7 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState: opt_state = self.update_fn(i, g, opt_state) return i + 1, opt_state - def eval_and_update(self, fn: Callable, state: _IterOptState) -> _IterOptState: + def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): """ Performs an optimization step for the objective function `fn`. For most optimizers, the update is performed based on the gradient @@ -69,17 +69,17 @@ def eval_and_update(self, fn: Callable, state: _IterOptState) -> _IterOptState: by reevaluating the function multiple times to get optimal parameters. - :param fn: objective function. + :param fn: an objective function returning a pair where the first item + is a scalar loss function to be differentiated and the second item + is an auxiliary output. :param state: current optimizer state. :return: a pair of the output of objective function and the new optimizer state. """ params = self.get_params(state) - out, grads = value_and_grad(fn)(params) - return out, self.update(grads, state) + (out, aux), grads = value_and_grad(fn, has_aux=True)(params) + return (out, aux), self.update(grads, state) - def eval_and_stable_update( - self, fn: Callable, state: _IterOptState - ) -> _IterOptState: + def eval_and_stable_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): """ Like :meth:`eval_and_update` but when the value of the objective function or the gradients are not finite, we will not update the input `state` @@ -90,14 +90,14 @@ def eval_and_stable_update( :return: a pair of the output of objective function and the new optimizer state. """ params = self.get_params(state) - out, grads = value_and_grad(fn)(params) + (out, aux), grads = value_and_grad(fn, has_aux=True)(params) out, state = lax.cond( jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(), lambda _: (out, self.update(grads, state)), lambda _: (jnp.nan, state), None, ) - return out, state + return (out, aux), state def get_params(self, state: _IterOptState) -> _Params: """ @@ -265,15 +265,21 @@ def __init__(self, method="BFGS", **kwargs): self._method = method self._kwargs = kwargs - def eval_and_update(self, fn: Callable, state: _IterOptState) -> _IterOptState: + def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): i, (flat_params, unravel_fn) = state + + def loss_fn(x): + x = unravel_fn(x) + out, aux = fn(x) + if aux is not None: + raise ValueError( + "Minimize does not support models with mutable states." + ) + return out + results = minimize( - lambda x: fn(unravel_fn(x)), - flat_params, - (), - method=self._method, - **self._kwargs + loss_fn, flat_params, (), method=self._method, **self._kwargs ) flat_params, out = results.x, results.fun state = (i + 1, _MinimizeState(flat_params, unravel_fn)) - return out, state + return (out, None), state diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 8c998702c..f5be24bb4 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -238,6 +238,38 @@ def deterministic(name, value): return msg["value"] +def mutable(name, init_value=None): + """ + This primitive is used to store a mutable value that can be changed + during model execution:: + + a = numpyro.mutable("a", {"value": 1.}) + a["value"] = 2. + assert numpyro.mutable("a")["value"] == 2. + + For example, this can be used to store and update information like + running mean/variance in a neural network batch normalization layer. + + :param str name: name of the mutable site. + :param init_value: mutable value to record in the trace. + """ + if not _PYRO_STACK: + return init_value + + initial_msg = { + "type": "mutable", + "name": name, + "fn": identity, + "args": (init_value,), + "kwargs": {}, + "value": init_value, + } + + # ...and use apply_stack to send it to the Messengers + msg = apply_stack(initial_msg) + return msg["value"] + + def _inspect(): """ EXPERIMENTAL Inspect the Pyro stack. diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 1a4392273..520ebfa7e 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -163,7 +163,8 @@ def false_fun(_): cond(cluster > 0, true_fun, false_fun, None) svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) - params, losses = svi.run(random.PRNGKey(0), num_steps=2500) + svi_result = svi.run(random.PRNGKey(0), num_steps=2500) + params = svi_result.params predictive = Predictive( model, diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 2e411f62d..62ef56869 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -48,7 +48,6 @@ class TestHaikuModule(hk.Module): def __init__(self, dim: int = 100): super().__init__() self._dim = dim - return def __call__(self, w, x): l1 = hk.Linear(self._dim, name="w_linear")(w) @@ -64,7 +63,7 @@ def __call__(self, w, x): def flax_model_by_shape(x, y): import flax - linear_module = flax.nn.Dense.partial(features=100) + linear_module = flax.linen.Dense(features=100) nn = flax_module("nn", linear_module, input_shape=(100,)) mean = nn(x) numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y) @@ -73,7 +72,7 @@ def flax_model_by_shape(x, y): def flax_model_by_kwargs(x, y): import flax - linear_module = flax.nn.Dense.partial(features=100) + linear_module = flax.linen.Dense(features=100) nn = flax_module("nn", linear_module, inputs=x) mean = nn(x) numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y) @@ -149,12 +148,12 @@ def test_update_params(): @pytest.mark.parametrize("backend", ["flax", "haiku"]) @pytest.mark.parametrize("init", ["shape", "kwargs"]) -def test_random_module__mcmc(backend, init): +def test_random_module_mcmc(backend, init): if backend == "flax": import flax - linear_module = flax.nn.Dense.partial(features=1) + linear_module = flax.linen.Dense(features=1) bias_name = "bias" weight_name = "kernel" random_module = random_flax_module @@ -206,3 +205,82 @@ def model(data, labels): true_coefs, atol=0.22, ) + + +@pytest.mark.parametrize("dropout", [True, False]) +@pytest.mark.parametrize("batchnorm", [True, False]) +def test_haiku_state_dropout_smoke(dropout, batchnorm): + import haiku as hk + + def fn(x): + if dropout: + x = hk.dropout(hk.next_rng_key(), 0.5, x) + if batchnorm: + x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)( + x, is_training=True + ) + return x + + def model(): + transform = hk.transform_with_state if batchnorm else hk.transform + nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3)) + x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) + if dropout: + y = nn(numpyro.prng_key(), x) + else: + y = nn(x) + numpyro.deterministic("y", y) + + with handlers.trace(model) as tr, handlers.seed(rng_seed=0): + model() + + if batchnorm: + assert set(tr.keys()) == {"nn$params", "nn$state", "x", "y"} + assert tr["nn$state"]["type"] == "mutable" + else: + assert set(tr.keys()) == {"nn$params", "x", "y"} + + +@pytest.mark.parametrize("dropout", [True, False]) +@pytest.mark.parametrize("batchnorm", [True, False]) +def test_flax_state_dropout_smoke(dropout, batchnorm): + import flax.linen as nn + + class Net(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(10)(x) + if dropout: + x = nn.Dropout(0.5, deterministic=False)(x) + if batchnorm: + x = nn.BatchNorm( + use_bias=True, + use_scale=True, + momentum=0.999, + use_running_average=False, + )(x) + return x + + def model(): + net = flax_module( + "nn", + Net(), + apply_rng=["dropout"] if dropout else None, + mutable=["batch_stats"] if batchnorm else None, + input_shape=(4, 3), + ) + x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) + if dropout: + y = net(x, rngs={"dropout": numpyro.prng_key()}) + else: + y = net(x) + numpyro.deterministic("y", y) + + with handlers.trace(model) as tr, handlers.seed(rng_seed=0): + model() + + if batchnorm: + assert set(tr.keys()) == {"nn$params", "nn$state", "x", "y"} + assert tr["nn$state"]["type"] == "mutable" + else: + assert set(tr.keys()) == {"nn$params", "x", "y"} diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 606f2f08d..197ea1f18 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -471,7 +471,8 @@ def model(y=None): optimiser = numpyro.optim.Adam(step_size=0.01) svi = SVI(model, guide, optimiser, Trace_ELBO()) - params, losses = svi.run(random.PRNGKey(0), num_steps=500, y=y_train) + svi_result = svi.run(random.PRNGKey(0), num_steps=500, y=y_train) + params = svi_result.params posterior_samples = guide.sample_posterior( random.PRNGKey(0), params, sample_shape=(1000,) ) diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index 6e51122dd..6f890db98 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -16,7 +16,14 @@ from numpyro.distributions import constraints from numpyro.distributions.transforms import AffineTransform, SigmoidTransform from numpyro.handlers import substitute -from numpyro.infer import SVI, RenyiELBO, Trace_ELBO, TraceGraph_ELBO +from numpyro.infer import ( + SVI, + RenyiELBO, + Trace_ELBO, + TraceGraph_ELBO, + TraceMeanField_ELBO, +) +from numpyro.primitives import mutable as numpyro_mutable from numpyro.util import fori_loop @@ -96,7 +103,8 @@ def guide(data): numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO()) - params, losses = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar) + svi_result = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar) + params, losses = svi_result.params, svi_result.losses assert losses.shape == (1000,) assert_allclose( params["alpha_q"] / (params["alpha_q"] + params["beta_q"]), @@ -235,6 +243,36 @@ def guide(): svi.run(random.PRNGKey(0), 10) +@pytest.mark.parametrize("stable_update", [True, False]) +@pytest.mark.parametrize("num_particles", [1, 10]) +@pytest.mark.parametrize("elbo", [Trace_ELBO, TraceMeanField_ELBO]) +def test_mutable_state(stable_update, num_particles, elbo): + def model(): + x = numpyro.sample("x", dist.Normal(-1, 1)) + numpyro_mutable("x1p", x + 1) + + def guide(): + loc = numpyro.param("loc", 0.0) + p = numpyro_mutable("loc1p", {"value": None}) + # we can modify the content of `p` if it is a dict + p["value"] = loc + 2 + numpyro.sample("x", dist.Normal(loc, 0.1)) + + svi = SVI(model, guide, optim.Adam(0.1), elbo(num_particles=num_particles)) + if num_particles > 1: + with pytest.raises(ValueError, match="mutable state"): + svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update) + return + svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update) + params = svi_result.params + mutable_state = svi_result.state.mutable_state + assert set(mutable_state) == {"x1p", "loc1p"} + assert_allclose(mutable_state["loc1p"]["value"], params["loc"] + 2, atol=0.1) + # here, the initial loc has value 0., hence x1p will have init value near 1 + # it won't be updated during SVI run because it is not a mutable state + assert_allclose(mutable_state["x1p"], 1.0, atol=0.2) + + def test_tracegraph_normal_normal(): # normal-normal; known covariance lam0 = jnp.array([0.1, 0.1]) # precision of prior