diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index bfb7f6280..7a9d9460a 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -1,6 +1,6 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see - +import warnings from copy import deepcopy from functools import partial from math import ceil @@ -387,7 +387,18 @@ def sample_batched( sample_shape: Desired shape of samples that are drawn from the posterior given every observation. x: A batch of observations, of shape `(batch_dim, event_shape_x)`. - `batch_dim` corresponds to the number of observations to be drawn. + `batch_dim` corresponds to the number of observations to be + drawn. + method: Method used for MCMC sampling, e.g., "slice_np_vectorized". + thin: The thinning factor for the chain, default 1 (no thinning). + warmup_steps: The initial number of samples to discard. + num_chains: The number of chains used for each `x` passed in the batch. + init_strategy: The initialisation strategy for chains. + init_strategy_parameters: Dictionary of keyword arguments passed to + the init strategy. + num_workers: number of cpu cores used to parallelize initial + parameter generation and mcmc sampling. + mp_context: Multiprocessing start method, either `"fork"` or `"spawn"` show_progress_bars: Whether to show sampling progress monitor. Returns: @@ -412,6 +423,16 @@ def sample_batched( method == "slice_np_vectorized" ), "Batched sampling only supported for vectorized samplers!" + # warn if num_chains is larger than num requested samples + if num_chains > torch.Size(sample_shape).numel(): + warnings.warn( + f"""Passed num_chains {num_chains} is larger than the number of + requested samples {torch.Size(sample_shape).numel()}, resetting + it to {torch.Size(sample_shape).numel()}.""", + stacklevel=2, + ) + num_chains = torch.Size(sample_shape).numel() + # custom shape handling to make sure to match the batch size of x and theta # without unnecessary combinations. if len(x.shape) == 1: @@ -455,22 +476,29 @@ def sample_batched( show_progress_bars=show_progress_bars, ) - samples = self.theta_transform.inv(transformed_samples) - sample_shape_len = len(sample_shape) - # The MCMC sampler returns the samples per chain, of shape - # (num_samples, num_chains_extended, *input_shape). We return the samples as ` - # (*sample_shape, x_batch_size, *input_shape). This means we want to combine - # all the chains that belong to the same x. However, using - # samples.reshape(*sample_shape,batch_size,-1) does not combine the samples in - # the right order, since this mixes samples that belong to different `x`. - # This is a workaround to reshape the samples in the right order. - return samples.reshape((batch_size, *sample_shape, -1)).permute( # type: ignore - tuple(range(1, sample_shape_len + 1)) - + ( - 0, - -1, - ) - ) + # (num_chains_extended, samples_per_chain, *input_shape) + samples_per_chain: Tensor = self.theta_transform.inv(transformed_samples) # type: ignore + dim_theta = samples_per_chain.shape[-1] + # We need to collect samples for each x from the respective chains. + # However, using samples.reshape(*sample_shape, batch_size, dim_theta) + # does not combine the samples in the right order, since this mixes + # samples that belong to different `x`. The following permute is a + # workaround to reshape the samples in the right order. + samples_per_x = samples_per_chain.reshape(( + batch_size, + # We are flattening the sample shape here using -1 because we might have + # generated more samples than requested (more chains, or multiple of + # chains not matching sample_shape) + -1, + dim_theta, + )).permute(1, 0, -1) + + # Shape is now (-1, batch_size, dim_theta) + # We can now select the number of requested samples + samples = samples_per_x[: torch.Size(sample_shape).numel()] + # and reshape into (*sample_shape, batch_size, dim_theta) + samples = samples.reshape((*sample_shape, batch_size, dim_theta)) + return samples def _build_mcmc_init_fn( self, diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 7a98c1020..69fba316d 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -121,11 +121,24 @@ def test_batched_sample_log_prob_with_different_x( @pytest.mark.parametrize("snlre_method", [SNLE_A, SNRE_A, SNRE_B, SNRE_C, SNPE_C]) @pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) @pytest.mark.parametrize("init_strategy", ["proposal", "resample"]) +@pytest.mark.parametrize( + "sample_shape", + ( + (5,), # less than num_chains + (4, 2), # 2D batch + (15,), # not divisible by num_chains + ), +) def test_batched_mcmc_sample_log_prob_with_different_x( - snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict, init_strategy: str + snlre_method: type, + x_o_batch_dim: bool, + mcmc_params_fast: dict, + init_strategy: str, + sample_shape: torch.Size, ): num_dim = 2 - num_simulations = 1000 + num_simulations = 100 + num_chains = 10 prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) simulator = diagonal_linear_gaussian @@ -133,7 +146,7 @@ def test_batched_mcmc_sample_log_prob_with_different_x( inference = snlre_method(prior=prior) theta = prior.sample((num_simulations,)) x = simulator(theta) - inference.append_simulations(theta, x).train(max_num_epochs=3) + inference.append_simulations(theta, x).train(max_num_epochs=2) x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) @@ -144,16 +157,16 @@ def test_batched_mcmc_sample_log_prob_with_different_x( ) samples = posterior.sample_batched( - (10,), + sample_shape, x_o, init_strategy=init_strategy, - num_chains=2, + num_chains=num_chains, ) assert ( - samples.shape == (10, x_o_batch_dim, num_dim) + samples.shape == (*sample_shape, x_o_batch_dim, num_dim) if x_o_batch_dim > 0 - else (10, num_dim) + else (*sample_shape, num_dim) ), "Sample shape wrong" if x_o_batch_dim > 1: @@ -167,14 +180,18 @@ def test_batched_mcmc_sample_log_prob_with_different_x( ) x_o = torch.stack([0.5 * ones(num_dim), -0.5 * ones(num_dim)], dim=0) - # test with multiple chains to test whether correct chains are concatenated. - samples = posterior.sample_batched((1000,), x_o, num_chains=2, warmup_steps=500) + # test with multiple chains to test whether correct chains are + # concatenated. + sample_shape = (1000,) # use enough samples for accuracy comparison + samples = posterior.sample_batched( + sample_shape, x_o, num_chains=num_chains, warmup_steps=500 + ) samples_separate1 = posterior.sample( - (1000,), x_o[0], num_chains=2, warmup_steps=500 + sample_shape, x_o[0], num_chains=num_chains, warmup_steps=500 ) samples_separate2 = posterior.sample( - (1000,), x_o[1], num_chains=2, warmup_steps=500 + sample_shape, x_o[1], num_chains=num_chains, warmup_steps=500 ) # Check if means are approx. same