From b66641115a29a44b3c9592c8a2b84dbacf5a9fa4 Mon Sep 17 00:00:00 2001 From: janfb Date: Tue, 25 Jan 2022 17:38:47 +0100 Subject: [PATCH] adapt mcmc refactoring and arviz integration to new sampler interface. Co-authored-by: Seth Axen Co-authored-by: janfb --- sbi/inference/posteriors/mcmc_posterior.py | 23 ++++++++++++++++- sbi/samplers/mcmc/__init__.py | 1 + tests/posterior_sampler_test.py | 29 ++++++++++++++-------- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 1c9dd5889..acabedc4e 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -13,12 +13,17 @@ from torch import multiprocessing as mp from tqdm.auto import tqdm +# TODO: fix imports. from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.samplers.mcmc import ( IterateParameters, Slice, + SliceSamplerSerial, + SliceSamplerVectorized, + prior_init, proposal_init, resample_given_potential_fn, + sir, sir_init, slice_np_parallized, ) @@ -47,6 +52,7 @@ def __init__( init_strategy_parameters: Dict[str, Any] = {}, init_strategy_num_candidates: Optional[int] = None, num_workers: int = 1, + param_name: str = "theta", device: Optional[str] = None, x_shape: Optional[torch.Size] = None, ): @@ -81,6 +87,9 @@ def __init__( locations in `init_strategy=sir` (deprecated, use init_strategy_parameters instead). num_workers: number of cpu cores used to parallelize mcmc + param_name: Name of the sampled parameters used internally. When sampling + with `mcmc_method` of `slice`, `hmc`, or `nuts`, this name is used in + the sampler returned by `self.posterior_sampler` after sampling. device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None, `potential_fn.device` is used. x_shape: Shape of a single simulator output. If passed, it is used to check @@ -102,6 +111,8 @@ def __init__( self.init_strategy = init_strategy self.init_strategy_parameters = init_strategy_parameters self.num_workers = num_workers + self.param_name = param_name + self._posterior_sampler = None if init_strategy_num_candidates is not None: warn( @@ -130,6 +141,11 @@ def mcmc_method(self, method: str) -> None: """See `set_mcmc_method`.""" self.set_mcmc_method(method) + @property + def posterior_sampler(self): + """Returns sampler created by `sample` when `sample_with='mcmc'`.""" + return self._posterior_sampler + def set_mcmc_method(self, method: str) -> "NeuralPosterior": """Sets sampling method to for MCMC and returns `NeuralPosterior`. @@ -441,6 +457,7 @@ def _slice_np_mcmc( num_chains, dim_samples = initial_params.shape + # TODO: fix typo in function name. samples = slice_np_parallized( potential_function, initial_params, @@ -455,6 +472,10 @@ def _slice_np_mcmc( # Save sample as potential next init (if init_strategy == 'latest_sample'). self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) + # TODO: adapt slice_np_parallelized to return the sampler object. + # Save posterior sampler. + # self._posterior_sampler = posterior_sampler + samples = samples.reshape(-1, dim_samples)[:num_samples, :] assert samples.shape[0] == num_samples @@ -494,7 +515,7 @@ def _pyro_mcmc( kernel=kernels[mcmc_method](potential_fn=potential_function), num_samples=(thin * num_samples) // num_chains + num_chains, warmup_steps=warmup_steps, - initial_params={"": initial_params}, + initial_params={self.param_name: initial_params}, num_chains=num_chains, mp_context="spawn", disable_progbar=not show_progress_bars, diff --git a/sbi/samplers/mcmc/__init__.py b/sbi/samplers/mcmc/__init__.py index 125cc5f86..7b4f64392 100644 --- a/sbi/samplers/mcmc/__init__.py +++ b/sbi/samplers/mcmc/__init__.py @@ -7,6 +7,7 @@ from sbi.samplers.mcmc.slice import Slice from sbi.samplers.mcmc.slice_numpy import ( SliceSampler, + SliceSamplerSerial, SliceSamplerVectorized, slice_np_parallized, ) diff --git a/tests/posterior_sampler_test.py b/tests/posterior_sampler_test.py index 14eebdcef..b397ee23f 100644 --- a/tests/posterior_sampler_test.py +++ b/tests/posterior_sampler_test.py @@ -4,19 +4,25 @@ from __future__ import annotations import pytest +from pyro.infer.mcmc import MCMC from torch import eye, zeros from torch.distributions import MultivariateNormal from sbi import utils as utils -from sbi.inference import SNL, prepare_for_sbi, simulate_for_sbi +from sbi.inference import ( + SNL, + MCMCPosterior, + likelihood_estimator_based_potential, + prepare_for_sbi, + simulate_for_sbi, +) +from sbi.samplers.mcmc import SliceSamplerSerial, SliceSamplerVectorized from sbi.simulators.linear_gaussian import diagonal_linear_gaussian -from sbi.mcmc import SliceSamplerVectorized, SliceSamplerSerial -from pyro.infer.mcmc import MCMC - @pytest.mark.parametrize( - "sampling_method", ( + "sampling_method", + ( "slice_np", "slice_np_vectorized", "slice", @@ -47,14 +53,17 @@ def test_api_posterior_sampler_set(sampling_method: str, set_seed): theta, x = simulate_for_sbi( simulator, prior, num_simulations, simulation_batch_size=10 ) - _ = inference.append_simulations(theta, x).train(max_num_epochs=5) - posterior = inference.build_posterior( - sample_with="mcmc", mcmc_method=sampling_method - ).set_default_x(x_o) + estimator = inference.append_simulations(theta, x).train(max_num_epochs=5) + potential_fn, transform = likelihood_estimator_based_potential( + estimator, prior, x_o + ) + posterior = MCMCPosterior( + potential_fn, theta_transform=transform, method=sampling_method, proposal=prior + ) assert posterior.posterior_sampler is None posterior.sample( - sample_shape=(num_samples,num_chains), + sample_shape=(num_samples, num_chains), x=x_o, mcmc_parameters={ "thin": 3,