diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 0db66b1cb..4aaf1385e 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -121,6 +121,17 @@ def sample( """See child classes for docstring.""" pass + @abstractmethod + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10_000, + show_progress_bars: bool = True, + ) -> Tensor: + """See child classes for docstring.""" + pass + @property def default_x(self) -> Optional[Tensor]: """Return default x used by `.sample(), .log_prob` as conditioning context.""" diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index fb20a580e..9f90cfb45 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -16,7 +16,7 @@ reshape_to_batch_event, reshape_to_sample_batch_event, ) -from sbi.samplers.rejection.rejection import accept_reject_sample +from sbi.samplers.rejection import rejection from sbi.sbi_types import Shape from sbi.utils import check_prior, within_support from sbi.utils.torchutils import ensure_theta_batched @@ -123,7 +123,51 @@ def sample( f"`.build_posterior(sample_with={sample_with}).`" ) - samples = accept_reject_sample( + samples = rejection.accept_reject_sample( + proposal=self.posterior_estimator, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"condition": x}, + alternative_method="build_posterior(..., sample_with='mcmc')", + )[0] + + return samples[:, 0] # Remove batch dimension. + + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10_000, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Given a batch of observations [x_1, ..., x_B] this function samples from + posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized) + manner. + + Args: + sample_shape: Desired shape of samples that are drawn from the posterior + given every observation. + x: A batch of observations, of shape `(batch_dim, event_shape_x)`. + `batch_dim` corresponds to the number of observations to be drawn. + max_sampling_batch_size: Maximum batch size for rejection sampling. + show_progress_bars: Whether to show sampling progress monitor. + + Returns: + Samples from the posteriors of shape (*sample_shape, B, *input_shape) + """ + num_samples = torch.Size(sample_shape).numel() + condition_shape = self.posterior_estimator.condition_shape + x = reshape_to_batch_event(x, event_shape=condition_shape) + + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size + ) + + samples = rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, @@ -210,6 +254,81 @@ def log_prob( return masked_log_prob - log_factor + def log_prob_batched( + self, + theta: Tensor, + x: Tensor, + norm_posterior: bool = True, + track_gradients: bool = False, + leakage_correction_params: Optional[dict] = None, + ) -> Tensor: + """Given a batch of observations [x_1, ..., x_B] and a batch of parameters \ + [$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \ + of the posteriors $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \ + (i.e. vectorized) manner. + + Args: + theta: Batch of parameters $\theta$ of shape \ + `(*sample_shape, batch_dim, *theta_shape)`. + x: Batch of observations $x$ of shape \ + `(batch_dim, *condition_shape)`. + norm_posterior: Whether to enforce a normalized posterior density. + Renormalization of the posterior is useful when some + probability falls out or leaks out of the prescribed prior support. + The normalizing factor is calculated via rejection sampling, so if you + need speedier but unnormalized log posterior estimates set here + `norm_posterior=False`. The returned log posterior is set to + -∞ outside of the prior support regardless of this setting. + track_gradients: Whether the returned tensor supports tracking gradients. + This can be helpful for e.g. sensitivity analysis, but increases memory + consumption. + leakage_correction_params: A `dict` of keyword arguments to override the + default values of `leakage_correction()`. Possible options are: + `num_rejection_samples`, `force_update`, `show_progress_bars`, and + `rejection_sampling_batch_size`. + These parameters only have an effect if `norm_posterior=True`. + + Returns: + `(len(θ), B)`-shaped log posterior probability $\\log p(\theta|x)$\\ for θ \ + in the support of the prior, -∞ (corresponding to 0 probability) outside. + """ + + theta = ensure_theta_batched(torch.as_tensor(theta)) + event_shape = self.posterior_estimator.input_shape + theta_density_estimator = reshape_to_sample_batch_event( + theta, event_shape, leading_is_sample=True + ) + x_density_estimator = reshape_to_batch_event( + x, event_shape=self.posterior_estimator.condition_shape + ) + + self.posterior_estimator.eval() + + with torch.set_grad_enabled(track_gradients): + # Evaluate on device, move back to cpu for comparison with prior. + unnorm_log_prob = self.posterior_estimator.log_prob( + theta_density_estimator, condition=x_density_estimator + ) + + # Force probability to be zero outside prior support. + in_prior_support = within_support(self.prior, theta) + + masked_log_prob = torch.where( + in_prior_support, + unnorm_log_prob, + torch.tensor(float("-inf"), dtype=torch.float32, device=self._device), + ) + + if leakage_correction_params is None: + leakage_correction_params = dict() # use defaults + log_factor = ( + log(self.leakage_correction(x=x, **leakage_correction_params)) + if norm_posterior + else 0 + ) + + return masked_log_prob - log_factor + @torch.no_grad() def leakage_correction( self, @@ -240,7 +359,7 @@ def leakage_correction( def acceptance_at(x: Tensor) -> Tensor: # [1:] to remove batch-dimension for `reshape_to_batch_event`. - return accept_reject_sample( + return rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_rejection_samples, diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 72af02d88..b7eef8f7d 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -179,6 +179,29 @@ def sample( ) return torch.vstack(samples).reshape(*sample_shape, -1) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + **kwargs, + ) -> Tensor: + num_samples = torch.Size(sample_shape).numel() + posterior_indices = torch.multinomial( + self._weights, num_samples, replacement=True + ) + samples = [] + for posterior_index, sample_size in torch.vstack( + posterior_indices.unique(return_counts=True) + ).T: + sample_shape_c = torch.Size((int(sample_size),)) + samples.append( + self.posteriors[posterior_index].sample_batched( + sample_shape_c, x=x, **kwargs + ) + ) + samples = torch.vstack(samples) + return samples.reshape(sample_shape + samples.shape[1:]) + def log_prob( self, theta: Tensor, diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index bbd4ce32f..b6e9721d1 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -194,6 +194,19 @@ def sample( else: raise NameError + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not implemented for ImportanceSamplingPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." + ) + def _importance_sample( self, sample_shape: Shape = torch.Size(), diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 5ef9f882a..65e9a0c25 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -352,8 +352,51 @@ def sample( raise NameError(f"The sampling method {method} is not implemented!") samples = self.theta_transform.inv(transformed_samples) + # NOTE: Currently MCMCPosteriors will require a single dimension for the + # parameter dimension. With recent ConditionalDensity(Ratio) estimators, we + # can have multiple dimensions for the parameter dimension. + samples = samples.reshape((*sample_shape, -1)) # type: ignore - return samples.reshape((*sample_shape, -1)) # type: ignore + return samples + + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + method: Optional[str] = None, + thin: Optional[int] = None, + warmup_steps: Optional[int] = None, + num_chains: Optional[int] = None, + init_strategy: Optional[str] = None, + init_strategy_parameters: Optional[Dict[str, Any]] = None, + num_workers: Optional[int] = None, + mp_context: Optional[str] = None, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Given a batch of observations [x_1, ..., x_B] this function samples from + posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized) + manner. + + Check the `__init__()` method for a description of all arguments as well as + their default values. + + Args: + sample_shape: Desired shape of samples that are drawn from the posterior + given every observation. + x: A batch of observations, of shape `(batch_dim, event_shape_x)`. + `batch_dim` corresponds to the number of observations to be drawn. + show_progress_bars: Whether to show sampling progress monitor. + + Returns: + Samples from the posteriors of shape (*sample_shape, B, *input_shape) + """ + + # See #1176 for a discussion on the implementation of batched sampling. + raise NotImplementedError( + "Batched sampling is not implemented for MCMC posterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." + ) def _build_mcmc_init_fn( self, diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py index 6da838059..549942ce2 100644 --- a/sbi/inference/posteriors/rejection_posterior.py +++ b/sbi/inference/posteriors/rejection_posterior.py @@ -167,6 +167,19 @@ def sample( return samples.reshape((*sample_shape, -1)) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not implemented for RejectionPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." + ) + def map( self, x: Optional[Tensor] = None, diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 006ab543a..f75ac0a4b 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -296,6 +296,19 @@ def sample( samples = self.q.sample(torch.Size(sample_shape)) return samples.reshape((*sample_shape, samples.shape[-1])) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not implemented for VIPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." + ) + def log_prob( self, theta: Tensor, diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index dd3774e6f..1295ff287 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -474,7 +474,8 @@ def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tenso condition = condition.to(self._device) if not self._apply_correction: - return self._neural_net.sample(sample_shape, condition=condition) + samples = self._neural_net.sample(sample_shape, condition=condition) + return samples else: # When we want to sample from the approx. posterior, a proposal prior # \tilde{p} has already been observed. To analytically calculate the @@ -483,7 +484,12 @@ def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tenso condition_ndim = len(self.condition_shape) batch_size = condition.shape[:-condition_ndim] batch_size = torch.Size(batch_size).numel() - return self._sample_approx_posterior_mog(num_samples, condition, batch_size) + samples = self._sample_approx_posterior_mog( + num_samples, condition, batch_size + ) + # NOTE: New batching convention: (batch_dim, sample_dim, *event_shape) + samples = samples.transpose(0, 1) + return samples def _sample_approx_posterior_mog( self, num_samples, x: Tensor, batch_size: int diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 8d6aaba55..42aee9d47 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -135,6 +135,8 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: num_samples = torch.Size(sample_shape).numel() samples = self.net.sample(num_samples, context=condition) + # Change from Nflows' convention of (batch_dim, sample_dim, *event_shape) to + # (sample_dim, batch_dim, *event_shape) (PyTorch + SBI). samples = samples.transpose(0, 1) return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape)) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 5c78cd8dd..5c1d5ffc5 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -236,7 +236,8 @@ def accept_reject_sample( `rejection_sample()`. Warn if not empty. Returns: - Accepted samples and acceptance rate as scalar Tensor. + Accepted samples of shape `(sample_dim, batch_dim, *event_shape)`, and + acceptance rates for each observation. """ if kwargs: @@ -253,39 +254,61 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) + if proposal_sampling_kwargs is None: + proposal_sampling_kwargs = {} - num_sampled_total, num_remaining = 0, num_samples - accepted, acceptance_rate = [], float("Nan") + num_remaining = num_samples + + # NOTE: We might want to change this to a more general approach in the future. + # Currently limited to a single "batch_dim" for the condition. + # But this would require giving the method the condition_shape explicitly... + if "condition" in proposal_sampling_kwargs: + num_xos = proposal_sampling_kwargs["condition"].shape[0] + else: + num_xos = 1 + + accepted = [[] for _ in range(num_xos)] + acceptance_rate = torch.full((num_xos,), float("Nan")) leakage_warning_raised = False # Ruff suggestion - if proposal_sampling_kwargs is None: - proposal_sampling_kwargs = {} # To cover cases with few samples without leakage: sampling_batch_size = min(num_samples, max_sampling_batch_size) + num_sampled_total = torch.zeros(num_xos) + num_samples_possible = 0 while num_remaining > 0: # Sample and reject. candidates = proposal.sample( (sampling_batch_size,), # type: ignore **proposal_sampling_kwargs, ) - # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - samples = candidates[are_accepted] - accepted.append(samples) + # Reshape necessary in certain cases which do not follow the shape conventions + # of the "DensityEstimator" class. + are_accepted = are_accepted.reshape(sampling_batch_size, num_xos) + candidates_to_reject = candidates.reshape( + sampling_batch_size, num_xos, *candidates.shape[candidates.ndim - 1 :] + ) + + for i in range(num_xos): + accepted[i].append(candidates_to_reject[are_accepted[:, i], i]) # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the - # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work - # in dim = -2. - num_sampled_total += sampling_batch_size - num_remaining -= samples.shape[-2] - pbar.update(samples.shape[-2]) + # samples will be of shape(sampling_batch_size,*batch_shape, *event_shape) + # and hence work in dim = 0. + num_accepted = are_accepted.sum(dim=0) + num_sampled_total += num_accepted.to(num_sampled_total.device) + num_samples_possible += sampling_batch_size + min_num_accepted = num_accepted.min().item() + num_remaining -= min_num_accepted + pbar.update(min_num_accepted) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. - acceptance_rate = (num_samples - num_remaining) / num_sampled_total + acceptance_rate = num_sampled_total / num_samples_possible + min_acceptance_rate = acceptance_rate.min().item() # For remaining iterations (leakage or many samples) continue # sampling with fixed batch size, reduced in cased the number @@ -293,20 +316,21 @@ def accept_reject_sample( # by zero if acceptance rate is zero. sampling_batch_size = min( max_sampling_batch_size, - max(int(1.5 * num_remaining / max(acceptance_rate, 1e-12)), 100), + max(int(1.5 * num_remaining / max(min_acceptance_rate, 1e-12)), 100), ) if ( - num_sampled_total > 1000 - and acceptance_rate < warn_acceptance + num_sampled_total.min().item() > 1000 + and min_acceptance_rate < warn_acceptance and not leakage_warning_raised ): if sample_for_correction_factor: + idx_min = acceptance_rate.argmin().item() logging.warning( f"""Drawing samples from posterior to estimate the normalizing constant for `log_prob()`. However, only - {acceptance_rate:.3%} posterior samples are within the - prior support. It may take a long time to collect the - remaining {num_remaining} samples. + {min_acceptance_rate:.3%} posterior samples are within the + prior support (for condition {idx_min}). It may take a long time + to collect the remaining {num_remaining} samples. Consider interrupting (Ctrl-C) and either basing the estimate of the normalizing constant on fewer samples (by calling `posterior.leakage_correction(x_o, @@ -318,7 +342,7 @@ def accept_reject_sample( result in an unnormalized `log_prob()`.""" ) else: - warn_msg = f"""Only {acceptance_rate:.3%} proposal samples are + warn_msg = f"""Only {min_acceptance_rate:.3%} proposal samples are accepted. It may take a long time to collect the remaining {num_remaining} samples. """ if alternative_method is not None: @@ -331,9 +355,13 @@ def accept_reject_sample( pbar.close() # When in case of leakage a batch size was used there could be too many samples. - samples = torch.cat(accepted, dim=-2)[..., :num_samples, :] + samples = [torch.cat(accepted[i], dim=0)[:num_samples] for i in range(num_xos)] + samples = torch.stack(samples, dim=1) + samples = samples.reshape(num_samples, *candidates.shape[1:]) assert ( - samples.shape[-2] == num_samples + samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." - return samples, as_tensor(acceptance_rate) + # NOTE: Restriction prior does currently require a float as return for the + # acceptance rate, which is why we for now also return the minimum acceptance rate. + return samples, as_tensor(min_acceptance_rate) diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 35fb0d946..4fd0cd794 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -299,14 +299,21 @@ def test_correctness_of_density_estimator_log_prob( build_zuko_nsf, build_zuko_sospf, build_zuko_unaf, - build_categoricalmassestimator, - build_mnle, + pytest.param( + build_categoricalmassestimator, + marks=pytest.mark.xfail(reason='see issue #1172'), + ), + pytest.param(build_mnle, marks=pytest.mark.xfail(reason='see issue #1172')), ), ) @pytest.mark.parametrize("input_event_shape", ((1,), (4,))) @pytest.mark.parametrize("condition_event_shape", ((1,), (7,))) +@pytest.mark.parametrize("sample_shape", ((1000,), (500, 2))) def test_correctness_of_batched_vs_seperate_sample_and_log_prob( - density_estimator_build_fn, input_event_shape, condition_event_shape + density_estimator_build_fn, + input_event_shape, + condition_event_shape, + sample_shape, ): input_sample_dim = 2 batch_dim = 2 @@ -318,7 +325,9 @@ def test_correctness_of_batched_vs_seperate_sample_and_log_prob( input_sample_dim, ) # Batched vs separate sampling - samples = density_estimator.sample((1000,), condition=condition) + samples = density_estimator.sample(sample_shape, condition=condition) + samples = samples.reshape(-1, batch_dim, *input_event_shape) # Flat for comp. + samples_separate1 = density_estimator.sample( (1000,), condition=condition[0][None, ...] ) diff --git a/tests/ensemble_test.py b/tests/ensemble_test.py index 8ec819eb3..bdd9fd5e0 100644 --- a/tests/ensemble_test.py +++ b/tests/ensemble_test.py @@ -139,3 +139,13 @@ def simulator(theta): # test individual log_prob and map posterior.log_prob(samples, individually=True) + + # Test sample_batched + x_o_batch_dim = 2 + if isinstance(inferer, (SNLE_A, SNRE_A)): + # TODO: Implement batched sampling for MCMC methods + return + else: + samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim)) + + assert samples.shape == (10, x_o_batch_dim, num_dim), "Sample shape wrong" diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 33f4c29e4..f634c1fd2 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -8,8 +8,12 @@ from torch.distributions import MultivariateNormal from sbi.inference import ( + SNLE_A, SNPE_A, SNPE_C, + SNRE_A, + SNRE_B, + SNRE_C, DirectPosterior, simulate_for_sbi, ) @@ -49,3 +53,70 @@ def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool): ).set_default_x(x_o) samples = posterior.sample((10,)) _ = posterior.log_prob(samples) + + +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) +def test_batched_sample_log_prob_with_different_x( + snpe_method: type, x_o_batch_dim: bool +): + num_dim = 2 + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + simulator = diagonal_linear_gaussian + + inference = snpe_method(prior=prior) + theta, x = simulate_for_sbi(simulator, prior, 1000) + posterior_estimator = inference.append_simulations(theta, x).train(max_num_epochs=3) + + x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) + + posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior) + + samples = posterior.sample_batched((10,), x_o) + batched_log_probs = posterior.log_prob_batched(samples, x_o) + + assert ( + samples.shape == (10, x_o_batch_dim, num_dim) + if x_o_batch_dim > 0 + else (10, num_dim) + ), "Sample shape wrong" + assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)), "logprob shape wrong" + + +@pytest.mark.mcmc +@pytest.mark.parametrize( + "snlre_method", + [ + pytest.param(SNLE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(SNRE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(SNRE_B, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(SNRE_C, marks=pytest.mark.xfail(raises=NotImplementedError)), + ], +) +@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) +def test_batched_mcmc_sample_log_prob_with_different_x( + snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict +): + num_dim = 2 + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + simulator = diagonal_linear_gaussian + + inference = snlre_method(prior=prior) + theta, x = simulate_for_sbi(simulator, prior, 1000) + _ = inference.append_simulations(theta, x).train(max_num_epochs=3) + + x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) + + posterior = inference.build_posterior( + mcmc_method="slice_np_vectorized", mcmc_parameters=mcmc_params_fast + ) + + samples = posterior.sample_batched((10,), x_o) + + assert ( + samples.shape == (10, x_o_batch_dim, num_dim) + if x_o_batch_dim > 0 + else (10, num_dim) + ), "Sample shape wrong" diff --git a/tests/test_utils.py b/tests/test_utils.py index b6730e65b..a1cea1e07 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -246,6 +246,17 @@ def sample( return self.potential_fn.posterior.sample(sample_shape) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not supported for TractablePosterior." + ) + def log_prob( self, theta: Tensor,