Skip to content

Commit

Permalink
adapt mcmc refactoring and arviz integration to new sampler interface.
Browse files Browse the repository at this point in the history
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: janfb <j.f.boelts@gmail.com>
  • Loading branch information
janfb and sethaxen committed Jul 14, 2022
1 parent fb445b6 commit b666411
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
23 changes: 22 additions & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
from torch import multiprocessing as mp
from tqdm.auto import tqdm

# TODO: fix imports.
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.samplers.mcmc import (
IterateParameters,
Slice,
SliceSamplerSerial,
SliceSamplerVectorized,
prior_init,
proposal_init,
resample_given_potential_fn,
sir,
sir_init,
slice_np_parallized,
)
Expand Down Expand Up @@ -47,6 +52,7 @@ def __init__(
init_strategy_parameters: Dict[str, Any] = {},
init_strategy_num_candidates: Optional[int] = None,
num_workers: int = 1,
param_name: str = "theta",
device: Optional[str] = None,
x_shape: Optional[torch.Size] = None,
):
Expand Down Expand Up @@ -81,6 +87,9 @@ def __init__(
locations in `init_strategy=sir` (deprecated, use
init_strategy_parameters instead).
num_workers: number of cpu cores used to parallelize mcmc
param_name: Name of the sampled parameters used internally. When sampling
with `mcmc_method` of `slice`, `hmc`, or `nuts`, this name is used in
the sampler returned by `self.posterior_sampler` after sampling.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of a single simulator output. If passed, it is used to check
Expand All @@ -102,6 +111,8 @@ def __init__(
self.init_strategy = init_strategy
self.init_strategy_parameters = init_strategy_parameters
self.num_workers = num_workers
self.param_name = param_name
self._posterior_sampler = None

if init_strategy_num_candidates is not None:
warn(
Expand Down Expand Up @@ -130,6 +141,11 @@ def mcmc_method(self, method: str) -> None:
"""See `set_mcmc_method`."""
self.set_mcmc_method(method)

@property
def posterior_sampler(self):
"""Returns sampler created by `sample` when `sample_with='mcmc'`."""
return self._posterior_sampler

def set_mcmc_method(self, method: str) -> "NeuralPosterior":
"""Sets sampling method to for MCMC and returns `NeuralPosterior`.
Expand Down Expand Up @@ -441,6 +457,7 @@ def _slice_np_mcmc(

num_chains, dim_samples = initial_params.shape

# TODO: fix typo in function name.
samples = slice_np_parallized(
potential_function,
initial_params,
Expand All @@ -455,6 +472,10 @@ def _slice_np_mcmc(
# Save sample as potential next init (if init_strategy == 'latest_sample').
self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples)

# TODO: adapt slice_np_parallelized to return the sampler object.
# Save posterior sampler.
# self._posterior_sampler = posterior_sampler

samples = samples.reshape(-1, dim_samples)[:num_samples, :]
assert samples.shape[0] == num_samples

Expand Down Expand Up @@ -494,7 +515,7 @@ def _pyro_mcmc(
kernel=kernels[mcmc_method](potential_fn=potential_function),
num_samples=(thin * num_samples) // num_chains + num_chains,
warmup_steps=warmup_steps,
initial_params={"": initial_params},
initial_params={self.param_name: initial_params},
num_chains=num_chains,
mp_context="spawn",
disable_progbar=not show_progress_bars,
Expand Down
1 change: 1 addition & 0 deletions sbi/samplers/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sbi.samplers.mcmc.slice import Slice
from sbi.samplers.mcmc.slice_numpy import (
SliceSampler,
SliceSamplerSerial,
SliceSamplerVectorized,
slice_np_parallized,
)
29 changes: 19 additions & 10 deletions tests/posterior_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@
from __future__ import annotations

import pytest
from pyro.infer.mcmc import MCMC
from torch import eye, zeros
from torch.distributions import MultivariateNormal

from sbi import utils as utils
from sbi.inference import SNL, prepare_for_sbi, simulate_for_sbi
from sbi.inference import (
SNL,
MCMCPosterior,
likelihood_estimator_based_potential,
prepare_for_sbi,
simulate_for_sbi,
)
from sbi.samplers.mcmc import SliceSamplerSerial, SliceSamplerVectorized
from sbi.simulators.linear_gaussian import diagonal_linear_gaussian


from sbi.mcmc import SliceSamplerVectorized, SliceSamplerSerial
from pyro.infer.mcmc import MCMC

@pytest.mark.parametrize(
"sampling_method", (
"sampling_method",
(
"slice_np",
"slice_np_vectorized",
"slice",
Expand Down Expand Up @@ -47,14 +53,17 @@ def test_api_posterior_sampler_set(sampling_method: str, set_seed):
theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=10
)
_ = inference.append_simulations(theta, x).train(max_num_epochs=5)
posterior = inference.build_posterior(
sample_with="mcmc", mcmc_method=sampling_method
).set_default_x(x_o)
estimator = inference.append_simulations(theta, x).train(max_num_epochs=5)
potential_fn, transform = likelihood_estimator_based_potential(
estimator, prior, x_o
)
posterior = MCMCPosterior(
potential_fn, theta_transform=transform, method=sampling_method, proposal=prior
)

assert posterior.posterior_sampler is None
posterior.sample(
sample_shape=(num_samples,num_chains),
sample_shape=(num_samples, num_chains),
x=x_o,
mcmc_parameters={
"thin": 3,
Expand Down

0 comments on commit b666411

Please sign in to comment.