Skip to content

Commit

Permalink
fix: batched mcmc sampling reshaping
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 2, 2024
1 parent b275448 commit 295ff42
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 29 deletions.
64 changes: 46 additions & 18 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from copy import deepcopy
from functools import partial
from math import ceil
Expand Down Expand Up @@ -387,7 +387,18 @@ def sample_batched(
sample_shape: Desired shape of samples that are drawn from the posterior
given every observation.
x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
`batch_dim` corresponds to the number of observations to be drawn.
`batch_dim` corresponds to the number of observations to be
drawn.
method: Method used for MCMC sampling, e.g., "slice_np_vectorized".
thin: The thinning factor for the chain, default 1 (no thinning).
warmup_steps: The initial number of samples to discard.
num_chains: The number of chains used for each `x` passed in the batch.
init_strategy: The initialisation strategy for chains.
init_strategy_parameters: Dictionary of keyword arguments passed to
the init strategy.
num_workers: number of cpu cores used to parallelize initial
parameter generation and mcmc sampling.
mp_context: Multiprocessing start method, either `"fork"` or `"spawn"`
show_progress_bars: Whether to show sampling progress monitor.
Returns:
Expand All @@ -412,6 +423,16 @@ def sample_batched(
method == "slice_np_vectorized"
), "Batched sampling only supported for vectorized samplers!"

# warn if num_chains is larger than num requested samples
if num_chains > torch.Size(sample_shape).numel():
warnings.warn(
f"""Passed num_chains {num_chains} is larger than the number of
requested samples {torch.Size(sample_shape).numel()}, resetting
it to {torch.Size(sample_shape).numel()}.""",
stacklevel=2,
)
num_chains = torch.Size(sample_shape).numel()

# custom shape handling to make sure to match the batch size of x and theta
# without unnecessary combinations.
if len(x.shape) == 1:
Expand Down Expand Up @@ -455,22 +476,29 @@ def sample_batched(
show_progress_bars=show_progress_bars,
)

samples = self.theta_transform.inv(transformed_samples)
sample_shape_len = len(sample_shape)
# The MCMC sampler returns the samples per chain, of shape
# (num_samples, num_chains_extended, *input_shape). We return the samples as `
# (*sample_shape, x_batch_size, *input_shape). This means we want to combine
# all the chains that belong to the same x. However, using
# samples.reshape(*sample_shape,batch_size,-1) does not combine the samples in
# the right order, since this mixes samples that belong to different `x`.
# This is a workaround to reshape the samples in the right order.
return samples.reshape((batch_size, *sample_shape, -1)).permute( # type: ignore
tuple(range(1, sample_shape_len + 1))
+ (
0,
-1,
)
)
# (num_chains_extended, samples_per_chain, *input_shape)
samples_per_chain: Tensor = self.theta_transform.inv(transformed_samples) # type: ignore
dim_theta = samples_per_chain.shape[-1]
# We need to collect samples for each x from the respective chains.
# However, using samples.reshape(*sample_shape, batch_size, dim_theta)
# does not combine the samples in the right order, since this mixes
# samples that belong to different `x`. The following permute is a
# workaround to reshape the samples in the right order.
samples_per_x = samples_per_chain.reshape((
batch_size,
# We are flattening the sample shape here using -1 because we might have
# generated more samples than requested (more chains, or multiple of
# chains not matching sample_shape)
-1,
dim_theta,
)).permute(1, 0, -1)

# Shape is now (-1, batch_size, dim_theta)
# We can now select the number of requested samples
samples = samples_per_x[: torch.Size(sample_shape).numel()]
# and reshape into (*sample_shape, batch_size, dim_theta)
samples = samples.reshape((*sample_shape, batch_size, dim_theta))
return samples

def _build_mcmc_init_fn(
self,
Expand Down
39 changes: 28 additions & 11 deletions tests/posterior_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,32 @@ def test_batched_sample_log_prob_with_different_x(
@pytest.mark.parametrize("snlre_method", [SNLE_A, SNRE_A, SNRE_B, SNRE_C, SNPE_C])
@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
@pytest.mark.parametrize("init_strategy", ["proposal", "resample"])
@pytest.mark.parametrize(
"sample_shape",
(
(5,), # less than num_chains
(4, 2), # 2D batch
(15,), # not divisible by num_chains
),
)
def test_batched_mcmc_sample_log_prob_with_different_x(
snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict, init_strategy: str
snlre_method: type,
x_o_batch_dim: bool,
mcmc_params_fast: dict,
init_strategy: str,
sample_shape: torch.Size,
):
num_dim = 2
num_simulations = 1000
num_simulations = 100
num_chains = 10

prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
simulator = diagonal_linear_gaussian

inference = snlre_method(prior=prior)
theta = prior.sample((num_simulations,))
x = simulator(theta)
inference.append_simulations(theta, x).train(max_num_epochs=3)
inference.append_simulations(theta, x).train(max_num_epochs=2)

x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim)

Expand All @@ -144,16 +157,16 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
)

samples = posterior.sample_batched(
(10,),
sample_shape,
x_o,
init_strategy=init_strategy,
num_chains=2,
num_chains=num_chains,
)

assert (
samples.shape == (10, x_o_batch_dim, num_dim)
samples.shape == (*sample_shape, x_o_batch_dim, num_dim)
if x_o_batch_dim > 0
else (10, num_dim)
else (*sample_shape, num_dim)
), "Sample shape wrong"

if x_o_batch_dim > 1:
Expand All @@ -167,14 +180,18 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
)

x_o = torch.stack([0.5 * ones(num_dim), -0.5 * ones(num_dim)], dim=0)
# test with multiple chains to test whether correct chains are concatenated.
samples = posterior.sample_batched((1000,), x_o, num_chains=2, warmup_steps=500)
# test with multiple chains to test whether correct chains are
# concatenated.
sample_shape = (1000,) # use enough samples for accuracy comparison
samples = posterior.sample_batched(
sample_shape, x_o, num_chains=num_chains, warmup_steps=500
)

samples_separate1 = posterior.sample(
(1000,), x_o[0], num_chains=2, warmup_steps=500
sample_shape, x_o[0], num_chains=num_chains, warmup_steps=500
)
samples_separate2 = posterior.sample(
(1000,), x_o[1], num_chains=2, warmup_steps=500
sample_shape, x_o[1], num_chains=num_chains, warmup_steps=500
)

# Check if means are approx. same
Expand Down

0 comments on commit 295ff42

Please sign in to comment.