From 84feca498d07aeb9be3714749a8f232581e2536a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 28 Oct 2021 15:32:20 -0400 Subject: [PATCH 01/22] Implement Effect_ELBO and AutoRegressiveMessenger --- pyro/infer/autoguide/effect.py | 86 +++++++++++++++++ pyro/infer/effect_elbo.py | 167 +++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 pyro/infer/autoguide/effect.py create mode 100644 pyro/infer/effect_elbo.py diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py new file mode 100644 index 0000000000..f324ec6457 --- /dev/null +++ b/pyro/infer/autoguide/effect.py @@ -0,0 +1,86 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Union + +import torch + +import pyro.distributions as dist +import pyro.poutine as poutine +from pyro.distributions import constraints +from pyro.distributions.torch_distribution import TorchDistribution +from pyro.infer.effect_elbo import GuideMessenger +from pyro.infer.nn import PyroModule, PyroParam + +from .utils import deep_getattr, deep_setattr, helpful_support_errors + + +class AutoMessengerMeta(type(PyroModule), type(GuideMessenger)): + pass + + +class AutoRegressiveMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): + """ + Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` , intended for + use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or similar. + + The posterior at any site is a learned affine transform of the prior, + conditioned on upstream posterior samples. The affine tramsform operates in + unconstrained space. This supports only continuous inference. + + Derived classes may override particular sites and use this simply as a + default, e.g.:: + + class MyGuideMessenger(AutoRegressiveMessenger): + def get_posterior(self, name, prior, upstream_values): + if name == "x": + # Use a custom distribution at site x. + loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape())) + scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())) + return dist.Normal(loc, scale).to_event(prior.event_dim()) + # Fall back to autoregressive. + return super().get_posterior(name, prior, upstream_values) + """ + + def get_posterior( + self, + name: str, + prior: TorchDistribution, + upstream_values: Dict[str, torch.Tensor], + ) -> Union[TorchDistribution, torch.Tensor]: + with helpful_support_errors({"name": name, "fn": prior}): + transform = constraints.biject_to(prior.support) + loc, scale = self.get_params(name, prior) + affine = dist.transforms.AffineTransform( + loc, scale, event_dim=transform.domain.event_dim, cache_size=1 + ) + posterior = dist.TransformedDistribution( + prior, [transform.inv.with_cache(), affine, transform.with_cache()] + ) + return posterior + + def get_params(self, name, prior): + try: + loc = deep_getattr(self.locs, name) + scale = deep_getattr(self.scales, name) + return loc, scale + except AttributeError: + pass + + # Initialize. + with poutine.block(), torch.no_grad(): + constrained = prior.sample() + transform = constraints.transform_to(prior.support) + unconstrained = transform.inv(constrained) + event_dim = transform.domain.event_dim + deep_setattr( + self.loc, + name, + PyroParam(torch.zeros_like(unconstrained), event_dim=event_dim), + ) + deep_setattr( + self.scale, + name, + PyroParam(torch.ones_like(unconstrained), event_dim=event_dim), + ) + return self.get_params(name, prior) diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py new file mode 100644 index 0000000000..c2abb8b29b --- /dev/null +++ b/pyro/infer/effect_elbo.py @@ -0,0 +1,167 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from typing import Callable, Dict, Tuple, Union + +import torch + +import pyro.distributions as dist +from pyro.distributions.torch_distribution import TorchDistribution +from pyro.infer.elbo import ELBO +from pyro.infer.util import is_validation_enabled +from pyro.poutine.trace_messenger import TraceMessenger +from pyro.poutine.trace_struct import TraceStruct +from pyro.poutine.util import prune_subsample_sites, site_is_subsample +from pyro.util import check_model_guide_match, check_site_shape + +from .trace_elbo import JitTrace_ELBO, Trace_ELBO + + +class GuideMessengerMeta(type(TraceMessenger), ABCMeta): + pass + + +class GuideMessenger(TraceMessenger, metaclass=GuideMessengerMeta): + """ + Abstract base class for effect-based guides for use in :class:`Effect_ELBO` + . + + Derived classes must implement the :meth:`get_posterior` method. + """ + + def __enter__(self, *args, **kwargs) -> TraceStruct: + self.args_kwargs = args, kwargs + self.upstream_values = OrderedDict() + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + del self.args_kwargs + del self.upstream_values + return super().__exit__(self, exc_type, exc_value, traceback) + + def _pyro_sample(self, msg): + if not msg["is_observed"] and not site_is_subsample(msg): + msg["infer"]["prior"] = msg["fn"] + posterior = self.get_posterior(msg["name"], msg["fn"], self.upstream_values) + if isinstance(posterior, torch.Tensor): + posterior = dist.Delta(posterior, event_dim=msg["fn"].event_dim) + msg["fn"] = posterior + return super()._pyro_sample(msg) + + def _pyro_post_sample(self, msg): + self.upstream_values[msg["name"]] = msg["value"] + return super()._pyro_post_sample(msg) + + @abstractmethod + def get_posterior( + self, + name: str, + prior: TorchDistribution, + upstream_values: Dict[str, torch.Tensor], + ) -> Union[TorchDistribution, torch.Tensor]: + """ + Abstract method to compute a posterior distribution or sample a + posterior value given a prior distribution and values of upstream + sample sites. + + Implementations may use ``pyro.param`` and ``pyro.sample`` inside this + function, but ``pyro.sample`` statements should set + ``infer={"is_auxiliary": True"}`` . + + Implementations may access further information for computations: + + - ``args, kwargs = self.args_kwargs`` are the inputs to the model, and + may be useful for amortization. + - ``self.trace`` is a trace of upstream sites, and may be useful for + other information such as ``self.trace.nodes["my_site"]["fn"]`` or + ``self.trace.nodes["my_site"]["cond_indep_stack"]`` . + + :param str name: The name of the sample site to sample. + :param prior: The prior distribution of this sample site + (conditioned on upstream samples from the posterior). + :type prior: ~pyro.distributions.TorchDistribution + :param dict upstream_values: + :returns: A posterior distribution or sample from the posterior + distribution. + :rtype: ~pyro.distributions.TorchDistribution or torch.Tensor + """ + raise NotImplementedError + + def get_traces(self) -> Tuple[TraceStruct, TraceStruct]: + """ + :returns: a pair ``(model_trace, guide_trace)`` + :rtype: tuple + """ + guide_trace = self.trace.copy() + model_trace = self.trace.copy() + for name, guide_site in list(guide_trace.nodes.items()): + if guide_site["type"] != "sample" or guide_site["is_observed"]: + del guide_trace.nodes[name] + continue + model_trace[name]["fn"] = guide_site["infer"]["prior"] + return model_trace, guide_trace + + +class EffectMixin(ELBO): + """ + Mixin class to turn a trace-based ELBO implementation into an effect-based + implementation. + """ + + def _get_trace( + self, model: Callable, guide: GuideMessenger, args: tuple, kwargs: dict + ): + # This differs from Trace_ELBO in that the guide is assumed to be an + # effect handler. + assert isinstance(guide, GuideMessenger) + with guide(*args, **kwargs): + model(*args, **kwargs) + model_trace, guide_trace = guide.get_traces() + if getattr(self, "max_plate_nesting") is None: + self.max_plate_nesting = max( + [0] + + [ + -f.dim + for site in guide.trace.nodes.values() + for f in site["cond_indep_stack"] + if f.vectorized + ] + ) + + # The rest follows pyro.infer.enum.get_importance_trace(). + max_plate_nesting = self.max_plate_nesting + if is_validation_enabled(): + check_model_guide_match(model_trace, guide_trace, max_plate_nesting) + guide_trace = prune_subsample_sites(guide_trace) + model_trace = prune_subsample_sites(model_trace) + model_trace.compute_log_prob() + guide_trace.compute_score_parts() + if is_validation_enabled(): + for site in model_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, max_plate_nesting) + for site in guide_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, max_plate_nesting) + + return model_trace, guide_trace + + +class Effect_ELBO(EffectMixin, Trace_ELBO): + """ + Similar to :class:`~pyro.infer.trace_elbo.Trace_ELBO` but supporting guides + that are effect handlers rather than traceable functions. + """ + + pass + + +class JitEffect_ELBO(EffectMixin, JitTrace_ELBO): + """ + Similar to :class:`~pyro.infer.trace_elbo.JitTrace_ELBO` but supporting guides + that are effect handlers rather than traceable functions. + """ + + pass From afe472c7e547e4daed9622368d9bde1d0579c8cc Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 28 Oct 2021 15:44:52 -0400 Subject: [PATCH 02/22] Fix max_plate_nesting typo --- pyro/infer/effect_elbo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index c2abb8b29b..79f4981da5 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -119,7 +119,7 @@ def _get_trace( with guide(*args, **kwargs): model(*args, **kwargs) model_trace, guide_trace = guide.get_traces() - if getattr(self, "max_plate_nesting") is None: + if self.max_plate_nesting == -float("inf"): self.max_plate_nesting = max( [0] + [ From 75a0a6f55ad24db7c7d2a8e44a415e364227ef46 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 28 Oct 2021 16:18:40 -0400 Subject: [PATCH 03/22] Add docs --- docs/source/infer.autoguide.rst | 8 ++++++++ docs/source/inference_algos.rst | 5 +++++ pyro/infer/__init__.py | 16 ++++++++++------ pyro/infer/autoguide/__init__.py | 2 ++ pyro/infer/autoguide/effect.py | 14 +++++++------- pyro/infer/effect_elbo.py | 20 ++++++++------------ 6 files changed, 40 insertions(+), 25 deletions(-) diff --git a/docs/source/infer.autoguide.rst b/docs/source/infer.autoguide.rst index 6ea6564233..1a48fcdfaa 100644 --- a/docs/source/infer.autoguide.rst +++ b/docs/source/infer.autoguide.rst @@ -125,6 +125,14 @@ AutoGaussian :member-order: bysource :show-inheritance: +AutoRegressiveMessenger +----------------------- +.. autoclass:: pyro.infer.autoguide.AutoRegressiveMessenger + :members: + :undoc-members: + :member-order: bysource + :show-inheritance: + .. _autoguide-initialization: Initialization diff --git a/docs/source/inference_algos.rst b/docs/source/inference_algos.rst index 153028d1f7..af7aa65a99 100644 --- a/docs/source/inference_algos.rst +++ b/docs/source/inference_algos.rst @@ -57,6 +57,11 @@ ELBO :show-inheritance: :member-order: bysource +.. automodule:: pyro.infer.effect_elbo + :members: + :show-inheritance: + :member-order: bysource + Importance ---------- diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 5485b7476e..f2985dba2e 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -4,6 +4,7 @@ from pyro.infer.abstract_infer import EmpiricalMarginal, TracePosterior, TracePredictive from pyro.infer.csis import CSIS from pyro.infer.discrete import infer_discrete +from pyro.infer.effect_elbo import Effect_ELBO, GuideMessenger, JitEffect_ELBO from pyro.infer.elbo import ELBO from pyro.infer.energy_distance import EnergyDistance from pyro.infer.enum import config_enumerate @@ -27,17 +28,16 @@ from pyro.infer.util import enable_validation, is_validation_enabled __all__ = [ - "config_enumerate", "CSIS", - "enable_validation", - "is_validation_enabled", "ELBO", + "Effect_ELBO", "EmpiricalMarginal", "EnergyDistance", + "GuideMessenger", "HMC", - "Importance", "IMQSteinKernel", - "infer_discrete", + "Importance", + "JitEffect_ELBO", "JitTraceEnum_ELBO", "JitTraceGraph_ELBO", "JitTraceMeanField_ELBO", @@ -51,13 +51,17 @@ "SMCFilter", "SVGD", "SVI", - "TraceTMC_ELBO", "TraceEnum_ELBO", "TraceGraph_ELBO", "TraceMeanField_ELBO", "TracePosterior", "TracePredictive", + "TraceTMC_ELBO", "TraceTailAdaptive_ELBO", "Trace_ELBO", "Trace_MMD", + "config_enumerate", + "enable_validation", + "infer_discrete", + "is_validation_enabled", ] diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index 18db07b79c..ef8f9d83ac 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from pyro.infer.autoguide.effect import AutoRegressiveMessenger from pyro.infer.autoguide.gaussian import AutoGaussian from pyro.infer.autoguide.guides import ( AutoCallable, @@ -44,6 +45,7 @@ "AutoMultivariateNormal", "AutoNormal", "AutoNormalizingFlow", + "AutoRegressiveMessenger", "AutoStructured", "init_to_feasible", "init_to_generated", diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index f324ec6457..3d69186d62 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -10,12 +10,12 @@ from pyro.distributions import constraints from pyro.distributions.torch_distribution import TorchDistribution from pyro.infer.effect_elbo import GuideMessenger -from pyro.infer.nn import PyroModule, PyroParam +from pyro.nn.module import PyroModule, PyroParam from .utils import deep_getattr, deep_setattr, helpful_support_errors -class AutoMessengerMeta(type(PyroModule), type(GuideMessenger)): +class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)): pass @@ -25,8 +25,8 @@ class AutoRegressiveMessenger(GuideMessenger, PyroModule, metaclass=AutoMessenge use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or similar. The posterior at any site is a learned affine transform of the prior, - conditioned on upstream posterior samples. The affine tramsform operates in - unconstrained space. This supports only continuous inference. + conditioned on upstream posterior samples. The affine transform operates in + unconstrained space. This supports only continuous latent variables. Derived classes may override particular sites and use this simply as a default, e.g.:: @@ -50,7 +50,7 @@ def get_posterior( ) -> Union[TorchDistribution, torch.Tensor]: with helpful_support_errors({"name": name, "fn": prior}): transform = constraints.biject_to(prior.support) - loc, scale = self.get_params(name, prior) + loc, scale = self._get_params(name, prior) affine = dist.transforms.AffineTransform( loc, scale, event_dim=transform.domain.event_dim, cache_size=1 ) @@ -59,7 +59,7 @@ def get_posterior( ) return posterior - def get_params(self, name, prior): + def _get_params(self, name, prior): try: loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) @@ -83,4 +83,4 @@ def get_params(self, name, prior): name, PyroParam(torch.ones_like(unconstrained), event_dim=event_dim), ) - return self.get_params(name, prior) + return self._get_params(name, prior) diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index 79f4981da5..14605c97de 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from collections import OrderedDict from typing import Callable, Dict, Tuple, Union @@ -12,26 +12,22 @@ from pyro.infer.elbo import ELBO from pyro.infer.util import is_validation_enabled from pyro.poutine.trace_messenger import TraceMessenger -from pyro.poutine.trace_struct import TraceStruct +from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites, site_is_subsample from pyro.util import check_model_guide_match, check_site_shape from .trace_elbo import JitTrace_ELBO, Trace_ELBO -class GuideMessengerMeta(type(TraceMessenger), ABCMeta): - pass - - -class GuideMessenger(TraceMessenger, metaclass=GuideMessengerMeta): +class GuideMessenger(TraceMessenger, ABC): """ Abstract base class for effect-based guides for use in :class:`Effect_ELBO` - . + and similar. Derived classes must implement the :meth:`get_posterior` method. """ - def __enter__(self, *args, **kwargs) -> TraceStruct: + def __enter__(self, *args, **kwargs) -> Trace: self.args_kwargs = args, kwargs self.upstream_values = OrderedDict() return super().__enter__() @@ -89,7 +85,7 @@ def get_posterior( """ raise NotImplementedError - def get_traces(self) -> Tuple[TraceStruct, TraceStruct]: + def get_traces(self) -> Tuple[Trace, Trace]: """ :returns: a pair ``(model_trace, guide_trace)`` :rtype: tuple @@ -152,7 +148,7 @@ def _get_trace( class Effect_ELBO(EffectMixin, Trace_ELBO): """ Similar to :class:`~pyro.infer.trace_elbo.Trace_ELBO` but supporting guides - that are effect handlers rather than traceable functions. + that are :class:`GuideMessenger` s rather than traceable functions. """ pass @@ -161,7 +157,7 @@ class Effect_ELBO(EffectMixin, Trace_ELBO): class JitEffect_ELBO(EffectMixin, JitTrace_ELBO): """ Similar to :class:`~pyro.infer.trace_elbo.JitTrace_ELBO` but supporting guides - that are effect handlers rather than traceable functions. + that are :class:`GuideMessenger` s rather than traceable functions. """ pass From b656276dca59214571c015aace6202c8ba921427 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 28 Oct 2021 19:26:56 -0400 Subject: [PATCH 04/22] Get smoke tests passing --- pyro/infer/autoguide/__init__.py | 3 +- pyro/infer/autoguide/effect.py | 58 +++++++++++++++++++++++++------- pyro/infer/effect_elbo.py | 37 ++++++++++---------- pyro/infer/elbo.py | 33 +++++++++--------- tests/infer/test_autoguide.py | 38 ++++++++++++++++----- 5 files changed, 114 insertions(+), 55 deletions(-) diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index ef8f9d83ac..ebd0997941 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from pyro.infer.autoguide.effect import AutoRegressiveMessenger +from pyro.infer.autoguide.effect import AutoMessenger, AutoRegressiveMessenger from pyro.infer.autoguide.gaussian import AutoGaussian from pyro.infer.autoguide.guides import ( AutoCallable, @@ -42,6 +42,7 @@ "AutoIAFNormal", "AutoLaplaceApproximation", "AutoLowRankMultivariateNormal", + "AutoMessenger", "AutoMultivariateNormal", "AutoNormal", "AutoNormalizingFlow", diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 3d69186d62..737e182046 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -4,13 +4,14 @@ from typing import Dict, Union import torch +from torch.distributions import biject_to, constraints import pyro.distributions as dist import pyro.poutine as poutine -from pyro.distributions import constraints from pyro.distributions.torch_distribution import TorchDistribution from pyro.infer.effect_elbo import GuideMessenger -from pyro.nn.module import PyroModule, PyroParam +from pyro.nn.module import PyroModule, PyroParam, pyro_method +from pyro.poutine.runtime import get_plates from .utils import deep_getattr, deep_setattr, helpful_support_errors @@ -19,7 +20,34 @@ class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)): pass -class AutoRegressiveMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): +class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): + """ + Base class for :class:`pyro.infer.effect_elbo.GuideMessenger` autoguides. + """ + + # Drop args for backwards compatibility with AutoGuide. + def __init__(self, model, *, init_loc_fn=None): + super().__init__() + + def __call__(self, *args, **kwargs): + self._outer_plates = get_plates() + return super().__call__(*args, **kwargs) + + def _remove_outer_plates(self, value, event_dim): + """ + Removes particle plates from initial values of parameters. + """ + for f in self._outer_plates: + dim = -f.dim - event_dim + if -value.dim() <= dim: + dim = dim + value.dim() + value = value[(slice(None),) * dim + slice(1)] + for dim in range(value.dim() - event_dim): + value = value.squeeze(0) + return value + + +class AutoRegressiveMessenger(AutoMessenger): """ Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` , intended for use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or similar. @@ -42,6 +70,7 @@ def get_posterior(self, name, prior, upstream_values): return super().get_posterior(name, prior, upstream_values) """ + @pyro_method def get_posterior( self, name: str, @@ -49,7 +78,7 @@ def get_posterior( upstream_values: Dict[str, torch.Tensor], ) -> Union[TorchDistribution, torch.Tensor]: with helpful_support_errors({"name": name, "fn": prior}): - transform = constraints.biject_to(prior.support) + transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) affine = dist.transforms.AffineTransform( loc, scale, event_dim=transform.domain.event_dim, cache_size=1 @@ -69,18 +98,23 @@ def _get_params(self, name, prior): # Initialize. with poutine.block(), torch.no_grad(): - constrained = prior.sample() - transform = constraints.transform_to(prior.support) + constrained = prior.sample().detach() + transform = biject_to(prior.support) unconstrained = transform.inv(constrained) event_dim = transform.domain.event_dim + prototype = self._remove_outer_plates(unconstrained, event_dim) deep_setattr( - self.loc, - name, - PyroParam(torch.zeros_like(unconstrained), event_dim=event_dim), + self, + "locs." + name, + PyroParam(torch.zeros_like(prototype), event_dim=event_dim), ) deep_setattr( - self.scale, - name, - PyroParam(torch.ones_like(unconstrained), event_dim=event_dim), + self, + "scales." + name, + PyroParam( + torch.ones_like(prototype), + constraint=constraints.positive, + event_dim=event_dim, + ), ) return self._get_params(name, prior) diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index 14605c97de..af309a4e96 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Callable, Dict, Tuple, Union +from typing import Dict, Tuple, Union import torch @@ -27,24 +27,24 @@ class GuideMessenger(TraceMessenger, ABC): Derived classes must implement the :meth:`get_posterior` method. """ - def __enter__(self, *args, **kwargs) -> Trace: + def __call__(self, *args, **kwargs): self.args_kwargs = args, kwargs self.upstream_values = OrderedDict() - return super().__enter__() + return self def __exit__(self, exc_type, exc_value, traceback): del self.args_kwargs del self.upstream_values - return super().__exit__(self, exc_type, exc_value, traceback) + return super().__exit__(exc_type, exc_value, traceback) def _pyro_sample(self, msg): - if not msg["is_observed"] and not site_is_subsample(msg): - msg["infer"]["prior"] = msg["fn"] - posterior = self.get_posterior(msg["name"], msg["fn"], self.upstream_values) - if isinstance(posterior, torch.Tensor): - posterior = dist.Delta(posterior, event_dim=msg["fn"].event_dim) - msg["fn"] = posterior - return super()._pyro_sample(msg) + if msg["is_observed"] or site_is_subsample(msg): + return + msg["infer"]["prior"] = msg["fn"] + posterior = self.get_posterior(msg["name"], msg["fn"], self.upstream_values) + if isinstance(posterior, torch.Tensor): + posterior = dist.Delta(posterior, event_dim=msg["fn"].event_dim) + msg["fn"] = posterior def _pyro_post_sample(self, msg): self.upstream_values[msg["name"]] = msg["value"] @@ -90,13 +90,15 @@ def get_traces(self) -> Tuple[Trace, Trace]: :returns: a pair ``(model_trace, guide_trace)`` :rtype: tuple """ - guide_trace = self.trace.copy() - model_trace = self.trace.copy() + guide_trace = prune_subsample_sites(self.trace) + model_trace = model_trace = guide_trace.copy() for name, guide_site in list(guide_trace.nodes.items()): if guide_site["type"] != "sample" or guide_site["is_observed"]: del guide_trace.nodes[name] continue - model_trace[name]["fn"] = guide_site["infer"]["prior"] + model_site = model_trace.nodes[name].copy() + model_site["fn"] = guide_site["infer"]["prior"] + model_trace.nodes[name] = model_site return model_trace, guide_trace @@ -106,14 +108,13 @@ class EffectMixin(ELBO): implementation. """ - def _get_trace( - self, model: Callable, guide: GuideMessenger, args: tuple, kwargs: dict - ): + def _get_trace(self, model, guide, args, kwargs): # This differs from Trace_ELBO in that the guide is assumed to be an # effect handler. - assert isinstance(guide, GuideMessenger) with guide(*args, **kwargs): model(*args, **kwargs) + while not isinstance(guide, GuideMessenger): + guide = guide.func.args[1] # unwrap plates model_trace, guide_trace = guide.get_traces() if self.max_plate_nesting == -float("inf"): self.max_plate_nesting = max( diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index 6e7b45f17e..c65b672531 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -94,10 +94,17 @@ def _guess_max_plate_nesting(self, model, guide, args, kwargs): """ # Ignore validation to allow model-enumerated sites absent from the guide. with poutine.block(): - guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) - model_trace = poutine.trace( - poutine.replay(model, trace=guide_trace) - ).get_trace(*args, **kwargs) + if isinstance(guide, poutine.messenger.Messenger): + # Subclasses of GuideMessenger. + with guide(*args, **kwargs): + model(*args, **kwargs) + model_trace, guide_trace = guide.get_traces() + else: + # Traditional callable guides. + guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) + model_trace = poutine.trace( + poutine.replay(model, trace=guide_trace) + ).get_trace(*args, **kwargs) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) sites = [ @@ -139,17 +146,13 @@ def _vectorized_num_particles(self, fn): :return: wrapped callable. """ - def wrapped_fn(*args, **kwargs): - if self.num_particles == 1: - return fn(*args, **kwargs) - with pyro.plate( - "num_particles_vectorized", - self.num_particles, - dim=-self.max_plate_nesting, - ): - return fn(*args, **kwargs) - - return wrapped_fn + if self.num_particles == 1: + return fn + return pyro.plate( + "num_particles_vectorized", + self.num_particles, + dim=-self.max_plate_nesting, + )(fn) def _get_vectorized_trace(self, model, guide, args, kwargs): """ diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 73711a4280..f522d1aae9 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -16,6 +16,8 @@ import pyro.poutine as poutine from pyro.infer import ( SVI, + Effect_ELBO, + JitEffect_ELBO, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, @@ -35,8 +37,10 @@ AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, + AutoMessenger, AutoMultivariateNormal, AutoNormal, + AutoRegressiveMessenger, AutoStructured, init_to_feasible, init_to_mean, @@ -57,6 +61,20 @@ ) +def promote_elbo(Guide, Elbo): + """ + Promote e.g. Trace_ELBO --> Effect_ELBO for AutoMessengers. + """ + if issubclass(Guide, AutoMessenger): + if Elbo is Trace_ELBO: + return Effect_ELBO + if Elbo is JitTrace_ELBO: + return Effect_ELBO # DEBUG + return JitEffect_ELBO + pytest.skip("not implemented") + return Elbo + + @pytest.mark.parametrize( "auto_class", [ @@ -202,6 +220,7 @@ def dependency_z6_z5(z5): AutoStructured_shapes, AutoGaussian, AutoGaussianFunsor, + AutoRegressiveMessenger, ], ) @pytest.mark.filterwarnings("ignore::FutureWarning") @@ -218,6 +237,7 @@ def model(): ) pyro.sample("z7", dist.LKJCholesky(2, torch.tensor(1.0))) + Elbo = promote_elbo(auto_class, Elbo) guide = auto_class(model, init_loc_fn=init_loc_fn) elbo = Elbo( num_particles=num_particles, @@ -1278,6 +1298,7 @@ def __init__(self, model): AutoStructured_exact_mvn, AutoGaussian, AutoGaussianFunsor, + AutoRegressiveMessenger, ], ) def test_exact(Guide): @@ -1306,9 +1327,8 @@ def model(data): expected_loss = float(g.event_logsumexp() - g.condition(data).event_logsumexp()) guide = Guide(model) - elbo = JitTrace_ELBO( - num_particles=100, vectorize_particles=True, ignore_jit_warnings=True - ) + Elbo = promote_elbo(Guide, JitTrace_ELBO) + elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) @@ -1342,6 +1362,7 @@ def model(data): AutoStructured_exact_mvn, AutoGaussian, AutoGaussianFunsor, + AutoRegressiveMessenger, ], ) def test_exact_batch(Guide): @@ -1367,9 +1388,8 @@ def model(data): ) guide = Guide(model) - elbo = JitTrace_ELBO( - num_particles=100, vectorize_particles=True, ignore_jit_warnings=True - ) + Elbo = promote_elbo(Guide, JitTrace_ELBO) + elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) @@ -1402,6 +1422,7 @@ def model(data): AutoStructured, AutoGaussian, AutoGaussianFunsor, + AutoRegressiveMessenger, ], ) def test_exact_tree(Guide): @@ -1437,9 +1458,8 @@ def model(data): expected_loss = float(g.event_logsumexp() - g_cond.event_logsumexp()) guide = Guide(model) - elbo = JitTrace_ELBO( - num_particles=100, vectorize_particles=True, ignore_jit_warnings=True - ) + Elbo = promote_elbo(Guide, JitTrace_ELBO) + elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) From 6c29e1bc1461d254dbb616fa666ba1b26783dbae Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 28 Oct 2021 19:40:06 -0400 Subject: [PATCH 05/22] Fix tests --- pyro/infer/autoguide/effect.py | 8 +++++--- pyro/infer/effect_elbo.py | 18 ++++++++++-------- tests/infer/test_autoguide.py | 34 +++++++++++++++++++++++----------- 3 files changed, 38 insertions(+), 22 deletions(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 737e182046..48480885d1 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -22,7 +22,8 @@ class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)): class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): """ - Base class for :class:`pyro.infer.effect_elbo.GuideMessenger` autoguides. + EXPERIMENTAL Base class for :class:`pyro.infer.effect_elbo.GuideMessenger` + autoguides. """ # Drop args for backwards compatibility with AutoGuide. @@ -49,8 +50,9 @@ def _remove_outer_plates(self, value, event_dim): class AutoRegressiveMessenger(AutoMessenger): """ - Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` , intended for - use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or similar. + EXPERIMENTAL Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` , + intended for use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or + similar. The posterior at any site is a learned affine transform of the prior, conditioned on upstream posterior samples. The affine transform operates in diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index af309a4e96..ce2d307f10 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -21,8 +21,8 @@ class GuideMessenger(TraceMessenger, ABC): """ - Abstract base class for effect-based guides for use in :class:`Effect_ELBO` - and similar. + EXPERIMENTAL Abstract base class for effect-based guides for use in + :class:`Effect_ELBO` and similar. Derived classes must implement the :meth:`get_posterior` method. """ @@ -104,8 +104,8 @@ def get_traces(self) -> Tuple[Trace, Trace]: class EffectMixin(ELBO): """ - Mixin class to turn a trace-based ELBO implementation into an effect-based - implementation. + EXPERIMENTAL Mixin class to turn a trace-based ELBO implementation into an + effect-based implementation. """ def _get_trace(self, model, guide, args, kwargs): @@ -148,8 +148,9 @@ def _get_trace(self, model, guide, args, kwargs): class Effect_ELBO(EffectMixin, Trace_ELBO): """ - Similar to :class:`~pyro.infer.trace_elbo.Trace_ELBO` but supporting guides - that are :class:`GuideMessenger` s rather than traceable functions. + EXPERIMENTAL Similar to :class:`~pyro.infer.trace_elbo.Trace_ELBO` but + supporting guides that are :class:`GuideMessenger` s rather than traceable + functions. """ pass @@ -157,8 +158,9 @@ class Effect_ELBO(EffectMixin, Trace_ELBO): class JitEffect_ELBO(EffectMixin, JitTrace_ELBO): """ - Similar to :class:`~pyro.infer.trace_elbo.JitTrace_ELBO` but supporting guides - that are :class:`GuideMessenger` s rather than traceable functions. + EXPERIMENTAL Similar to :class:`~pyro.infer.trace_elbo.JitTrace_ELBO` but + supporting guides that are :class:`GuideMessenger` s rather than traceable + functions. """ pass diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index f522d1aae9..cfeea95f02 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -69,7 +69,7 @@ def promote_elbo(Guide, Elbo): if Elbo is Trace_ELBO: return Effect_ELBO if Elbo is JitTrace_ELBO: - return Effect_ELBO # DEBUG + return Effect_ELBO # DEBUG work around "Trying to backward a second time" return JitEffect_ELBO pytest.skip("not implemented") return Elbo @@ -1338,9 +1338,13 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): # Check moments. - vectorize = pyro.plate("particles", 10000, dim=-2) - guide_trace = poutine.trace(vectorize(guide)).get_trace(data) - samples = poutine.replay(vectorize(model), guide_trace)(data) + with pyro.plate("particles", 10000, dim=-2): + if isinstance(guide, poutine.messenger.Messenger): + with guide(data): + samples = model(data) + else: + guide_trace = poutine.trace(guide).get_trace(data) + samples = poutine.replay(model, guide_trace)(data) actual_mean = samples.mean().item() actual_std = samples.std().item() assert_close(actual_mean, expected_mean, atol=0.05) @@ -1399,9 +1403,13 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): # Check moments. - vectorize = pyro.plate("particles", 10000, dim=-2) - guide_trace = poutine.trace(vectorize(guide)).get_trace(data) - samples = poutine.replay(vectorize(model), guide_trace)(data) + with pyro.plate("particles", 10000, dim=-2): + if isinstance(guide, poutine.messenger.Messenger): + with guide(data): + samples = model(data) + else: + guide_trace = poutine.trace(guide).get_trace(data) + samples = poutine.replay(model, guide_trace)(data) actual_mean = samples.mean(0) actual_std = samples.std(0) assert_close(actual_mean, expected_mean, atol=0.05) @@ -1426,7 +1434,7 @@ def model(data): ], ) def test_exact_tree(Guide): - is_exact = Guide not in (AutoNormal, AutoDiagonalNormal) + is_exact = Guide not in (AutoNormal, AutoDiagonalNormal, AutoRegressiveMessenger) def model(data): x = pyro.sample("x", dist.Normal(0, 1)) @@ -1470,9 +1478,13 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): # Check moments. - vectorize = pyro.plate("particles", 10000, dim=-2) - guide_trace = poutine.trace(vectorize(guide)).get_trace(data) - samples = poutine.replay(vectorize(model), guide_trace)(data) + with pyro.plate("particles", 10000, dim=-2): + if isinstance(guide, poutine.messenger.Messenger): + with guide(data): + samples = model(data) + else: + guide_trace = poutine.trace(guide).get_trace(data) + samples = poutine.replay(model, guide_trace)(data) for name in ["x", "y"]: actual_mean = samples[name].mean(0).squeeze() actual_std = samples[name].std(0).squeeze() From 299340259284baa14697922869bb0a1d0e848ff0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 29 Oct 2021 09:17:41 -0400 Subject: [PATCH 06/22] Add AutoNormalMessenger --- docs/source/infer.autoguide.rst | 8 ++++ pyro/infer/autoguide/__init__.py | 7 +++- pyro/infer/autoguide/effect.py | 70 ++++++++++++++++++++++++++++++++ tests/infer/test_autoguide.py | 6 ++- 4 files changed, 89 insertions(+), 2 deletions(-) diff --git a/docs/source/infer.autoguide.rst b/docs/source/infer.autoguide.rst index 1a48fcdfaa..f82915d595 100644 --- a/docs/source/infer.autoguide.rst +++ b/docs/source/infer.autoguide.rst @@ -125,6 +125,14 @@ AutoGaussian :member-order: bysource :show-inheritance: +AutoNormalMessenger +------------------- +.. autoclass:: pyro.infer.autoguide.AutoNormalMessenger + :members: + :undoc-members: + :member-order: bysource + :show-inheritance: + AutoRegressiveMessenger ----------------------- .. autoclass:: pyro.infer.autoguide.AutoRegressiveMessenger diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index ebd0997941..30f928943c 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -1,7 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from pyro.infer.autoguide.effect import AutoMessenger, AutoRegressiveMessenger +from pyro.infer.autoguide.effect import ( + AutoMessenger, + AutoNormalMessenger, + AutoRegressiveMessenger, +) from pyro.infer.autoguide.gaussian import AutoGaussian from pyro.infer.autoguide.guides import ( AutoCallable, @@ -45,6 +49,7 @@ "AutoMessenger", "AutoMultivariateNormal", "AutoNormal", + "AutoNormalMessenger", "AutoNormalizingFlow", "AutoRegressiveMessenger", "AutoStructured", diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 48480885d1..f82b66f772 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -48,6 +48,76 @@ def _remove_outer_plates(self, value, event_dim): return value +class AutoNormalMessenger(AutoMessenger): + """ + EXPERIMENTAL Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` , + intended for use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or + similar. + + The mean-field posterior at any site is a transformed normal distribution. + + Derived classes may override particular sites and use this simply as a + default, e.g.:: + + class MyGuideMessenger(AutoNormalMessenger): + def get_posterior(self, name, prior, upstream_values): + if name == "x": + # Use a custom distribution at site x. + loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape())) + scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())) + return dist.Normal(loc, scale).to_event(prior.event_dim()) + # Fall back to autoregressive. + return super().get_posterior(name, prior, upstream_values) + """ + + @pyro_method + def get_posterior( + self, + name: str, + prior: TorchDistribution, + upstream_values: Dict[str, torch.Tensor], + ) -> Union[TorchDistribution, torch.Tensor]: + with helpful_support_errors({"name": name, "fn": prior}): + transform = biject_to(prior.support) + loc, scale = self._get_params(name, prior) + posterior = dist.TransformedDistribution( + dist.Normal(loc, scale).to_event(transform.domain.event_dim), + transform.with_cache() + ).expand(prior.batch_shape) + return posterior + + def _get_params(self, name, prior): + try: + loc = deep_getattr(self.locs, name) + scale = deep_getattr(self.scales, name) + return loc, scale + except AttributeError: + pass + + # Initialize. + with poutine.block(), torch.no_grad(): + constrained = prior.sample().detach() + transform = biject_to(prior.support) + unconstrained = transform.inv(constrained) + event_dim = transform.domain.event_dim + prototype = self._remove_outer_plates(unconstrained, event_dim) + deep_setattr( + self, + "locs." + name, + PyroParam(torch.zeros_like(prototype), event_dim=event_dim), + ) + deep_setattr( + self, + "scales." + name, + PyroParam( + torch.ones_like(prototype), + constraint=constraints.positive, + event_dim=event_dim, + ), + ) + return self._get_params(name, prior) + + class AutoRegressiveMessenger(AutoMessenger): """ EXPERIMENTAL Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` , diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index cfeea95f02..5602dc57a3 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -40,6 +40,7 @@ AutoMessenger, AutoMultivariateNormal, AutoNormal, + AutoNormalMessenger, AutoRegressiveMessenger, AutoStructured, init_to_feasible, @@ -1298,6 +1299,7 @@ def __init__(self, model): AutoStructured_exact_mvn, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1366,6 +1368,7 @@ def model(data): AutoStructured_exact_mvn, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1430,11 +1433,12 @@ def model(data): AutoStructured, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, AutoRegressiveMessenger, ], ) def test_exact_tree(Guide): - is_exact = Guide not in (AutoNormal, AutoDiagonalNormal, AutoRegressiveMessenger) + is_exact = Guide not in (AutoNormal, AutoDiagonalNormal, AutoNormalMessenger, AutoRegressiveMessenger) def model(data): x = pyro.sample("x", dist.Normal(0, 1)) From 6eb02e4e8dc4576cbaa47f9575d2d1e5bfbadd52 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 29 Oct 2021 09:28:58 -0400 Subject: [PATCH 07/22] Add example to docstring --- pyro/infer/autoguide/effect.py | 49 +++++++++++++++++++++++++++------- pyro/infer/effect_elbo.py | 9 ++++--- tests/infer/test_autoguide.py | 7 ++++- 3 files changed, 52 insertions(+), 13 deletions(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index f82b66f772..82405bcf2d 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -59,14 +59,44 @@ class AutoNormalMessenger(AutoMessenger): Derived classes may override particular sites and use this simply as a default, e.g.:: + def model(data): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + c = pyro.sample("c", dist.Normal(a + b, 1)) + pyro.sample("obs", dist.Normal(c, 1), obs=data) + class MyGuideMessenger(AutoNormalMessenger): def get_posterior(self, name, prior, upstream_values): - if name == "x": - # Use a custom distribution at site x. - loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape())) - scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())) - return dist.Normal(loc, scale).to_event(prior.event_dim()) - # Fall back to autoregressive. + if name == "c": + # Use a custom distribution at site c. + bias = pyro.param("c_bias", lambda: torch.zeros(())) + weight = pyro.param("c_weight", lambda: torch.ones(()), + constraint=constraints.positive) + scale = pyro.param("c_scale", lambda: torch.ones(()), + constraint=constraints.positive) + a = upstream_values["a"] + b = upstream_values["b"] + loc = bias + weight * (a + b) + return dist.Normal(loc, scale) + # Fall back to mean field. + return super().get_posterior(name, prior, upstream_values) + + Note that above we manually computed ``loc = bias + weight * (a + b)``. + Alternatively we could reuse the model-side computation by setting ``loc = + bias + weight * prior.loc``:: + + class MyGuideMessenger_v2(AutoNormalMessenger): + def get_posterior(self, name, prior, upstream_values): + if name == "c": + # Use a custom distribution at site c. + bias = pyro.param("c_bias", lambda: torch.zeros(())) + scale = pyro.param("c_scale", lambda: torch.ones(()), + constraint=constraints.positive) + weight = pyro.param("c_weight", lambda: torch.ones(()), + constraint=constraints.positive) + loc = bias + weight * prior.loc + return dist.Normal(loc, scale) + # Fall back to mean field. return super().get_posterior(name, prior, upstream_values) """ @@ -82,8 +112,8 @@ def get_posterior( loc, scale = self._get_params(name, prior) posterior = dist.TransformedDistribution( dist.Normal(loc, scale).to_event(transform.domain.event_dim), - transform.with_cache() - ).expand(prior.batch_shape) + transform.with_cache(), + ) return posterior def _get_params(self, name, prior): @@ -136,7 +166,8 @@ def get_posterior(self, name, prior, upstream_values): if name == "x": # Use a custom distribution at site x. loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape())) - scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())) + scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())), + constraint=constraints.positive return dist.Normal(loc, scale).to_event(prior.event_dim()) # Fall back to autoregressive. return super().get_posterior(name, prior, upstream_values) diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index ce2d307f10..04375863ed 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -40,10 +40,13 @@ def __exit__(self, exc_type, exc_value, traceback): def _pyro_sample(self, msg): if msg["is_observed"] or site_is_subsample(msg): return - msg["infer"]["prior"] = msg["fn"] - posterior = self.get_posterior(msg["name"], msg["fn"], self.upstream_values) + prior = msg["fn"] + msg["infer"]["prior"] = prior + posterior = self.get_posterior(msg["name"], prior, self.upstream_values) if isinstance(posterior, torch.Tensor): - posterior = dist.Delta(posterior, event_dim=msg["fn"].event_dim) + posterior = dist.Delta(posterior, event_dim=prior.event_dim) + if posterior.batch_shape != prior.batch_shape: + posterior = posterior.expand(prior.batch_shape) msg["fn"] = posterior def _pyro_post_sample(self, msg): diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 5602dc57a3..1190f7cba8 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1438,7 +1438,12 @@ def model(data): ], ) def test_exact_tree(Guide): - is_exact = Guide not in (AutoNormal, AutoDiagonalNormal, AutoNormalMessenger, AutoRegressiveMessenger) + is_exact = Guide not in ( + AutoNormal, + AutoDiagonalNormal, + AutoNormalMessenger, + AutoRegressiveMessenger, + ) def model(data): x = pyro.sample("x", dist.Normal(0, 1)) From 84d5e60b37ba0edd76a57d00b1347bc7c623982c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 29 Oct 2021 21:20:58 -0400 Subject: [PATCH 08/22] Support calling guide(*args, **kwargs) --- pyro/infer/autoguide/effect.py | 7 +++-- pyro/infer/effect_elbo.py | 52 ++++++++++++++++++++-------------- pyro/infer/elbo.py | 15 +++------- tests/infer/test_autoguide.py | 24 ++++------------ 4 files changed, 46 insertions(+), 52 deletions(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 82405bcf2d..76e7f4147a 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -28,11 +28,14 @@ class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): # Drop args for backwards compatibility with AutoGuide. def __init__(self, model, *, init_loc_fn=None): - super().__init__() + super().__init__(model) def __call__(self, *args, **kwargs): self._outer_plates = get_plates() - return super().__call__(*args, **kwargs) + try: + return super().__call__(*args, **kwargs) + finally: + del self._outer_plates def _remove_outer_plates(self, value, event_dim): """ diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index 04375863ed..f3821673c2 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -27,15 +27,32 @@ class GuideMessenger(TraceMessenger, ABC): Derived classes must implement the :meth:`get_posterior` method. """ - def __call__(self, *args, **kwargs): + def __init__(self, model): + super().__init__() + # Do not register model as submodule + self._model = (model,) + + @property + def model(self): + return self._model[0] + + def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: self.args_kwargs = args, kwargs self.upstream_values = OrderedDict() - return self - - def __exit__(self, exc_type, exc_value, traceback): - del self.args_kwargs - del self.upstream_values - return super().__exit__(exc_type, exc_value, traceback) + try: + with self: + self.model(*args, **kwargs) + finally: + del self.args_kwargs + del self.upstream_values + + model_trace, guide_trace = self.get_traces() + samples = { + name: site["value"] + for name, site in model_trace.nodes.items() + if site["type"] == "sample" + } + return samples def _pyro_sample(self, msg): if msg["is_observed"] or site_is_subsample(msg): @@ -51,6 +68,12 @@ def _pyro_sample(self, msg): def _pyro_post_sample(self, msg): self.upstream_values[msg["name"]] = msg["value"] + + # Manually apply outer plates. + prior = msg["infer"].get("prior") + if prior is not None and prior.batch_shape != msg["fn"].batch_shape: + msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) + return super()._pyro_post_sample(msg) @abstractmethod @@ -114,28 +137,15 @@ class EffectMixin(ELBO): def _get_trace(self, model, guide, args, kwargs): # This differs from Trace_ELBO in that the guide is assumed to be an # effect handler. - with guide(*args, **kwargs): - model(*args, **kwargs) + guide(*args, **kwargs) while not isinstance(guide, GuideMessenger): guide = guide.func.args[1] # unwrap plates model_trace, guide_trace = guide.get_traces() - if self.max_plate_nesting == -float("inf"): - self.max_plate_nesting = max( - [0] - + [ - -f.dim - for site in guide.trace.nodes.values() - for f in site["cond_indep_stack"] - if f.vectorized - ] - ) # The rest follows pyro.infer.enum.get_importance_trace(). max_plate_nesting = self.max_plate_nesting if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, max_plate_nesting) - guide_trace = prune_subsample_sites(guide_trace) - model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index c65b672531..3abe07d748 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -94,17 +94,10 @@ def _guess_max_plate_nesting(self, model, guide, args, kwargs): """ # Ignore validation to allow model-enumerated sites absent from the guide. with poutine.block(): - if isinstance(guide, poutine.messenger.Messenger): - # Subclasses of GuideMessenger. - with guide(*args, **kwargs): - model(*args, **kwargs) - model_trace, guide_trace = guide.get_traces() - else: - # Traditional callable guides. - guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) - model_trace = poutine.trace( - poutine.replay(model, trace=guide_trace) - ).get_trace(*args, **kwargs) + guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) + model_trace = poutine.trace( + poutine.replay(model, trace=guide_trace) + ).get_trace(*args, **kwargs) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) sites = [ diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 1190f7cba8..1e9b404eab 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1341,12 +1341,8 @@ def model(data): with torch.no_grad(): # Check moments. with pyro.plate("particles", 10000, dim=-2): - if isinstance(guide, poutine.messenger.Messenger): - with guide(data): - samples = model(data) - else: - guide_trace = poutine.trace(guide).get_trace(data) - samples = poutine.replay(model, guide_trace)(data) + guide_trace = poutine.trace(guide).get_trace(data) + samples = poutine.replay(model, guide_trace)(data) actual_mean = samples.mean().item() actual_std = samples.std().item() assert_close(actual_mean, expected_mean, atol=0.05) @@ -1407,12 +1403,8 @@ def model(data): with torch.no_grad(): # Check moments. with pyro.plate("particles", 10000, dim=-2): - if isinstance(guide, poutine.messenger.Messenger): - with guide(data): - samples = model(data) - else: - guide_trace = poutine.trace(guide).get_trace(data) - samples = poutine.replay(model, guide_trace)(data) + guide_trace = poutine.trace(guide).get_trace(data) + samples = poutine.replay(model, guide_trace)(data) actual_mean = samples.mean(0) actual_std = samples.std(0) assert_close(actual_mean, expected_mean, atol=0.05) @@ -1488,12 +1480,8 @@ def model(data): with torch.no_grad(): # Check moments. with pyro.plate("particles", 10000, dim=-2): - if isinstance(guide, poutine.messenger.Messenger): - with guide(data): - samples = model(data) - else: - guide_trace = poutine.trace(guide).get_trace(data) - samples = poutine.replay(model, guide_trace)(data) + guide_trace = poutine.trace(guide).get_trace(data) + samples = poutine.replay(model, guide_trace)(data) for name in ["x", "y"]: actual_mean = samples[name].mean(0).squeeze() actual_std = samples[name].std(0).squeeze() From 3d5250e133c2abe4cc4e90b88c6f11baa6414e7e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 30 Oct 2021 10:49:34 -0400 Subject: [PATCH 09/22] Support init_loc_fn, init_scale --- pyro/infer/autoguide/effect.py | 110 +++++++++++++++++++++------------ pyro/infer/effect_elbo.py | 14 ++--- 2 files changed, 76 insertions(+), 48 deletions(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 76e7f4147a..ff7ab73d0f 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -1,18 +1,19 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Union +from typing import Callable, Dict, Union import torch from torch.distributions import biject_to, constraints import pyro.distributions as dist import pyro.poutine as poutine -from pyro.distributions.torch_distribution import TorchDistribution +from pyro.distributions.distribution import Distribution from pyro.infer.effect_elbo import GuideMessenger from pyro.nn.module import PyroModule, PyroParam, pyro_method from pyro.poutine.runtime import get_plates +from .initialization import init_to_feasible, init_to_mean from .utils import deep_getattr, deep_setattr, helpful_support_errors @@ -26,10 +27,6 @@ class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): autoguides. """ - # Drop args for backwards compatibility with AutoGuide. - def __init__(self, model, *, init_loc_fn=None): - super().__init__(model) - def __call__(self, *args, **kwargs): self._outer_plates = get_plates() try: @@ -37,7 +34,7 @@ def __call__(self, *args, **kwargs): finally: del self._outer_plates - def _remove_outer_plates(self, value, event_dim): + def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor: """ Removes particle plates from initial values of parameters. """ @@ -101,15 +98,34 @@ def get_posterior(self, name, prior, upstream_values): return dist.Normal(loc, scale) # Fall back to mean field. return super().get_posterior(name, prior, upstream_values) + + :param callable model: A Pyro model. + :param callable init_loc_fn: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. + :param float init_scale: Initial scale for the standard deviation of each + (unconstrained transformed) latent variable. """ + def __init__( + self, + model: Callable, + *, + init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible), + init_scale: float = 0.1, + ): + if not isinstance(init_scale, float) or not (init_scale > 0): + raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) + super().__init__(model) + self.init_loc_fn = init_loc_fn + self._init_scale = init_scale + @pyro_method def get_posterior( self, name: str, - prior: TorchDistribution, + prior: Distribution, upstream_values: Dict[str, torch.Tensor], - ) -> Union[TorchDistribution, torch.Tensor]: + ) -> Union[Distribution, torch.Tensor]: with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) @@ -119,7 +135,7 @@ def get_posterior( ) return posterior - def _get_params(self, name, prior): + def _get_params(self, name: str, prior: Distribution): try: loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) @@ -129,24 +145,19 @@ def _get_params(self, name, prior): # Initialize. with poutine.block(), torch.no_grad(): - constrained = prior.sample().detach() - transform = biject_to(prior.support) - unconstrained = transform.inv(constrained) + with helpful_support_errors({"name": name, "fn": prior}): + transform = biject_to(prior.support) event_dim = transform.domain.event_dim - prototype = self._remove_outer_plates(unconstrained, event_dim) - deep_setattr( - self, - "locs." + name, - PyroParam(torch.zeros_like(prototype), event_dim=event_dim), - ) + constrained = self.init_loc_fn({"name": name, "fn": prior}).detach() + unconstrained = transform.inv(constrained) + init_loc = self._remove_outer_plates(unconstrained, event_dim) + init_scale = torch.full_like(init_loc, self._init_scale) + + deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim)) deep_setattr( self, "scales." + name, - PyroParam( - torch.ones_like(prototype), - constraint=constraints.positive, - event_dim=event_dim, - ), + PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim), ) return self._get_params(name, prior) @@ -174,15 +185,34 @@ def get_posterior(self, name, prior, upstream_values): return dist.Normal(loc, scale).to_event(prior.event_dim()) # Fall back to autoregressive. return super().get_posterior(name, prior, upstream_values) + + :param callable model: A Pyro model. + :param callable init_loc_fn: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. + :param float init_scale: Initial scale for the standard deviation of each + (unconstrained transformed) latent variable. """ + def __init__( + self, + model: Callable, + *, + init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible), + init_scale: float = 0.1, + ): + if not isinstance(init_scale, float) or not (init_scale > 0): + raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) + super().__init__(model) + self.init_loc_fn = init_loc_fn + self._init_scale = init_scale + @pyro_method def get_posterior( self, name: str, - prior: TorchDistribution, + prior: Distribution, upstream_values: Dict[str, torch.Tensor], - ) -> Union[TorchDistribution, torch.Tensor]: + ) -> Union[Distribution, torch.Tensor]: with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) @@ -194,7 +224,7 @@ def get_posterior( ) return posterior - def _get_params(self, name, prior): + def _get_params(self, name: str, prior: Distribution): try: loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) @@ -204,23 +234,21 @@ def _get_params(self, name, prior): # Initialize. with poutine.block(), torch.no_grad(): - constrained = prior.sample().detach() - transform = biject_to(prior.support) - unconstrained = transform.inv(constrained) + with helpful_support_errors({"name": name, "fn": prior}): + transform = biject_to(prior.support) event_dim = transform.domain.event_dim - prototype = self._remove_outer_plates(unconstrained, event_dim) - deep_setattr( - self, - "locs." + name, - PyroParam(torch.zeros_like(prototype), event_dim=event_dim), - ) + constrained = self.init_loc_fn({"name": name, "fn": prior}).detach() + unconstrained = transform.inv(constrained) + # Initialize the distribution to be an affine combination: + # init_scale * prior + (1 - init_scale) * init_loc + init_loc = self._remove_outer_plates(unconstrained, event_dim) + init_loc = init_loc * (1 - self._init_scale) + init_scale = torch.full_like(init_loc, self._init_scale) + + deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim)) deep_setattr( self, "scales." + name, - PyroParam( - torch.ones_like(prototype), - constraint=constraints.positive, - event_dim=event_dim, - ), + PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim), ) return self._get_params(name, prior) diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index f3821673c2..2ad827321d 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -3,12 +3,12 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Dict, Tuple, Union +from typing import Callable, Dict, Tuple, Union import torch import pyro.distributions as dist -from pyro.distributions.torch_distribution import TorchDistribution +from pyro.distributions.distribution import Distribution from pyro.infer.elbo import ELBO from pyro.infer.util import is_validation_enabled from pyro.poutine.trace_messenger import TraceMessenger @@ -27,7 +27,7 @@ class GuideMessenger(TraceMessenger, ABC): Derived classes must implement the :meth:`get_posterior` method. """ - def __init__(self, model): + def __init__(self, model: Callable): super().__init__() # Do not register model as submodule self._model = (model,) @@ -80,9 +80,9 @@ def _pyro_post_sample(self, msg): def get_posterior( self, name: str, - prior: TorchDistribution, + prior: Distribution, upstream_values: Dict[str, torch.Tensor], - ) -> Union[TorchDistribution, torch.Tensor]: + ) -> Union[Distribution, torch.Tensor]: """ Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution and values of upstream @@ -103,11 +103,11 @@ def get_posterior( :param str name: The name of the sample site to sample. :param prior: The prior distribution of this sample site (conditioned on upstream samples from the posterior). - :type prior: ~pyro.distributions.TorchDistribution + :type prior: ~pyro.distributions.Distribution :param dict upstream_values: :returns: A posterior distribution or sample from the posterior distribution. - :rtype: ~pyro.distributions.TorchDistribution or torch.Tensor + :rtype: ~pyro.distributions.Distribution or torch.Tensor """ raise NotImplementedError From 0ac9ae103e9d03bd7b201b6a832a5ddf8b1c82e3 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 30 Oct 2021 11:40:39 -0400 Subject: [PATCH 10/22] Add more tests --- pyro/infer/autoguide/effect.py | 41 +++++++++++++++++++++++++----- tests/infer/test_autoguide.py | 46 ++++++++++++++++++++++++++-------- 2 files changed, 70 insertions(+), 17 deletions(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index ff7ab73d0f..0ad6668611 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -34,15 +34,30 @@ def __call__(self, *args, **kwargs): finally: del self._outer_plates + def call(self, *args, **kwargs): + """ + Method that calls :meth:`forward` and returns parameter values of the + guide as a `tuple` instead of a `dict`, which is a requirement for + JIT tracing. Unlike :meth:`forward`, this method can be traced by + :func:`torch.jit.trace_module`. + + .. warning:: + This method may be removed once PyTorch JIT tracer starts accepting + `dict` as valid return types. See + `issue _`. + """ + result = self(*args, **kwargs) + return tuple(v for _, v in sorted(result.items())) + def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor: """ Removes particle plates from initial values of parameters. """ for f in self._outer_plates: - dim = -f.dim - event_dim + dim = f.dim - event_dim if -value.dim() <= dim: dim = dim + value.dim() - value = value[(slice(None),) * dim + slice(1)] + value = value[(slice(None),) * dim + (slice(1),)] for dim in range(value.dim() - event_dim): value = value.squeeze(0) return value @@ -118,6 +133,7 @@ def __init__( super().__init__(model) self.init_loc_fn = init_loc_fn self._init_scale = init_scale + self._computing_median = False @pyro_method def get_posterior( @@ -126,6 +142,9 @@ def get_posterior( prior: Distribution, upstream_values: Dict[str, torch.Tensor], ) -> Union[Distribution, torch.Tensor]: + if self._computing_median: + return self._get_posterior_median(name, prior) + with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) @@ -145,8 +164,7 @@ def _get_params(self, name: str, prior: Distribution): # Initialize. with poutine.block(), torch.no_grad(): - with helpful_support_errors({"name": name, "fn": prior}): - transform = biject_to(prior.support) + transform = biject_to(prior.support) event_dim = transform.domain.event_dim constrained = self.init_loc_fn({"name": name, "fn": prior}).detach() unconstrained = transform.inv(constrained) @@ -161,6 +179,18 @@ def _get_params(self, name: str, prior: Distribution): ) return self._get_params(name, prior) + def median(self, *args, **kwargs): + self._computing_median = True + try: + return self(*args, **kwargs) + finally: + self._computing_median = False + + def _get_posterior_median(self, name, prior): + transform = biject_to(prior.support) + loc, scale = self._get_params(name, prior) + return transform(loc) + class AutoRegressiveMessenger(AutoMessenger): """ @@ -234,8 +264,7 @@ def _get_params(self, name: str, prior: Distribution): # Initialize. with poutine.block(), torch.no_grad(): - with helpful_support_errors({"name": name, "fn": prior}): - transform = biject_to(prior.support) + transform = biject_to(prior.support) event_dim = transform.domain.event_dim constrained = self.init_loc_fn({"name": name, "fn": prior}).detach() unconstrained = transform.inv(constrained) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 1e9b404eab..c07be18d08 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -55,7 +55,12 @@ from pyro.optim import Adam, ClippedAdam from pyro.poutine.util import prune_subsample_sites from pyro.util import check_model_guide_match -from tests.common import assert_close, assert_equal, xfail_if_not_implemented +from tests.common import ( + assert_close, + assert_equal, + xfail_if_not_implemented, + xfail_param, +) AutoGaussianFunsor = pytest.param( AutoGaussianFunsor, marks=[pytest.mark.stage("funsor")] @@ -66,7 +71,7 @@ def promote_elbo(Guide, Elbo): """ Promote e.g. Trace_ELBO --> Effect_ELBO for AutoMessengers. """ - if issubclass(Guide, AutoMessenger): + if isinstance(Guide, type) and issubclass(Guide, AutoMessenger): if Elbo is Trace_ELBO: return Effect_ELBO if Elbo is JitTrace_ELBO: @@ -123,6 +128,8 @@ def model(): AutoLaplaceApproximation, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, + AutoRegressiveMessenger, ], ) def test_factor(auto_class, Elbo): @@ -135,6 +142,7 @@ def model(log_factor): pyro.sample("z3", dist.Normal(torch.zeros(3), torch.ones(3))) guide = auto_class(model) + Elbo = promote_elbo(auto_class, Elbo) elbo = Elbo(strict_enumeration_warning=False) elbo.loss(model, guide, torch.tensor(0.0)) # initialize param store @@ -377,6 +385,7 @@ def __init__(self, model): AutoStructured_median, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, ], ) @pytest.mark.parametrize("Elbo", [JitTrace_ELBO, JitTraceGraph_ELBO, JitTraceEnum_ELBO]) @@ -388,6 +397,7 @@ def model(): guide = auto_class(model) optim = Adam({"lr": 0.02, "betas": (0.8, 0.99)}) + Elbo = promote_elbo(auto_class, Elbo) elbo = Elbo( strict_enumeration_warning=False, num_particles=500, @@ -411,6 +421,13 @@ def model(): assert_equal(median["z"], torch.tensor(0.5), prec=0.1) +def serialization_model(): + pyro.sample("x", dist.Normal(0.0, 1.0)) + with pyro.plate("plate", 2): + pyro.sample("y", dist.LogNormal(0.0, 1.0)) + pyro.sample("z", dist.Beta(2.0, 2.0)) + + @pytest.mark.parametrize("jit", [False, True], ids=["nojit", "jit"]) @pytest.mark.parametrize( "auto_class", @@ -440,17 +457,12 @@ def model(): ), ], ), + AutoNormalMessenger, + xfail_param(AutoRegressiveMessenger, reason="jit does not support _Dirichlet"), ], ) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) -def test_serialization(auto_class, Elbo, jit): - def model(): - pyro.sample("x", dist.Normal(0.0, 1.0)) - with pyro.plate("plate", 2): - pyro.sample("y", dist.LogNormal(0.0, 1.0)) - pyro.sample("z", dist.Beta(2.0, 2.0)) - - guide = auto_class(model) +def test_serialization(auto_class, jit): + guide = auto_class(serialization_model) guide() if auto_class is AutoLaplaceApproximation: guide = guide.laplace_approximation() @@ -712,6 +724,7 @@ def model(): AutoLowRankMultivariateNormal, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, ], ) def test_init_loc_fn(auto_class): @@ -776,6 +789,7 @@ def model(): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), functools.partial(AutoNormal, init_loc_fn=init_to_median), functools.partial(AutoGaussian, init_loc_fn=init_to_median), + AutoNormalMessenger, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -793,6 +807,7 @@ def forward(self): model = Model() guide = auto_class(model) + Elbo = promote_elbo(auto_class, Elbo) infer = SVI( model, guide, Adam({"lr": 0.005}), Elbo(strict_enumeration_warning=False) ) @@ -864,6 +879,8 @@ def forward(self): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, + AutoRegressiveMessenger, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -894,6 +911,7 @@ def forward(self, x, y=None): x, y = torch.randn(N, D), torch.randn(N) model = LinearRegression() guide = auto_class(model) + Elbo = promote_elbo(auto_class, Elbo) infer = SVI( model, guide, Adam({"lr": 0.005}), Elbo(strict_enumeration_warning=False) ) @@ -1005,6 +1023,8 @@ def forward(self, x, y=None): AutoStructured, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, + AutoRegressiveMessenger, ], ) def test_replay_plates(auto_class, sample_shape): @@ -1191,6 +1211,8 @@ def model(): AutoLaplaceApproximation, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, + AutoRegressiveMessenger, ], ) @pytest.mark.parametrize( @@ -1224,6 +1246,8 @@ def model(): AutoLaplaceApproximation, AutoGaussian, AutoGaussianFunsor, + AutoNormalMessenger, + AutoRegressiveMessenger, ], ) @pytest.mark.parametrize( From 0637e0024ffa18fda56f0dd792224831e81dea2c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 30 Oct 2021 12:07:55 -0400 Subject: [PATCH 11/22] Fix jit tests --- docs/source/inference_algos.rst | 2 ++ pyro/infer/autoguide/effect.py | 9 +++++++-- pyro/infer/effect_elbo.py | 15 +++++++++++++++ tests/infer/test_autoguide.py | 23 +++++++++++++++++------ 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/docs/source/inference_algos.rst b/docs/source/inference_algos.rst index af7aa65a99..746e111a5d 100644 --- a/docs/source/inference_algos.rst +++ b/docs/source/inference_algos.rst @@ -59,6 +59,8 @@ ELBO .. automodule:: pyro.infer.effect_elbo :members: + :undoc-members: + :special-members: __call__ :show-inheritance: :member-order: bysource diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 0ad6668611..0cbda2e336 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -27,7 +27,12 @@ class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): autoguides. """ + @pyro_method def __call__(self, *args, **kwargs): + # Since this guide creates parameters lazily, we need to avoid batching + # those parameters by a particle plate, in case the first time this + # guide is called is inside a particle plate. We assume all plates + # outside the model are particle plates. self._outer_plates = get_plates() try: return super().__call__(*args, **kwargs) @@ -135,7 +140,6 @@ def __init__( self._init_scale = init_scale self._computing_median = False - @pyro_method def get_posterior( self, name: str, @@ -216,6 +220,8 @@ def get_posterior(self, name, prior, upstream_values): # Fall back to autoregressive. return super().get_posterior(name, prior, upstream_values) + .. warning:: This guide currently does not support ``JitEffect_ELBO``. + :param callable model: A Pyro model. :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. @@ -236,7 +242,6 @@ def __init__( self.init_loc_fn = init_loc_fn self._init_scale = init_scale - @pyro_method def get_posterior( self, name: str, diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index 2ad827321d..f2b908f491 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -37,6 +37,14 @@ def model(self): return self._model[0] def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + """ + Draws posterior samples from the guide and replays the model against + those samples. + + :returns: A dict mapping sample site name to sample value. + This includes latent, deterministic, and observed values. + :rtype: dict + """ self.args_kwargs = args, kwargs self.upstream_values = OrderedDict() try: @@ -113,6 +121,13 @@ def get_posterior( def get_traces(self) -> Tuple[Trace, Trace]: """ + This can be called after running :meth:`__call__` . In contrast to the + trace-replay pattern of generating a pair of traces, + :class:`GuideMessenger` interleaves model and guide computations, so + only a single ``guide(*args, **kwargs)`` call is needed to create both + traces. This function merely extract the relevant information from this + guide's ``.trace`` attribute. + :returns: a pair ``(model_trace, guide_trace)`` :rtype: tuple """ diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index c07be18d08..c1e85eff22 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -67,7 +67,7 @@ ) -def promote_elbo(Guide, Elbo): +def promote_elbo(Guide, Elbo, jit=True): """ Promote e.g. Trace_ELBO --> Effect_ELBO for AutoMessengers. """ @@ -75,8 +75,7 @@ def promote_elbo(Guide, Elbo): if Elbo is Trace_ELBO: return Effect_ELBO if Elbo is JitTrace_ELBO: - return Effect_ELBO # DEBUG work around "Trying to backward a second time" - return JitEffect_ELBO + return JitEffect_ELBO if jit else Effect_ELBO pytest.skip("not implemented") return Elbo @@ -1353,7 +1352,11 @@ def model(data): expected_loss = float(g.event_logsumexp() - g.condition(data).event_logsumexp()) guide = Guide(model) - Elbo = promote_elbo(Guide, JitTrace_ELBO) + Elbo = promote_elbo( + Guide, + JitTrace_ELBO, + jit=(Guide is not AutoRegressiveMessenger), # currently fails with jit + ) elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) @@ -1415,7 +1418,11 @@ def model(data): ) guide = Guide(model) - Elbo = promote_elbo(Guide, JitTrace_ELBO) + Elbo = promote_elbo( + Guide, + JitTrace_ELBO, + jit=(Guide is not AutoRegressiveMessenger), # currently fails with jit + ) elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) @@ -1491,7 +1498,11 @@ def model(data): expected_loss = float(g.event_logsumexp() - g_cond.event_logsumexp()) guide = Guide(model) - Elbo = promote_elbo(Guide, JitTrace_ELBO) + Elbo = promote_elbo( + Guide, + JitTrace_ELBO, + jit=(Guide is not AutoRegressiveMessenger), # currently fails with jit + ) elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) From bb092a984ff68d9506f8c56de71459fe26badb25 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 30 Oct 2021 12:32:30 -0400 Subject: [PATCH 12/22] Add more docs --- docs/source/infer.autoguide.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/infer.autoguide.rst b/docs/source/infer.autoguide.rst index f82915d595..0b85e65179 100644 --- a/docs/source/infer.autoguide.rst +++ b/docs/source/infer.autoguide.rst @@ -125,6 +125,14 @@ AutoGaussian :member-order: bysource :show-inheritance: +AutoMessenger +------------- +.. autoclass:: pyro.infer.autoguide.AutoMessenger + :members: + :undoc-members: + :member-order: bysource + :show-inheritance: + AutoNormalMessenger ------------------- .. autoclass:: pyro.infer.autoguide.AutoNormalMessenger From bc889e45b3dd4a1a75c30dfa445b3ac941c7bd63 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 30 Oct 2021 12:35:16 -0400 Subject: [PATCH 13/22] Revert unnecessary change --- tests/infer/test_autoguide.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index c1e85eff22..a7e4ca36c5 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1367,9 +1367,9 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): # Check moments. - with pyro.plate("particles", 10000, dim=-2): - guide_trace = poutine.trace(guide).get_trace(data) - samples = poutine.replay(model, guide_trace)(data) + vectorize = pyro.plate("particles", 10000, dim=-2) + guide_trace = poutine.trace(vectorize(guide)).get_trace(data) + samples = poutine.replay(vectorize(model), guide_trace)(data) actual_mean = samples.mean().item() actual_std = samples.std().item() assert_close(actual_mean, expected_mean, atol=0.05) @@ -1433,9 +1433,9 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): # Check moments. - with pyro.plate("particles", 10000, dim=-2): - guide_trace = poutine.trace(guide).get_trace(data) - samples = poutine.replay(model, guide_trace)(data) + vectorize = pyro.plate("particles", 10000, dim=-2) + guide_trace = poutine.trace(vectorize(guide)).get_trace(data) + samples = poutine.replay(vectorize(model), guide_trace)(data) actual_mean = samples.mean(0) actual_std = samples.std(0) assert_close(actual_mean, expected_mean, atol=0.05) @@ -1514,9 +1514,9 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): # Check moments. - with pyro.plate("particles", 10000, dim=-2): - guide_trace = poutine.trace(guide).get_trace(data) - samples = poutine.replay(model, guide_trace)(data) + vectorize = pyro.plate("particles", 10000, dim=-2) + guide_trace = poutine.trace(vectorize(guide)).get_trace(data) + samples = poutine.replay(vectorize(model), guide_trace)(data) for name in ["x", "y"]: actual_mean = samples[name].mean(0).squeeze() actual_std = samples[name].std(0).squeeze() From 56d26afb348daa5debf540f18a31b85de701250b Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 30 Oct 2021 18:44:22 -0400 Subject: [PATCH 14/22] Document relationship to AutoNormal --- pyro/infer/autoguide/effect.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 0cbda2e336..a5280c4159 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -75,8 +75,12 @@ class AutoNormalMessenger(AutoMessenger): similar. The mean-field posterior at any site is a transformed normal distribution. + This posterior is equivalent to :class:`~pyro.infer.autoguide.AutoNormal` + or :class:`~pyro.infer.autoguide.AutoDiagonalNormal`, but allows + customization via subclassing. - Derived classes may override particular sites and use this simply as a + Derived classes may override the :meth:`get_posterior` behavior at + particular sites and use the mean-field normal behavior simply as a default, e.g.:: def model(data): @@ -206,8 +210,9 @@ class AutoRegressiveMessenger(AutoMessenger): conditioned on upstream posterior samples. The affine transform operates in unconstrained space. This supports only continuous latent variables. - Derived classes may override particular sites and use this simply as a - default, e.g.:: + Derived classes may override the :meth:`get_posterior` behavior at + particular sites and use the regressive behavior simply as a default, + e.g.:: class MyGuideMessenger(AutoRegressiveMessenger): def get_posterior(self, name, prior, upstream_values): From 26bd9bfc03138b6a7fd01121bae4d23f2fbf880f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 31 Oct 2021 15:56:45 -0400 Subject: [PATCH 15/22] Support subsampling and amortization --- pyro/infer/autoguide/effect.py | 43 ++++++++++++++++++------- tests/infer/test_autoguide.py | 57 ++++++++++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 13 deletions(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index a5280c4159..1141335868 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -1,16 +1,16 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Union +from typing import Callable, Dict, Tuple, Union import torch from torch.distributions import biject_to, constraints import pyro.distributions as dist -import pyro.poutine as poutine from pyro.distributions.distribution import Distribution from pyro.infer.effect_elbo import GuideMessenger from pyro.nn.module import PyroModule, PyroParam, pyro_method +from pyro.ops.tensor_utils import periodic_repeat from pyro.poutine.runtime import get_plates from .initialization import init_to_feasible, init_to_mean @@ -25,7 +25,15 @@ class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): """ EXPERIMENTAL Base class for :class:`pyro.infer.effect_elbo.GuideMessenger` autoguides. + + :param callable model: A Pyro model. + :param tuple amortized_plates: A tuple of names of plates over which guide + parameters should be shared. This is useful for subsampling, where a + guide parameter can be shared across all plates. """ + def __init__(self, model: Callable, *, amortized_plates: Tuple[str, ...] = ()): + self.amortized_plates = amortized_plates + super().__init__(model) @pyro_method def __call__(self, *args, **kwargs): @@ -33,7 +41,7 @@ def __call__(self, *args, **kwargs): # those parameters by a particle plate, in case the first time this # guide is called is inside a particle plate. We assume all plates # outside the model are particle plates. - self._outer_plates = get_plates() + self._outer_plates = tuple(f.name for f in get_plates()) try: return super().__call__(*args, **kwargs) finally: @@ -54,17 +62,22 @@ def call(self, *args, **kwargs): result = self(*args, **kwargs) return tuple(v for _, v in sorted(result.items())) + @torch.no_grad() def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor: """ Removes particle plates from initial values of parameters. """ - for f in self._outer_plates: + for f in get_plates(): + full_size = getattr(f, "full_size", f.size) dim = f.dim - event_dim - if -value.dim() <= dim: - dim = dim + value.dim() - value = value[(slice(None),) * dim + (slice(1),)] + if f in self._outer_plates or f.name in self.amortized_plates: + if -value.dim() <= dim: + value = value.mean(dim, keepdim=True) + elif f.size != full_size: + value = periodic_repeat(value, full_size, dim).contiguous() for dim in range(value.dim() - event_dim): value = value.squeeze(0) + print(f"DEBUG {tuple(value.shape)}") return value @@ -128,6 +141,9 @@ def get_posterior(self, name, prior, upstream_values): See :ref:`autoguide-initialization` section for available functions. :param float init_scale: Initial scale for the standard deviation of each (unconstrained transformed) latent variable. + :param tuple amortized_plates: A tuple of names of plates over which guide + parameters should be shared. This is useful for subsampling, where a + guide parameter can be shared across all plates. """ def __init__( @@ -136,10 +152,11 @@ def __init__( *, init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible), init_scale: float = 0.1, + amortized_plates: Tuple[str, ...] = (), ): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) - super().__init__(model) + super().__init__(model, amortized_plates=amortized_plates) self.init_loc_fn = init_loc_fn self._init_scale = init_scale self._computing_median = False @@ -171,7 +188,7 @@ def _get_params(self, name: str, prior: Distribution): pass # Initialize. - with poutine.block(), torch.no_grad(): + with torch.no_grad(): transform = biject_to(prior.support) event_dim = transform.domain.event_dim constrained = self.init_loc_fn({"name": name, "fn": prior}).detach() @@ -232,6 +249,9 @@ def get_posterior(self, name, prior, upstream_values): See :ref:`autoguide-initialization` section for available functions. :param float init_scale: Initial scale for the standard deviation of each (unconstrained transformed) latent variable. + :param tuple amortized_plates: A tuple of names of plates over which guide + parameters should be shared. This is useful for subsampling, where a + guide parameter can be shared across all plates. """ def __init__( @@ -240,10 +260,11 @@ def __init__( *, init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible), init_scale: float = 0.1, + amortized_plates: Tuple[str, ...] = (), ): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) - super().__init__(model) + super().__init__(model, amortized_plates=amortized_plates) self.init_loc_fn = init_loc_fn self._init_scale = init_scale @@ -273,7 +294,7 @@ def _get_params(self, name: str, prior: Distribution): pass # Initialize. - with poutine.block(), torch.no_grad(): + with torch.no_grad(): transform = biject_to(prior.support) event_dim = transform.domain.event_dim constrained = self.init_loc_fn({"name": name, "fn": prior}).detach() diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index a7e4ca36c5..604591d946 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1048,7 +1048,15 @@ def model(): assert d.shape == sample_shape + (2, 3) -@pytest.mark.parametrize("auto_class", [AutoDelta, AutoNormal]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoNormal, + AutoNormalMessenger, + AutoRegressiveMessenger, + ], +) def test_subsample_model(auto_class): def model(x, y=None, batch_size=None): loc = pyro.param("loc", lambda: torch.tensor(0.0)) @@ -1074,11 +1082,56 @@ def model(x, y=None, batch_size=None): pyro.get_param_store().clear() pyro.set_rng_seed(123456789) - svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO()) + Elbo = promote_elbo(auto_class, Trace_ELBO) + svi = SVI(model, guide, Adam({"lr": 0.02}), Elbo()) for step in range(5): svi.step(x, y, batch_size=batch_size) +@pytest.mark.parametrize( + "auto_class", + [ + AutoNormalMessenger, + AutoRegressiveMessenger, + ], +) +def test_subsample_model_amortized(auto_class): + def model(x, y=None, batch_size=None): + loc = pyro.param("loc", lambda: torch.tensor(0.0)) + scale = pyro.param( + "scale", lambda: torch.tensor(1.0), constraint=constraints.positive + ) + with pyro.plate("batch", len(x), subsample_size=batch_size): + batch_x = pyro.subsample(x, event_dim=0) + batch_y = pyro.subsample(y, event_dim=0) if y is not None else None + mean = loc + scale * batch_x + sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0)) + return pyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y) + + guide1 = auto_class(model) + guide2 = auto_class(model, amortized_plates=("batch",)) + + full_size = 50 + batch_size = 20 + pyro.set_rng_seed(123456789) + x = torch.randn(full_size) + with torch.no_grad(): + y = model(x) + assert y.shape == x.shape + + for guide in guide1, guide2: + pyro.get_param_store().clear() + pyro.set_rng_seed(123456789) + svi = SVI(model, guide, Adam({"lr": 0.02}), Effect_ELBO()) + for step in range(5): + svi.step(x, y, batch_size=batch_size) + + params1 = dict(guide1.named_parameters()) + params2 = dict(guide2.named_parameters()) + assert params1["locs.sigma_unconstrained"].shape == (50,) + assert params2["locs.sigma_unconstrained"].shape == () + + @pytest.mark.parametrize("init_fn", [None, init_to_mean, init_to_median]) @pytest.mark.parametrize("auto_class", [AutoDelta, AutoNormal, AutoGuideList]) def test_subsample_guide(auto_class, init_fn): From 8556734edf88d887efbbdda2f2f6858ed507838d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 31 Oct 2021 16:02:07 -0400 Subject: [PATCH 16/22] lint --- pyro/infer/autoguide/effect.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 1141335868..e2b8ec56fc 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -31,6 +31,7 @@ class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): parameters should be shared. This is useful for subsampling, where a guide parameter can be shared across all plates. """ + def __init__(self, model: Callable, *, amortized_plates: Tuple[str, ...] = ()): self.amortized_plates = amortized_plates super().__init__(model) From 7715687663dd11f5ff2a154218cdb87053e253d5 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 31 Oct 2021 16:38:38 -0400 Subject: [PATCH 17/22] Remove debug statement --- pyro/infer/autoguide/effect.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index e2b8ec56fc..776fe22cce 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -78,7 +78,6 @@ def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Ten value = periodic_repeat(value, full_size, dim).contiguous() for dim in range(value.dim() - event_dim): value = value.squeeze(0) - print(f"DEBUG {tuple(value.shape)}") return value From 5c3e5bcb8edb2aac1d32e64230d5c35eac9650b1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 1 Nov 2021 12:24:32 -0400 Subject: [PATCH 18/22] Add a poutine.unwrap() helper function --- pyro/infer/effect_elbo.py | 5 +++-- pyro/poutine/__init__.py | 2 ++ pyro/poutine/messenger.py | 14 ++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pyro/infer/effect_elbo.py b/pyro/infer/effect_elbo.py index f2b908f491..469a923c68 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/infer/effect_elbo.py @@ -8,6 +8,7 @@ import torch import pyro.distributions as dist +import pyro.poutine as poutine from pyro.distributions.distribution import Distribution from pyro.infer.elbo import ELBO from pyro.infer.util import is_validation_enabled @@ -153,8 +154,8 @@ def _get_trace(self, model, guide, args, kwargs): # This differs from Trace_ELBO in that the guide is assumed to be an # effect handler. guide(*args, **kwargs) - while not isinstance(guide, GuideMessenger): - guide = guide.func.args[1] # unwrap plates + guide = poutine.unwrap(guide) + assert isinstance(guide, GuideMessenger) model_trace, guide_trace = guide.get_traces() # The rest follows pyro.infer.enum.get_importance_trace(). diff --git a/pyro/poutine/__init__.py b/pyro/poutine/__init__.py index aed31645c5..4cb8102d09 100644 --- a/pyro/poutine/__init__.py +++ b/pyro/poutine/__init__.py @@ -21,6 +21,7 @@ trace, uncondition, ) +from .messenger import unwrap from .runtime import NonlocalExit, get_mask from .trace_struct import Trace from .util import enable_validation, is_validation_enabled @@ -49,4 +50,5 @@ "trace", "Trace", "uncondition", + "unwrap", ] diff --git a/pyro/poutine/messenger.py b/pyro/poutine/messenger.py index bd610946a7..753eecaa36 100644 --- a/pyro/poutine/messenger.py +++ b/pyro/poutine/messenger.py @@ -24,6 +24,20 @@ def __get__(self, instance, owner): return partial(self.func, instance) +def unwrap(fn): + """ + Recursively unwraps poutines. + """ + while True: + if isinstance(fn, _bound_partial): + fn = fn.func + continue + if isinstance(fn, partial): + fn = fn.args[1] + continue + return fn + + class Messenger: """ Context manager class that modifies behavior From f03401a57edb6e711cc3833bd496be8610e3902e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 1 Nov 2021 13:02:28 -0400 Subject: [PATCH 19/22] Eliminate Effect_ELBO and mixin stuff --- pyro/infer/__init__.py | 4 - pyro/infer/autoguide/effect.py | 45 ++++----- pyro/infer/enum.py | 23 +++-- pyro/infer/traceenum_elbo.py | 2 + .../effect_elbo.py => poutine/guide.py} | 94 ++++--------------- pyro/poutine/messenger.py | 2 +- tests/infer/test_autoguide.py | 62 +++++------- 7 files changed, 80 insertions(+), 152 deletions(-) rename pyro/{infer/effect_elbo.py => poutine/guide.py} (61%) diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index f2985dba2e..9af5465402 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -4,7 +4,6 @@ from pyro.infer.abstract_infer import EmpiricalMarginal, TracePosterior, TracePredictive from pyro.infer.csis import CSIS from pyro.infer.discrete import infer_discrete -from pyro.infer.effect_elbo import Effect_ELBO, GuideMessenger, JitEffect_ELBO from pyro.infer.elbo import ELBO from pyro.infer.energy_distance import EnergyDistance from pyro.infer.enum import config_enumerate @@ -30,14 +29,11 @@ __all__ = [ "CSIS", "ELBO", - "Effect_ELBO", "EmpiricalMarginal", "EnergyDistance", - "GuideMessenger", "HMC", "IMQSteinKernel", "Importance", - "JitEffect_ELBO", "JitTraceEnum_ELBO", "JitTraceGraph_ELBO", "JitTraceMeanField_ELBO", diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 776fe22cce..982baa0f83 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -1,16 +1,16 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Tuple, Union +from typing import Callable, Tuple, Union import torch from torch.distributions import biject_to, constraints import pyro.distributions as dist from pyro.distributions.distribution import Distribution -from pyro.infer.effect_elbo import GuideMessenger from pyro.nn.module import PyroModule, PyroParam, pyro_method from pyro.ops.tensor_utils import periodic_repeat +from pyro.poutine.guide import GuideMessenger from pyro.poutine.runtime import get_plates from .initialization import init_to_feasible, init_to_mean @@ -23,8 +23,7 @@ class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)): class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): """ - EXPERIMENTAL Base class for :class:`pyro.infer.effect_elbo.GuideMessenger` - autoguides. + Base class for :class:`pyro.poutine.guide.GuideMessenger` autoguides. :param callable model: A Pyro model. :param tuple amortized_plates: A tuple of names of plates over which guide @@ -83,9 +82,8 @@ def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Ten class AutoNormalMessenger(AutoMessenger): """ - EXPERIMENTAL Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` , - intended for use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or - similar. + Automatic :class:`~pyro.poutine.guide.GuideMessenger` with mean-field + normal posterior. The mean-field posterior at any site is a transformed normal distribution. This posterior is equivalent to :class:`~pyro.infer.autoguide.AutoNormal` @@ -103,7 +101,7 @@ def model(data): pyro.sample("obs", dist.Normal(c, 1), obs=data) class MyGuideMessenger(AutoNormalMessenger): - def get_posterior(self, name, prior, upstream_values): + def get_posterior(self, name, prior): if name == "c": # Use a custom distribution at site c. bias = pyro.param("c_bias", lambda: torch.zeros(())) @@ -111,19 +109,19 @@ def get_posterior(self, name, prior, upstream_values): constraint=constraints.positive) scale = pyro.param("c_scale", lambda: torch.ones(()), constraint=constraints.positive) - a = upstream_values["a"] - b = upstream_values["b"] + a = self.upstream_value("a") + b = self.upstream_value("b") loc = bias + weight * (a + b) return dist.Normal(loc, scale) # Fall back to mean field. - return super().get_posterior(name, prior, upstream_values) + return super().get_posterior(name, prior) Note that above we manually computed ``loc = bias + weight * (a + b)``. Alternatively we could reuse the model-side computation by setting ``loc = bias + weight * prior.loc``:: class MyGuideMessenger_v2(AutoNormalMessenger): - def get_posterior(self, name, prior, upstream_values): + def get_posterior(self, name, prior): if name == "c": # Use a custom distribution at site c. bias = pyro.param("c_bias", lambda: torch.zeros(())) @@ -134,7 +132,7 @@ def get_posterior(self, name, prior, upstream_values): loc = bias + weight * prior.loc return dist.Normal(loc, scale) # Fall back to mean field. - return super().get_posterior(name, prior, upstream_values) + return super().get_posterior(name, prior) :param callable model: A Pyro model. :param callable init_loc_fn: A per-site initialization function. @@ -162,10 +160,7 @@ def __init__( self._computing_median = False def get_posterior( - self, - name: str, - prior: Distribution, - upstream_values: Dict[str, torch.Tensor], + self, name: str, prior: Distribution ) -> Union[Distribution, torch.Tensor]: if self._computing_median: return self._get_posterior_median(name, prior) @@ -219,9 +214,8 @@ def _get_posterior_median(self, name, prior): class AutoRegressiveMessenger(AutoMessenger): """ - EXPERIMENTAL Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` , - intended for use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or - similar. + Automatic :class:`~pyro.poutine.guide.GuideMessenger` with prior dependency + structure. The posterior at any site is a learned affine transform of the prior, conditioned on upstream posterior samples. The affine transform operates in @@ -232,7 +226,7 @@ class AutoRegressiveMessenger(AutoMessenger): e.g.:: class MyGuideMessenger(AutoRegressiveMessenger): - def get_posterior(self, name, prior, upstream_values): + def get_posterior(self, name, prior): if name == "x": # Use a custom distribution at site x. loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape())) @@ -240,9 +234,9 @@ def get_posterior(self, name, prior, upstream_values): constraint=constraints.positive return dist.Normal(loc, scale).to_event(prior.event_dim()) # Fall back to autoregressive. - return super().get_posterior(name, prior, upstream_values) + return super().get_posterior(name, prior) - .. warning:: This guide currently does not support ``JitEffect_ELBO``. + .. warning:: This guide currently does not support jit-based elbos. :param callable model: A Pyro model. :param callable init_loc_fn: A per-site initialization function. @@ -269,10 +263,7 @@ def __init__( self._init_scale = init_scale def get_posterior( - self, - name: str, - prior: Distribution, - upstream_values: Dict[str, torch.Tensor], + self, name: str, prior: Distribution ) -> Union[Distribution, torch.Tensor]: with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) diff --git a/pyro/infer/enum.py b/pyro/infer/enum.py index 190762fb38..3a07e86899 100644 --- a/pyro/infer/enum.py +++ b/pyro/infer/enum.py @@ -49,12 +49,23 @@ def get_importance_trace( Returns a single trace from the guide, which can optionally be detached, and the model that is run against it. """ - guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs) - if detach: - guide_trace.detach_() - model_trace = poutine.trace( - poutine.replay(model, trace=guide_trace), graph_type=graph_type - ).get_trace(*args, **kwargs) + # Dispatch between callables vs GuideMessengers. + unwrapped_guide = poutine.unwrap(guide) + if isinstance(unwrapped_guide, poutine.messenger.Messenger): + if detach: + raise NotImplementedError("GuideMessenger does not support detach") + guide(*args, **kwargs) + model_trace, guide_trace = unwrapped_guide.get_traces() + else: + guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace( + *args, **kwargs + ) + if detach: + guide_trace.detach_() + model_trace = poutine.trace( + poutine.replay(model, trace=guide_trace), graph_type=graph_type + ).get_trace(*args, **kwargs) + if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, max_plate_nesting) diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 019c7f821e..1141718381 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -368,6 +368,8 @@ def _get_traces(self, model, guide, args, kwargs): Runs the guide and runs the model against the guide with the result packaged as a trace generator. """ + if isinstance(poutine.unwrap(guide), poutine.messenger.Messenger): + raise NotImplementedError("TraceEnum_ELBO does not support GuideMessenger") if self.max_plate_nesting == float("inf"): self._guess_max_plate_nesting(model, guide, args, kwargs) if self.vectorize_particles: diff --git a/pyro/infer/effect_elbo.py b/pyro/poutine/guide.py similarity index 61% rename from pyro/infer/effect_elbo.py rename to pyro/poutine/guide.py index 469a923c68..2d543f5a5b 100644 --- a/pyro/infer/effect_elbo.py +++ b/pyro/poutine/guide.py @@ -2,28 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from collections import OrderedDict from typing import Callable, Dict, Tuple, Union import torch import pyro.distributions as dist -import pyro.poutine as poutine from pyro.distributions.distribution import Distribution -from pyro.infer.elbo import ELBO -from pyro.infer.util import is_validation_enabled -from pyro.poutine.trace_messenger import TraceMessenger -from pyro.poutine.trace_struct import Trace -from pyro.poutine.util import prune_subsample_sites, site_is_subsample -from pyro.util import check_model_guide_match, check_site_shape -from .trace_elbo import JitTrace_ELBO, Trace_ELBO +from .trace_messenger import TraceMessenger +from .trace_struct import Trace +from .util import prune_subsample_sites, site_is_subsample class GuideMessenger(TraceMessenger, ABC): """ - EXPERIMENTAL Abstract base class for effect-based guides for use in - :class:`Effect_ELBO` and similar. + EXPERIMENTAL Abstract base class for effect-based guides. Derived classes must implement the :meth:`get_posterior` method. """ @@ -37,6 +30,9 @@ def __init__(self, model: Callable): def model(self): return self._model[0] + def upstream_value(self, name: str) -> torch.Tensor: + return self.trace.nodes[name]["value"] + def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: """ Draws posterior samples from the guide and replays the model against @@ -47,13 +43,11 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: :rtype: dict """ self.args_kwargs = args, kwargs - self.upstream_values = OrderedDict() try: with self: self.model(*args, **kwargs) finally: del self.args_kwargs - del self.upstream_values model_trace, guide_trace = self.get_traces() samples = { @@ -68,7 +62,7 @@ def _pyro_sample(self, msg): return prior = msg["fn"] msg["infer"]["prior"] = prior - posterior = self.get_posterior(msg["name"], prior, self.upstream_values) + posterior = self.get_posterior(msg["name"], prior) if isinstance(posterior, torch.Tensor): posterior = dist.Delta(posterior, event_dim=prior.event_dim) if posterior.batch_shape != prior.batch_shape: @@ -76,21 +70,15 @@ def _pyro_sample(self, msg): msg["fn"] = posterior def _pyro_post_sample(self, msg): - self.upstream_values[msg["name"]] = msg["value"] - # Manually apply outer plates. prior = msg["infer"].get("prior") if prior is not None and prior.batch_shape != msg["fn"].batch_shape: msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) - return super()._pyro_post_sample(msg) @abstractmethod def get_posterior( - self, - name: str, - prior: Distribution, - upstream_values: Dict[str, torch.Tensor], + self, name: str, prior: Distribution ) -> Union[Distribution, torch.Tensor]: """ Abstract method to compute a posterior distribution or sample a @@ -103,17 +91,18 @@ def get_posterior( Implementations may access further information for computations: - - ``args, kwargs = self.args_kwargs`` are the inputs to the model, and - may be useful for amortization. + - ``self.upstream_value(name)`` returns the value of an upstream sample + or deterministic site. - ``self.trace`` is a trace of upstream sites, and may be useful for other information such as ``self.trace.nodes["my_site"]["fn"]`` or ``self.trace.nodes["my_site"]["cond_indep_stack"]`` . + - ``args, kwargs = self.args_kwargs`` are the inputs to the model, and + may be useful for amortization. :param str name: The name of the sample site to sample. :param prior: The prior distribution of this sample site (conditioned on upstream samples from the posterior). :type prior: ~pyro.distributions.Distribution - :param dict upstream_values: :returns: A posterior distribution or sample from the posterior distribution. :rtype: ~pyro.distributions.Distribution or torch.Tensor @@ -122,8 +111,10 @@ def get_posterior( def get_traces(self) -> Tuple[Trace, Trace]: """ - This can be called after running :meth:`__call__` . In contrast to the - trace-replay pattern of generating a pair of traces, + This can be called after running :meth:`__call__` to extract a pair of + traces. + + In contrast to the trace-replay pattern of generating a pair of traces, :class:`GuideMessenger` interleaves model and guide computations, so only a single ``guide(*args, **kwargs)`` call is needed to create both traces. This function merely extract the relevant information from this @@ -142,54 +133,3 @@ def get_traces(self) -> Tuple[Trace, Trace]: model_site["fn"] = guide_site["infer"]["prior"] model_trace.nodes[name] = model_site return model_trace, guide_trace - - -class EffectMixin(ELBO): - """ - EXPERIMENTAL Mixin class to turn a trace-based ELBO implementation into an - effect-based implementation. - """ - - def _get_trace(self, model, guide, args, kwargs): - # This differs from Trace_ELBO in that the guide is assumed to be an - # effect handler. - guide(*args, **kwargs) - guide = poutine.unwrap(guide) - assert isinstance(guide, GuideMessenger) - model_trace, guide_trace = guide.get_traces() - - # The rest follows pyro.infer.enum.get_importance_trace(). - max_plate_nesting = self.max_plate_nesting - if is_validation_enabled(): - check_model_guide_match(model_trace, guide_trace, max_plate_nesting) - model_trace.compute_log_prob() - guide_trace.compute_score_parts() - if is_validation_enabled(): - for site in model_trace.nodes.values(): - if site["type"] == "sample": - check_site_shape(site, max_plate_nesting) - for site in guide_trace.nodes.values(): - if site["type"] == "sample": - check_site_shape(site, max_plate_nesting) - - return model_trace, guide_trace - - -class Effect_ELBO(EffectMixin, Trace_ELBO): - """ - EXPERIMENTAL Similar to :class:`~pyro.infer.trace_elbo.Trace_ELBO` but - supporting guides that are :class:`GuideMessenger` s rather than traceable - functions. - """ - - pass - - -class JitEffect_ELBO(EffectMixin, JitTrace_ELBO): - """ - EXPERIMENTAL Similar to :class:`~pyro.infer.trace_elbo.JitTrace_ELBO` but - supporting guides that are :class:`GuideMessenger` s rather than traceable - functions. - """ - - pass diff --git a/pyro/poutine/messenger.py b/pyro/poutine/messenger.py index 753eecaa36..2fe31feba4 100644 --- a/pyro/poutine/messenger.py +++ b/pyro/poutine/messenger.py @@ -33,7 +33,7 @@ def unwrap(fn): fn = fn.func continue if isinstance(fn, partial): - fn = fn.args[1] + fn = fn.args[1] # extract from partial(handler, fn) continue return fn diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 604591d946..8e39528bff 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -16,8 +16,6 @@ import pyro.poutine as poutine from pyro.infer import ( SVI, - Effect_ELBO, - JitEffect_ELBO, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, @@ -37,7 +35,6 @@ AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, - AutoMessenger, AutoMultivariateNormal, AutoNormal, AutoNormalMessenger, @@ -67,17 +64,11 @@ ) -def promote_elbo(Guide, Elbo, jit=True): - """ - Promote e.g. Trace_ELBO --> Effect_ELBO for AutoMessengers. - """ - if isinstance(Guide, type) and issubclass(Guide, AutoMessenger): - if Elbo is Trace_ELBO: - return Effect_ELBO - if Elbo is JitTrace_ELBO: - return JitEffect_ELBO if jit else Effect_ELBO - pytest.skip("not implemented") - return Elbo +def xfail_messenger(auto_class, Elbo): + if isinstance(auto_class, type): + if issubclass(auto_class, poutine.messenger.Messenger): + if Elbo in (TraceEnum_ELBO, JitTraceEnum_ELBO): + pytest.xfail(reason="not implemented") @pytest.mark.parametrize( @@ -132,6 +123,8 @@ def model(): ], ) def test_factor(auto_class, Elbo): + xfail_messenger(auto_class, Elbo) + def model(log_factor): pyro.sample("z1", dist.Normal(0.0, 1.0)) pyro.factor("f1", log_factor) @@ -141,7 +134,6 @@ def model(log_factor): pyro.sample("z3", dist.Normal(torch.zeros(3), torch.ones(3))) guide = auto_class(model) - Elbo = promote_elbo(auto_class, Elbo) elbo = Elbo(strict_enumeration_warning=False) elbo.loss(model, guide, torch.tensor(0.0)) # initialize param store @@ -233,6 +225,8 @@ def dependency_z6_z5(z5): ) @pytest.mark.filterwarnings("ignore::FutureWarning") def test_shapes(auto_class, init_loc_fn, Elbo, num_particles): + xfail_messenger(auto_class, Elbo) + def model(): pyro.sample("z1", dist.Normal(0.0, 1.0)) pyro.sample("z2", dist.Normal(torch.zeros(2), torch.ones(2)).to_event(1)) @@ -245,7 +239,6 @@ def model(): ) pyro.sample("z7", dist.LKJCholesky(2, torch.tensor(1.0))) - Elbo = promote_elbo(auto_class, Elbo) guide = auto_class(model, init_loc_fn=init_loc_fn) elbo = Elbo( num_particles=num_particles, @@ -389,6 +382,8 @@ def __init__(self, model): ) @pytest.mark.parametrize("Elbo", [JitTrace_ELBO, JitTraceGraph_ELBO, JitTraceEnum_ELBO]) def test_median(auto_class, Elbo): + xfail_messenger(auto_class, Elbo) + def model(): pyro.sample("x", dist.Normal(0.0, 1.0)) pyro.sample("y", dist.LogNormal(0.0, 1.0)) @@ -396,7 +391,6 @@ def model(): guide = auto_class(model) optim = Adam({"lr": 0.02, "betas": (0.8, 0.99)}) - Elbo = promote_elbo(auto_class, Elbo) elbo = Elbo( strict_enumeration_warning=False, num_particles=500, @@ -793,6 +787,8 @@ def model(): ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_median_module(auto_class, Elbo): + xfail_messenger(auto_class, Elbo) + class Model(PyroModule): def __init__(self): super().__init__() @@ -806,7 +802,6 @@ def forward(self): model = Model() guide = auto_class(model) - Elbo = promote_elbo(auto_class, Elbo) infer = SVI( model, guide, Adam({"lr": 0.005}), Elbo(strict_enumeration_warning=False) ) @@ -884,6 +879,7 @@ def forward(self): ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_linear_regression_smoke(auto_class, Elbo): + xfail_messenger(auto_class, Elbo) N, D = 10, 3 class RandomLinear(nn.Linear, PyroModule): @@ -910,7 +906,6 @@ def forward(self, x, y=None): x, y = torch.randn(N, D), torch.randn(N) model = LinearRegression() guide = auto_class(model) - Elbo = promote_elbo(auto_class, Elbo) infer = SVI( model, guide, Adam({"lr": 0.005}), Elbo(strict_enumeration_warning=False) ) @@ -1082,8 +1077,7 @@ def model(x, y=None, batch_size=None): pyro.get_param_store().clear() pyro.set_rng_seed(123456789) - Elbo = promote_elbo(auto_class, Trace_ELBO) - svi = SVI(model, guide, Adam({"lr": 0.02}), Elbo()) + svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO()) for step in range(5): svi.step(x, y, batch_size=batch_size) @@ -1122,7 +1116,7 @@ def model(x, y=None, batch_size=None): for guide in guide1, guide2: pyro.get_param_store().clear() pyro.set_rng_seed(123456789) - svi = SVI(model, guide, Adam({"lr": 0.02}), Effect_ELBO()) + svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO()) for step in range(5): svi.step(x, y, batch_size=batch_size) @@ -1405,11 +1399,9 @@ def model(data): expected_loss = float(g.event_logsumexp() - g.condition(data).event_logsumexp()) guide = Guide(model) - Elbo = promote_elbo( - Guide, - JitTrace_ELBO, - jit=(Guide is not AutoRegressiveMessenger), # currently fails with jit - ) + Elbo = JitTrace_ELBO + if Guide is AutoRegressiveMessenger: + Elbo = Trace_ELBO # currently fails with jit elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) @@ -1471,11 +1463,9 @@ def model(data): ) guide = Guide(model) - Elbo = promote_elbo( - Guide, - JitTrace_ELBO, - jit=(Guide is not AutoRegressiveMessenger), # currently fails with jit - ) + Elbo = JitTrace_ELBO + if Guide is AutoRegressiveMessenger: + Elbo = Trace_ELBO # currently fails with jit elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) @@ -1551,11 +1541,9 @@ def model(data): expected_loss = float(g.event_logsumexp() - g_cond.event_logsumexp()) guide = Guide(model) - Elbo = promote_elbo( - Guide, - JitTrace_ELBO, - jit=(Guide is not AutoRegressiveMessenger), # currently fails with jit - ) + Elbo = JitTrace_ELBO + if Guide is AutoRegressiveMessenger: + Elbo = Trace_ELBO # currently fails with jit elbo = Elbo(num_particles=100, vectorize_particles=True, ignore_jit_warnings=True) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) From 8749c96982ce1dbc32ab28def405a96937856265 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 1 Nov 2021 13:17:25 -0400 Subject: [PATCH 20/22] Fix docs --- docs/source/inference_algos.rst | 7 ------- docs/source/pyro.poutine.txt | 10 ++++++++++ pyro/infer/autoguide/effect.py | 9 ++++----- pyro/poutine/guide.py | 22 ++++++++++++++-------- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/docs/source/inference_algos.rst b/docs/source/inference_algos.rst index 746e111a5d..153028d1f7 100644 --- a/docs/source/inference_algos.rst +++ b/docs/source/inference_algos.rst @@ -57,13 +57,6 @@ ELBO :show-inheritance: :member-order: bysource -.. automodule:: pyro.infer.effect_elbo - :members: - :undoc-members: - :special-members: __call__ - :show-inheritance: - :member-order: bysource - Importance ---------- diff --git a/docs/source/pyro.poutine.txt b/docs/source/pyro.poutine.txt index 207a5cfb8b..a37cac6b77 100644 --- a/docs/source/pyro.poutine.txt +++ b/docs/source/pyro.poutine.txt @@ -173,3 +173,13 @@ ____________________ :members: :undoc-members: :show-inheritance: + +GuideMessenger +______________ + +.. automodule:: pyro.poutine.guide + :members: + :undoc-members: + :special-members: __call__ + :member-order: bysource + :show-inheritance: diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 982baa0f83..8c12e72197 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -23,7 +23,7 @@ class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)): class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta): """ - Base class for :class:`pyro.poutine.guide.GuideMessenger` autoguides. + Base class for :class:`~pyro.poutine.guide.GuideMessenger` autoguides. :param callable model: A Pyro model. :param tuple amortized_plates: A tuple of names of plates over which guide @@ -82,8 +82,7 @@ def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Ten class AutoNormalMessenger(AutoMessenger): """ - Automatic :class:`~pyro.poutine.guide.GuideMessenger` with mean-field - normal posterior. + :class:`AutoMessenger` with mean-field normal posterior. The mean-field posterior at any site is a transformed normal distribution. This posterior is equivalent to :class:`~pyro.infer.autoguide.AutoNormal` @@ -214,8 +213,8 @@ def _get_posterior_median(self, name, prior): class AutoRegressiveMessenger(AutoMessenger): """ - Automatic :class:`~pyro.poutine.guide.GuideMessenger` with prior dependency - structure. + :class:`AutoMessenger` with recursively affine-transformed priors using + prior dependency structure. The posterior at any site is a learned affine transform of the prior, conditioned on upstream posterior samples. The affine transform operates in diff --git a/pyro/poutine/guide.py b/pyro/poutine/guide.py index 2d543f5a5b..a1cc27b62a 100644 --- a/pyro/poutine/guide.py +++ b/pyro/poutine/guide.py @@ -16,7 +16,7 @@ class GuideMessenger(TraceMessenger, ABC): """ - EXPERIMENTAL Abstract base class for effect-based guides. + Abstract base class for effect-based guides. Derived classes must implement the :meth:`get_posterior` method. """ @@ -30,9 +30,6 @@ def __init__(self, model: Callable): def model(self): return self._model[0] - def upstream_value(self, name: str) -> torch.Tensor: - return self.trace.nodes[name]["value"] - def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: """ Draws posterior samples from the guide and replays the model against @@ -82,8 +79,8 @@ def get_posterior( ) -> Union[Distribution, torch.Tensor]: """ Abstract method to compute a posterior distribution or sample a - posterior value given a prior distribution and values of upstream - sample sites. + posterior value given a prior distribution conditioned on upstream + posterior samples. Implementations may use ``pyro.param`` and ``pyro.sample`` inside this function, but ``pyro.sample`` statements should set @@ -91,8 +88,8 @@ def get_posterior( Implementations may access further information for computations: - - ``self.upstream_value(name)`` returns the value of an upstream sample - or deterministic site. + - ``value = self.upstream_value(name)`` is the value of an upstream + sample or deterministic site. - ``self.trace`` is a trace of upstream sites, and may be useful for other information such as ``self.trace.nodes["my_site"]["fn"]`` or ``self.trace.nodes["my_site"]["cond_indep_stack"]`` . @@ -109,6 +106,15 @@ def get_posterior( """ raise NotImplementedError + def upstream_value(self, name: str) -> torch.Tensor: + """ + For use in :meth:`get_posterior` . + + :returns: The value of an upstream sample or deterministic site + :rtype: torch.Tensor + """ + return self.trace.nodes[name]["value"] + def get_traces(self) -> Tuple[Trace, Trace]: """ This can be called after running :meth:`__call__` to extract a pair of From c161ebb9fd871f577f6c5a87e2723a55de27ad93 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 1 Nov 2021 13:30:31 -0400 Subject: [PATCH 21/22] Revert unnecessary change --- pyro/infer/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 9af5465402..5485b7476e 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -27,13 +27,17 @@ from pyro.infer.util import enable_validation, is_validation_enabled __all__ = [ + "config_enumerate", "CSIS", + "enable_validation", + "is_validation_enabled", "ELBO", "EmpiricalMarginal", "EnergyDistance", "HMC", - "IMQSteinKernel", "Importance", + "IMQSteinKernel", + "infer_discrete", "JitTraceEnum_ELBO", "JitTraceGraph_ELBO", "JitTraceMeanField_ELBO", @@ -47,17 +51,13 @@ "SMCFilter", "SVGD", "SVI", + "TraceTMC_ELBO", "TraceEnum_ELBO", "TraceGraph_ELBO", "TraceMeanField_ELBO", "TracePosterior", "TracePredictive", - "TraceTMC_ELBO", "TraceTailAdaptive_ELBO", "Trace_ELBO", "Trace_MMD", - "config_enumerate", - "enable_validation", - "infer_discrete", - "is_validation_enabled", ] From 872b82d430c044dbfbe9de00103c382068e7472e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 1 Nov 2021 14:54:11 -0400 Subject: [PATCH 22/22] Fix poutine.unwrap() --- pyro/poutine/messenger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/poutine/messenger.py b/pyro/poutine/messenger.py index 2fe31feba4..457da3aba9 100644 --- a/pyro/poutine/messenger.py +++ b/pyro/poutine/messenger.py @@ -32,7 +32,7 @@ def unwrap(fn): if isinstance(fn, _bound_partial): fn = fn.func continue - if isinstance(fn, partial): + if isinstance(fn, partial) and len(fn.args) >= 2: fn = fn.args[1] # extract from partial(handler, fn) continue return fn