Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GuideMessenger, AutoNormalMessenger, AutoRegressiveMessenger #2953

Merged
merged 22 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/infer.autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ AutoGaussian
:member-order: bysource
:show-inheritance:

AutoNormalMessenger
-------------------
.. autoclass:: pyro.infer.autoguide.AutoNormalMessenger
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:

AutoRegressiveMessenger
-----------------------
.. autoclass:: pyro.infer.autoguide.AutoRegressiveMessenger
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:

.. _autoguide-initialization:

Initialization
Expand Down
7 changes: 7 additions & 0 deletions docs/source/inference_algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ ELBO
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.effect_elbo
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource

Importance
----------

Expand Down
16 changes: 10 additions & 6 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pyro.infer.abstract_infer import EmpiricalMarginal, TracePosterior, TracePredictive
from pyro.infer.csis import CSIS
from pyro.infer.discrete import infer_discrete
from pyro.infer.effect_elbo import Effect_ELBO, GuideMessenger, JitEffect_ELBO
from pyro.infer.elbo import ELBO
from pyro.infer.energy_distance import EnergyDistance
from pyro.infer.enum import config_enumerate
Expand All @@ -27,17 +28,16 @@
from pyro.infer.util import enable_validation, is_validation_enabled

__all__ = [
"config_enumerate",
"CSIS",
"enable_validation",
"is_validation_enabled",
"ELBO",
"Effect_ELBO",
"EmpiricalMarginal",
"EnergyDistance",
"GuideMessenger",
"HMC",
"Importance",
"IMQSteinKernel",
"infer_discrete",
"Importance",
"JitEffect_ELBO",
"JitTraceEnum_ELBO",
"JitTraceGraph_ELBO",
"JitTraceMeanField_ELBO",
Expand All @@ -51,13 +51,17 @@
"SMCFilter",
"SVGD",
"SVI",
"TraceTMC_ELBO",
"TraceEnum_ELBO",
"TraceGraph_ELBO",
"TraceMeanField_ELBO",
"TracePosterior",
"TracePredictive",
"TraceTMC_ELBO",
"TraceTailAdaptive_ELBO",
"Trace_ELBO",
"Trace_MMD",
"config_enumerate",
"enable_validation",
"infer_discrete",
"is_validation_enabled",
]
8 changes: 8 additions & 0 deletions pyro/infer/autoguide/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from pyro.infer.autoguide.effect import (
AutoMessenger,
AutoNormalMessenger,
AutoRegressiveMessenger,
)
from pyro.infer.autoguide.gaussian import AutoGaussian
from pyro.infer.autoguide.guides import (
AutoCallable,
Expand Down Expand Up @@ -41,9 +46,12 @@
"AutoIAFNormal",
"AutoLaplaceApproximation",
"AutoLowRankMultivariateNormal",
"AutoMessenger",
"AutoMultivariateNormal",
"AutoNormal",
"AutoNormalMessenger",
"AutoNormalizingFlow",
"AutoRegressiveMessenger",
"AutoStructured",
"init_to_feasible",
"init_to_generated",
Expand Down
288 changes: 288 additions & 0 deletions pyro/infer/autoguide/effect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Dict, Union

import torch
from torch.distributions import biject_to, constraints

import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.distribution import Distribution
from pyro.infer.effect_elbo import GuideMessenger
from pyro.nn.module import PyroModule, PyroParam, pyro_method
from pyro.poutine.runtime import get_plates

from .initialization import init_to_feasible, init_to_mean
from .utils import deep_getattr, deep_setattr, helpful_support_errors


class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)):
pass


class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta):
"""
EXPERIMENTAL Base class for :class:`pyro.infer.effect_elbo.GuideMessenger`
autoguides.
"""

@pyro_method
def __call__(self, *args, **kwargs):
fritzo marked this conversation as resolved.
Show resolved Hide resolved
# Since this guide creates parameters lazily, we need to avoid batching
# those parameters by a particle plate, in case the first time this
# guide is called is inside a particle plate. We assume all plates
# outside the model are particle plates.
self._outer_plates = get_plates()
try:
return super().__call__(*args, **kwargs)
finally:
del self._outer_plates

def call(self, *args, **kwargs):
"""
Method that calls :meth:`forward` and returns parameter values of the
guide as a `tuple` instead of a `dict`, which is a requirement for
JIT tracing. Unlike :meth:`forward`, this method can be traced by
:func:`torch.jit.trace_module`.

.. warning::
This method may be removed once PyTorch JIT tracer starts accepting
`dict` as valid return types. See
`issue <https://github.com/pytorch/pytorch/issues/27743>_`.
"""
result = self(*args, **kwargs)
return tuple(v for _, v in sorted(result.items()))

def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor:
"""
Removes particle plates from initial values of parameters.
"""
for f in self._outer_plates:
dim = f.dim - event_dim
if -value.dim() <= dim:
dim = dim + value.dim()
value = value[(slice(None),) * dim + (slice(1),)]
for dim in range(value.dim() - event_dim):
value = value.squeeze(0)
return value


class AutoNormalMessenger(AutoMessenger):
"""
EXPERIMENTAL Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` ,
intended for use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or
similar.

The mean-field posterior at any site is a transformed normal distribution.
fritzo marked this conversation as resolved.
Show resolved Hide resolved

Derived classes may override particular sites and use this simply as a
default, e.g.::

def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(0, 1))
c = pyro.sample("c", dist.Normal(a + b, 1))
pyro.sample("obs", dist.Normal(c, 1), obs=data)

class MyGuideMessenger(AutoNormalMessenger):
def get_posterior(self, name, prior, upstream_values):
if name == "c":
# Use a custom distribution at site c.
bias = pyro.param("c_bias", lambda: torch.zeros(()))
weight = pyro.param("c_weight", lambda: torch.ones(()),
constraint=constraints.positive)
scale = pyro.param("c_scale", lambda: torch.ones(()),
constraint=constraints.positive)
a = upstream_values["a"]
b = upstream_values["b"]
loc = bias + weight * (a + b)
return dist.Normal(loc, scale)
# Fall back to mean field.
return super().get_posterior(name, prior, upstream_values)

Note that above we manually computed ``loc = bias + weight * (a + b)``.
Alternatively we could reuse the model-side computation by setting ``loc =
bias + weight * prior.loc``::

class MyGuideMessenger_v2(AutoNormalMessenger):
def get_posterior(self, name, prior, upstream_values):
if name == "c":
# Use a custom distribution at site c.
bias = pyro.param("c_bias", lambda: torch.zeros(()))
scale = pyro.param("c_scale", lambda: torch.ones(()),
constraint=constraints.positive)
weight = pyro.param("c_weight", lambda: torch.ones(()),
constraint=constraints.positive)
loc = bias + weight * prior.loc
fritzo marked this conversation as resolved.
Show resolved Hide resolved
return dist.Normal(loc, scale)
# Fall back to mean field.
return super().get_posterior(name, prior, upstream_values)

:param callable model: A Pyro model.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
(unconstrained transformed) latent variable.
"""

def __init__(
self,
model: Callable,
*,
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
init_scale: float = 0.1,
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
super().__init__(model)
self.init_loc_fn = init_loc_fn
self._init_scale = init_scale
self._computing_median = False

def get_posterior(
self,
name: str,
prior: Distribution,
upstream_values: Dict[str, torch.Tensor],
) -> Union[Distribution, torch.Tensor]:
if self._computing_median:
return self._get_posterior_median(name, prior)

with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
Copy link
Contributor

@vitkl vitkl Oct 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guide will become fully hierarchical if you do this but it is not fully hierarchical by default, right?

loc, scale = self._get_params(name, prior)
loc = loc + prior.loc

Ideally one can add some kind of test of whether this site has dependency sites.

You are also mentioning that it could be useful to encode a more complex dependency:

loc, scale, weight = self._get_params(name, prior)
loc = loc + prior.loc * weight

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the intention of this simple guide is to be mean field.

Do you want to try contributing an AutoHierchicalNormalMessenger guide as a follow-up to this PR? I tried to do something similar with AutoRegressiveMessenger below by sampling from the prior and then shifting in unconstrained space. I was unsure how to implement a general AutoHierarchicalNormalMessenger because not all prior distributions have a .mean method, and even then it is the mean in unconstrained space that we care about. E.g. how do we deal with Gamma or LogNormal or Beta or Dirichlet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point about the distributions that don't have the mean. What are those distributions by the way?

I am thinking about this solution:

loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.loc) * weight

Does it make sense for all distributions that have the mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will doAutoHierchicalNormalMessenger PR - should I wait until this PR is merged?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What [distributions do not have a .mean method] by the way?

  • Heavy tailed distributions may not have a mean, e.g. Cauchy and Stable have infinite variance and no defined mean
  • Non-euclidean distributions such as VonMises3D and ProjectedNormal have no defined mean.
  • Some complex distributions have no computable mean, e.g. TransformedDistribution(Normal(...), MyNormalizingFlow).

Does prior.loc make sense for all distributions that have the mean?

First I would opt for prior.mean rather than prior.loc, since e.g. LogNormal(...).loc isn't a mean, rather it is the mean of the pre-transformed normal. Second note that the transform of the constrained mean is not the same as the unconstrained mean or unconstrained median, e.g. for LogNormal, mean = exp(loc + scale**2 / 2) whereas median = exp(loc).

I think your .mean idea is good enough in most cases, and for cases where it fails, users can subclass and define their own custom .get_posterior() methods.

posterior = dist.TransformedDistribution(
dist.Normal(loc, scale).to_event(transform.domain.event_dim),
transform.with_cache(),
)
return posterior

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
return loc, scale
except AttributeError:
pass

# Initialize.
with poutine.block(), torch.no_grad():
transform = biject_to(prior.support)
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
init_loc = self._remove_outer_plates(unconstrained, event_dim)
init_scale = torch.full_like(init_loc, self._init_scale)

deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
deep_setattr(
self,
"scales." + name,
PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim),
)
return self._get_params(name, prior)

def median(self, *args, **kwargs):
self._computing_median = True
try:
return self(*args, **kwargs)
finally:
self._computing_median = False

def _get_posterior_median(self, name, prior):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
return transform(loc)


class AutoRegressiveMessenger(AutoMessenger):
"""
EXPERIMENTAL Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` ,
intended for use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or
similar.

The posterior at any site is a learned affine transform of the prior,
conditioned on upstream posterior samples. The affine transform operates in
unconstrained space. This supports only continuous latent variables.

Derived classes may override particular sites and use this simply as a
default, e.g.::

class MyGuideMessenger(AutoRegressiveMessenger):
def get_posterior(self, name, prior, upstream_values):
if name == "x":
# Use a custom distribution at site x.
loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape()))
scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())),
constraint=constraints.positive
return dist.Normal(loc, scale).to_event(prior.event_dim())
# Fall back to autoregressive.
return super().get_posterior(name, prior, upstream_values)

.. warning:: This guide currently does not support ``JitEffect_ELBO``.

:param callable model: A Pyro model.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
(unconstrained transformed) latent variable.
"""

def __init__(
self,
model: Callable,
*,
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
init_scale: float = 0.1,
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
super().__init__(model)
self.init_loc_fn = init_loc_fn
self._init_scale = init_scale

def get_posterior(
self,
name: str,
prior: Distribution,
upstream_values: Dict[str, torch.Tensor],
) -> Union[Distribution, torch.Tensor]:
with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
affine = dist.transforms.AffineTransform(
loc, scale, event_dim=transform.domain.event_dim, cache_size=1
)
posterior = dist.TransformedDistribution(
prior, [transform.inv.with_cache(), affine, transform.with_cache()]
)
return posterior

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
return loc, scale
except AttributeError:
pass

# Initialize.
with poutine.block(), torch.no_grad():
transform = biject_to(prior.support)
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
# Initialize the distribution to be an affine combination:
# init_scale * prior + (1 - init_scale) * init_loc
init_loc = self._remove_outer_plates(unconstrained, event_dim)
init_loc = init_loc * (1 - self._init_scale)
init_scale = torch.full_like(init_loc, self._init_scale)

deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
deep_setattr(
self,
"scales." + name,
PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim),
)
return self._get_params(name, prior)
Loading