Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto-batched (low-rank) multivariate normal guides. #1737

Merged
merged 6 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"LowerCholeskyAffine",
"PermuteTransform",
"PowerTransform",
"ReshapeTransform",
"SigmoidTransform",
"SimplexToOrderedTransform",
"SoftplusTransform",
Expand Down Expand Up @@ -1141,6 +1142,64 @@ def __eq__(self, other):
return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn


def _get_target_shape(shape, forward_shape, inverse_shape):
batch_ndims = len(shape) - len(inverse_shape)
return shape[:batch_ndims] + forward_shape


class ReshapeTransform(Transform):
"""
Reshape a sample, leaving batch dimensions unchanged.

:param forward_shape: Shape to transform the sample to.
:param inverse_shape: Shape of the sample for the inverse transform.
"""

domain = constraints.real
codomain = constraints.real

def __init__(self, forward_shape, inverse_shape) -> None:
forward_size = math.prod(forward_shape)
inverse_size = math.prod(inverse_shape)
if forward_size != inverse_size:
raise ValueError(
f"forward shape {forward_shape} (size {forward_size}) and inverse "
f"shape {inverse_shape} (size {inverse_size}) are not compatible"
)
self._forward_shape = forward_shape
self._inverse_shape = inverse_shape

def forward_shape(self, shape):
return _get_target_shape(shape, self._forward_shape, self._inverse_shape)

def inverse_shape(self, shape):
return _get_target_shape(shape, self._inverse_shape, self._forward_shape)

def __call__(self, x):
return jnp.reshape(x, self.forward_shape(jnp.shape(x)))

def _inverse(self, y):
return jnp.reshape(y, self.inverse_shape(jnp.shape(y)))

def log_abs_det_jacobian(self, x, y, intermediates=None):
return 0.0

def tree_flatten(self):
aux_data = {
"_forward_shape": self._forward_shape,
"_inverse_shape": self._inverse_shape,
}
return (), ((), aux_data)

def __eq__(self, other):
return (
isinstance(other, ReshapeTransform)
and self._forward_shape == other._forward_shape
and self._inverse_shape == other._inverse_shape
)



##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down
154 changes: 154 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from contextlib import ExitStack
from functools import partial
import math
import warnings

import numpy as np
Expand All @@ -29,6 +30,7 @@
IndependentTransform,
LowerCholeskyAffine,
PermuteTransform,
ReshapeTransform,
UnpackTransform,
biject_to,
)
Expand All @@ -50,6 +52,8 @@
from numpyro.util import find_stack_level, not_jax_tracer

__all__ = [
"AutoBatchedLowRankMultivariateNormal",
"AutoBatchedMultivariateNormal",
"AutoContinuous",
"AutoGuide",
"AutoGuideList",
Expand Down Expand Up @@ -1808,6 +1812,106 @@ def quantiles(self, params, quantiles):
return self._unpack_and_constrain(latent, params)


class AutoBatchedMixin:
"""
Mixin to infer the batch and event shapes of batched auto guides.
"""

# Available from AutoContinuous.
latent_dim: int

def __init__(self, *args, **kwargs):
self._batch_shape = None
self._event_shape = None
# Pop the number of batch dimensions and pass the rest to the other constructor.
self.batch_ndim = kwargs.pop("batch_ndim")
super().__init__(*args, **kwargs)

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)

# Extract the batch shape.
batch_shape = None
for site in self.prototype_trace.values():
if site["type"] == "sample" and not site["is_observed"]:
shape = site["value"].shape
if site["value"].ndim < self.batch_ndim + site["fn"].event_dim:
raise ValueError(
f"Expected {self.batch_ndim} batch dimensions, but site "
f"`{site['name']}` only has shape {shape}."
)
shape = shape[:self.batch_ndim]
if batch_shape is None:
batch_shape = shape
elif shape != batch_shape:
raise ValueError("Encountered inconsistent batch shapes.")
self._batch_shape = batch_shape

# Save the event shape of the non-batched part. This will always be a vector.
batch_size = math.prod(self._batch_shape)
if self.latent_dim % batch_size:
raise RuntimeError(
f"Incompatible batch shape {batch_shape} (size {batch_size}) and "
f"latent dims {self.latent_dim}."
)
self._event_shape = (self.latent_dim // batch_size,)

def _get_batched_posterior(self):
raise NotImplementedError

def _get_posterior(self):
return dist.TransformedDistribution(
self._get_batched_posterior(),
ReshapeTransform((self.latent_dim,), self._batch_shape + self._event_shape),
)


class AutoBatchedMultivariateNormal(AutoBatchedMixin, AutoContinuous):
"""
This implementation of :class:`AutoContinuous` uses a batched MultivariateNormal
distribution to construct a guide over the entire latent space.
The guide does not depend on the model's ``*args, **kwargs``.

Usage::

guide = AutoBatchedMultivariateNormal(model, batch_ndim=1, ...)
svi = SVI(model, guide, ...)
"""

scale_tril_constraint = constraints.scaled_unit_lower_cholesky

def __init__(
self,
model,
*,
prefix="auto",
init_loc_fn=init_to_uniform,
init_scale=0.1,
batch_ndim=1,
):
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
self._init_scale = init_scale
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
init_latent = self._init_latent.reshape(self._batch_shape + self._event_shape)
loc = numpyro.param("{}_loc".format(self.prefix), init_latent)
init_scale = (
jnp.ones(self._batch_shape + (1, 1))
* jnp.identity(init_latent.shape[-1])
* self._init_scale
)
scale_tril = numpyro.param(
"{}_scale_tril".format(self.prefix),
init_scale,
constraint=self.scale_tril_constraint,
)
return dist.MultivariateNormal(loc, scale_tril=scale_tril)


class AutoLowRankMultivariateNormal(AutoContinuous):
"""
This implementation of :class:`AutoContinuous` uses a LowRankMultivariateNormal
Expand Down Expand Up @@ -1886,6 +1990,56 @@ def quantiles(self, params, quantiles):
return self._unpack_and_constrain(latent, params)


class AutoBatchedLowRankMultivariateNormal(AutoBatchedMixin, AutoContinuous):
"""
This implementation of :class:`AutoContinuous` uses a batched
AutoLowRankMultivariateNormal distribution to construct a guide over the entire
latent space. The guide does not depend on the model's ``*args, **kwargs``.

Usage::

guide = AutoBatchedLowRankMultivariateNormal(model, batch_ndim=1, ...)
svi = SVI(model, guide, ...)
"""

scale_constraint = constraints.softplus_positive

def __init__(
self,
model,
*,
prefix="auto",
init_loc_fn=init_to_uniform,
init_scale=0.1,
rank=None,
batch_ndim=1,
):
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
self._init_scale = init_scale
self.rank = rank
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
rank = int(round(self._event_shape[0]**0.5)) if self.rank is None else self.rank
init_latent = self._init_latent.reshape(self._batch_shape + self._event_shape)
loc = numpyro.param("{}_loc".format(self.prefix), init_latent)
cov_factor = numpyro.param(
"{}_cov_factor".format(self.prefix),
jnp.zeros(self._batch_shape + self._event_shape + (rank,))
)
scale = numpyro.param(
"{}_scale".format(self.prefix),
jnp.full(self._batch_shape + self._event_shape, self._init_scale),
constraint=self.scale_constraint,
)
cov_diag = scale * scale
cov_factor = cov_factor * scale[..., None]
return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)


class AutoLaplaceApproximation(AutoContinuous):
r"""
Laplace approximation (quadratic approximation) approximates the posterior
Expand Down
62 changes: 61 additions & 1 deletion test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numpy.testing import assert_allclose
import pytest

from jax import jacobian, jit, lax, random
from jax import jacobian, jit, lax, random, vmap
from jax.example_libraries.stax import Dense
import jax.numpy as jnp
from jax.tree_util import tree_all, tree_map
Expand All @@ -23,6 +23,8 @@
from numpyro.handlers import substitute
from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from numpyro.infer.autoguide import (
AutoBatchedLowRankMultivariateNormal,
AutoBatchedMultivariateNormal,
AutoBNAFNormal,
AutoDAIS,
AutoDelta,
Expand Down Expand Up @@ -1251,3 +1253,61 @@ def model():
assert_allclose(
samples["x"].mean(axis=0), jnp.arange(-5, 5), atol=0.2, rtol=0.1
)


@pytest.mark.parametrize(
"auto_class",
[
AutoBatchedMultivariateNormal,
AutoBatchedLowRankMultivariateNormal,
],
)
def test_auto_batched(auto_class) -> None:
# Model for batched multivariate normal.
off_diag = jnp.asarray([-0.2, 0, 0.5])
covs = off_diag[:, None, None] + jnp.eye(4)

def model():
with numpyro.plate("N", off_diag.shape[0]):
numpyro.sample("x", dist.MultivariateNormal(0, covs))

# Run inference.
guide = auto_class(model)
svi = SVI(model, guide, optax.adam(0.001), Trace_ELBO())
result = svi.run(random.PRNGKey(0), 10000)
samples = guide.sample_posterior(
random.PRNGKey(1), result.params, sample_shape=(1000,)
)

# Verify off-diagonal entries are correlated.
empirical_covs = vmap(jnp.cov)(jnp.moveaxis(samples["x"], 0, 2))
i, j = jnp.triu_indices(3, 1)
empirical_off_diag = empirical_covs[:, i, j].mean(axis=1)
corrcoef = jnp.corrcoef(off_diag, empirical_off_diag)[0, 1]
assert corrcoef > 0.99


@pytest.mark.parametrize(
"auto_class",
[
AutoBatchedMultivariateNormal,
AutoBatchedLowRankMultivariateNormal,
],
)
def test_auto_batched_shapes(auto_class) -> None:
def model(n, m):
distribution = dist.Normal().expand([7]).to_event(1)
with numpyro.plate("n", n):
x = numpyro.sample("x", distribution)
with numpyro.plate("m", m):
y = numpyro.sample("y", distribution)
return x, y

with numpyro.handlers.seed(rng_seed=0):
auto_class(model)(3, 3)

with pytest.raises(ValueError, match="inconsistent batch shapes"):
auto_class(model)(3, 4)

with pytest.raises(ValueError, match="Expected 2 batch dimensions"):
auto_class(model, batch_ndim=2)(3, 3)
Loading
Loading