diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 6934bd29fe..3a6a37ce5b 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -12,7 +12,7 @@ from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS from pyro.infer.mcmc.rwkernel import RandomWalkKernel -from pyro.infer.predictive import Predictive, WeighedPredictive +from pyro.infer.predictive import MHResampler, Predictive, WeighedPredictive from pyro.infer.renyi_elbo import RenyiELBO from pyro.infer.rws import ReweightedWakeSleep from pyro.infer.smcfilter import SMCFilter @@ -44,6 +44,7 @@ "JitTraceMeanField_ELBO", "JitTrace_ELBO", "MCMC", + "MHResampler", "NUTS", "Predictive", "RandomWalkKernel", diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index ea89aff5e5..e30099c85e 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from dataclasses import dataclass +from dataclasses import dataclass, fields from functools import reduce -from typing import List, Union +from typing import Callable, List, Union import torch import pyro import pyro.poutine as poutine from pyro.infer.importance import LogWeightsMixin -from pyro.infer.util import plate_log_prob_sum +from pyro.infer.util import CloneMixin, plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -320,7 +320,7 @@ def get_vectorized_trace(self, *args, **kwargs): @dataclass(frozen=True, eq=False) -class WeighedPredictiveResults(LogWeightsMixin): +class WeighedPredictiveResults(LogWeightsMixin, CloneMixin): """ Return value of call to instance of :class:`WeighedPredictive`. """ @@ -450,3 +450,188 @@ def forward(self, *args, **kwargs): guide_log_prob=guide_log_prob, model_log_prob=model_log_prob, ) + + +class MHResampler(torch.nn.Module): + r""" + Resampler for weighed samples that generates equally weighed samples from the distribution + specified by the weighed samples ``sampler``. + + The resampling is based on the Metropolis-Hastings algorithm. + Given an initial sample :math:`x` subsequent samples are generated by: + + - Sampling from the ``guide`` a new sample candidate :math:`x'` with probability :math:`g(x')`. + - Calculate an acceptance probability + :math:`A(x', x) = \min\left(1, \frac{P(x')}{P(x)} \frac{g(x)}{g(x')}\right)` + with :math:`P` being the ``model``. + - With probability :math:`A(x', x)` accept the new sample candidate :math:`x'` + as the next sample, otherwise set the current sample :math:`x` as the next sample. + + The above is the Metropolis-Hastings algorithm with the new sample candidate + proposal distribution being equal to the ``guide`` and independent of the + current sample such that :math:`g(x')=g(x' \mid x)`. + + :param callable sampler: When called returns :class:`WeighedPredictiveResults`. + :param slice source_samples_slice: Select source samples for storage (default is `slice(0)`, i.e. none). + :param slice stored_samples_slice: Select output samples for storage (default is `slice(0)`, i.e. none). + + The typical use case of :class:`MHResampler` would be to convert weighed samples + generated by :class:`WeighedPredictive` into equally weighed samples from the target distribution. + Each time an instance of :class:`MHResampler` is called it returns a new set of samples, with the + samples generated by the first call being distributed according to the ``guide``, and with each + subsequent call the distribution of the samples becomes closer to that of the posterior predictive + disdtribution. It might take some experimentation in order to find out in each case how many times one would + need to call an instance of :class:`MHResampler` in order to be close enough to the posterior + predictive distribution. + + Example:: + + def model(): + ... + + def guide(): + ... + + def conditioned_model(): + ... + + # Fit guide + elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) + svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=3.0)), elbo) + for i in range(num_svi_steps): + svi.step() + + # Create callable that returns weighed samples + posterior_predictive = WeighedPredictive(model, + guide=guide, + num_samples=num_samples, + parallel=parallel, + return_sites=["_RETURN"]) + + prob = 0.95 + + weighed_samples = posterior_predictive(model_guide=conditioned_model) + # Calculate quantile directly from weighed samples + weighed_samples_quantile = weighed_quantile(weighed_samples.samples['_RETURN'], + [prob], + weighed_samples.log_weights)[0] + + resampler = MHResampler(posterior_predictive) + num_mh_steps = 10 + for mh_step_count in range(num_mh_steps): + resampled_weighed_samples = resampler(model_guide=conditioned_model) + # Calculate quantile from resampled weighed samples (samples are equally weighed) + resampled_weighed_samples_quantile = quantile(resampled_weighed_samples.samples[`_RETURN`], + [prob])[0] + + # Quantiles calculated using both methods should be identical + assert_close(weighed_samples_quantile, resampled_weighed_samples_quantile, rtol=0.01) + + .. _mhsampler-behavior: + + **Notes on Sampler Behavior:** + + - In case the ``guide`` perfectly tracks the ``model`` this sampler will do nothing + as the acceptance probability :math:`A(x', x)` will always be one. + - Furtheremore, if the guide is approximately separable, i.e. :math:`g(z_A, z_B) \approx g_A(z_A) g_B(z_B)`, + with :math:`g_A(z_A)` pefectly tracking the ``model`` and :math:`g_B(z_B)` poorly tracking the ``model``, + quantiles of :math:`z_A` calculated from samples taken from :class:`MHResampler`, will have much lower + variance then quantiles of :math:`z_A` calculated by using :any:`weighed_quantile`, as the effective sample size + of the calculation using :any:`weighed_quantile` will be low due to :math:`g_B(z_B)` poorly tracking + the ``model``, whereas when using :class:`MHResampler` the poor ``model`` tracking of :math:`g_B(z_B)` has + negligible affect on the effective sample size of :math:`z_A` samples. + """ + + def __init__( + self, + sampler: Callable, + source_samples_slice: slice = slice(0), + stored_samples_slice: slice = slice(0), + ): + super().__init__() + self.sampler = sampler + self.samples = None + self.transition_count = torch.tensor(0, dtype=torch.long) + self.source_samples = [] + self.source_samples_slice = source_samples_slice + self.stored_samples = [] + self.stored_samples_slice = stored_samples_slice + + def forward(self, *args, **kwargs): + """ + Perform single resampling step. + Returns :class:`WeighedPredictiveResults` + """ + with torch.no_grad(): + new_samples = self.sampler(*args, **kwargs) + # Store samples + self.source_samples.append(new_samples) + self.source_samples = self.source_samples[self.source_samples_slice] + if self.samples is None: + # First set of samples + self.samples = new_samples.clone() + self.transition_count = torch.zeros_like( + new_samples.log_weights, dtype=torch.long + ) + else: + # Apply Metropolis-Hastings algorithm + prob = torch.clamp( + new_samples.log_weights - self.samples.log_weights, max=0.0 + ).exp() + idx = torch.rand(*prob.shape) <= prob + self.transition_count[idx] += 1 + for field_desc in fields(self.samples): + field, new_field = getattr(self.samples, field_desc.name), getattr( + new_samples, field_desc.name + ) + if isinstance(field, dict): + for key in field: + field[key][idx] = new_field[key][idx] + else: + field[idx] = new_field[idx] + self.stored_samples.append(self.samples.clone()) + self.stored_samples = self.stored_samples[self.stored_samples_slice] + return self.samples + + def get_min_sample_transition_count(self): + """ + Return transition count of sample with minimal amount of transitions. + """ + return self.transition_count.min() + + def get_total_transition_count(self): + """ + Return total number of transitions. + """ + return self.transition_count.sum() + + def get_source_samples(self): + """ + Return source samples that were the input to the Metropolis-Hastings algorithm. + """ + return self.get_samples(self.source_samples) + + def get_stored_samples(self): + """ + Return stored samples that were the output of the Metropolis-Hastings algorithm. + """ + return self.get_samples(self.stored_samples) + + def get_samples(self, samples): + """ + Return samples that were sampled during execution of the Metropolis-Hastings algorithm. + """ + retval = dict() + for field_desc in fields(self.samples): + field_name, value = field_desc.name, getattr(self.samples, field_desc.name) + if isinstance(value, dict): + retval[field_name] = dict() + for key in value: + retval[field_name][key] = torch.cat( + [getattr(sample, field_name)[key] for sample in samples] + ) + else: + retval[field_name] = torch.cat( + [getattr(sample, field_name) for sample in samples] + ) + return self.samples.__class__(**retval) diff --git a/pyro/infer/util.py b/pyro/infer/util.py index 13e1d9e12f..2efbb60ed8 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -5,6 +5,7 @@ import numbers from collections import Counter, defaultdict from contextlib import contextmanager +from dataclasses import fields import torch from opt_einsum import shared_intermediates @@ -358,3 +359,22 @@ def plate_log_prob_sum(trace: Trace, plate_symbol: str) -> torch.Tensor: [site["packed"]["log_prob"]], ) return log_prob_sum + + +class CloneMixin: + """ + Mixin class that adds ``.clone`` method to ``@dataclasses.dataclass`` decorated classes + that are made up of ``torch.Tensor`` fields. + """ + + def clone(self): + retval = dict() + for field_desc in fields(self): + field_name, value = field_desc.name, getattr(self, field_desc.name) + if isinstance(value, dict): + retval[field_name] = dict() + for key in value: + retval[field_name][key] = value[key].clone() + else: + retval[field_name] = value.clone() + return self.__class__(**retval) diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 319a1196dd..ca155ed2fd 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import logging + import pytest import torch @@ -8,8 +10,9 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, Trace_ELBO, WeighedPredictive +from pyro.infer import SVI, MHResampler, Predictive, Trace_ELBO, WeighedPredictive from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal +from pyro.ops.stats import quantile, weighed_quantile from tests.common import assert_close @@ -39,9 +42,18 @@ def beta_guide(num_trials): pyro.sample("phi", phi_posterior) -@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) +@pytest.mark.parametrize( + "predictive, num_svi_steps, test_unweighed_convergence", + [ + (Predictive, 5000, None), + (WeighedPredictive, 5000, True), + (WeighedPredictive, 1000, False), + ], +) @pytest.mark.parametrize("parallel", [False, True]) -def test_posterior_predictive_svi_manual_guide(parallel, predictive): +def test_posterior_predictive_svi_manual_guide( + parallel, predictive, num_svi_steps, test_unweighed_convergence +): true_probs = torch.ones(5) * 0.7 num_trials = ( torch.ones(5) * 400 @@ -51,9 +63,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): conditioned_model = poutine.condition(model, data={"obs": num_success}) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=3.0)), elbo) - for i in range( - 5000 - ): # Increased to 5000 from 1000 in order for guide optimization to converge + for i in range(num_svi_steps): svi.step(num_trials) posterior_predictive = predictive( model, @@ -70,10 +80,53 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): ) marginal_return_vals = weighed_samples.samples["_RETURN"] assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape - # Weights should be uniform as the guide has the same distribution as the model - assert weighed_samples.log_weights.std() < 0.6 - # Effective sample size should be close to actual number of samples taken from the guide - assert weighed_samples.get_ESS() > 0.8 * num_samples + # Resample weighed samples + resampler = MHResampler(posterior_predictive) + num_mh_steps = 10 + for mh_step_count in range(num_mh_steps): + resampled_weighed_samples = resampler( + num_trials, model_guide=conditioned_model + ) + resampled_marginal_return_vals = resampled_weighed_samples.samples["_RETURN"] + # Calculate CDF quantiles + quantile_test_point = 0.95 + quantile_test_point_value = quantile( + marginal_return_vals, [quantile_test_point] + )[0] + weighed_quantile_test_point_value = weighed_quantile( + marginal_return_vals, [quantile_test_point], weighed_samples.log_weights + )[0] + resampled_quantile_test_point_value = quantile( + resampled_marginal_return_vals, [quantile_test_point] + )[0] + logging.info( + "Unweighed quantile at test point is: " + str(quantile_test_point_value) + ) + logging.info( + "Weighed quantile at test point is: " + + str(weighed_quantile_test_point_value) + ) + logging.info( + "Resampled quantile at test point is: " + + str(resampled_quantile_test_point_value) + ) + # Weighed and resampled quantiles should match + assert_close( + weighed_quantile_test_point_value, + resampled_quantile_test_point_value, + rtol=0.01, + ) + if test_unweighed_convergence: + # Weights should be uniform as the guide has the same distribution as the model + assert weighed_samples.log_weights.std() < 0.6 + # Effective sample size should be close to actual number of samples taken from the guide + assert weighed_samples.get_ESS() > 0.8 * num_samples + # Weighed and unweighed quantiles should match if guide converged to true model + assert_close( + quantile_test_point_value, + resampled_quantile_test_point_value, + rtol=0.01, + ) assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1)