From 17c534303343bd6306ea8e45fd4085a929ba42c2 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 29 Apr 2024 09:04:20 +0200 Subject: [PATCH 01/71] Base estimator class --- sbi/neural_nets/density_estimators/base.py | 136 ++++++++++++++++----- 1 file changed, 105 insertions(+), 31 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index 252c850bc..4967b6df6 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,10 +1,113 @@ +from abc import ABC, abstractmethod from typing import Optional, Tuple import torch from torch import Tensor, nn -class DensityEstimator(nn.Module): +class Estimator(nn.Module, ABC): + r"""Base class for estimators i.e. neural nets estimating a certain quantity that + characterizes a distribution. This for example can be: + - Conditional density estimator of the posterior $p(\theta|x)$. + - Conditional density estimator of the likelihood $p(x|\theta)$. + - Estimator of the density ratio $p(x|\theta)/p(x)$. + - and more ... + """ + + def __init__(self, input_shape: torch.Size, condition_shape: torch.Size) -> None: + r"""Base class for estimators. + + Args: + input_shape: Event shape of the input at which the density is being + evaluated (and which is also the event_shape of samples). + condition_shape: Shape of the condition. If not provided, it will assume a + 1D input. + """ + super().__init__() + self._input_shape = torch.Size(input_shape) + self._condition_shape = torch.Size(condition_shape) + + @property + def input_shape(self) -> torch.Size: + r"""Return the input shape.""" + return self._input_shape + + @property + def condition_shape(self) -> torch.Size: + r"""Return the condition shape.""" + return self._condition_shape + + @abstractmethod + def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: + r"""Return the loss for training the estimator. + + Args: + input: Inputs to evaluate the loss on of shape + `(batch_dim, *input_event_shape)`. + condition: Conditions of shape `(batch_dim, *event_shape_condition)`. + + Returns: + Loss of shape (batch_dim,) + """ + pass + + def _check_condition_shape(self, condition: Tensor): + r"""This method checks whether the condition has the correct shape. + + Args: + condition: Conditions of shape `(batch_dim, *event_shape_condition)`. + + Raises: + ValueError: If the condition has a dimensionality that does not match + the expected input dimensionality. + ValueError: If the shape of the condition does not match the expected + input dimensionality. + """ + if len(condition.shape) < len(self.condition_shape): + raise ValueError( + f"Dimensionality of condition is to small and does not match the\ + expected input dimensionality {len(self.condition_shape)}, as provided\ + by condition_shape." + ) + else: + condition_shape = condition.shape[-len(self.condition_shape) :] + if tuple(condition_shape) != tuple(self.condition_shape): + raise ValueError( + f"Shape of condition {tuple(condition_shape)} does not match the \ + expected input dimensionality {tuple(self.condition_shape)}, as \ + provided by condition_shape. Please reshape it accordingly." + ) + + def _check_input_shape(self, input: Tensor): + r"""This method checks whether the input has the correct shape. + + Args: + input: Inputs to evaluate the log probability on of shape + `(sample_dim_input, batch_dim_input, *event_shape_input)`. + + Raises: + ValueError: If the input has a dimensionality that does not match + the expected input dimensionality. + ValueError: If the shape of the input does not match the expected + input dimensionality. + """ + if len(input.shape) < len(self.input_shape): + raise ValueError( + f"Dimensionality of input is to small and does not match the expected \ + input dimensionality {len(self.input_shape)}, as provided by \ + input_shape." + ) + else: + input_shape = input.shape[-len(self.input_shape) :] + if tuple(input_shape) != tuple(self.input_shape): + raise ValueError( + f"Shape of input {tuple(input_shape)} does not match the expected \ + input dimensionality {tuple(self.input_shape)}, as provided by \ + input_shape. Please reshape it accordingly." + ) + + +class DensityEstimator(Estimator): r"""Base class for density estimators. The density estimator class is a wrapper around neural networks that @@ -31,10 +134,8 @@ def __init__( condition_shape: Shape of the condition. If not provided, it will assume a 1D input. """ - super().__init__() + super().__init__(input_shape, condition_shape) self.net = net - self.input_shape = input_shape - self.condition_shape = condition_shape @property def embedding_net(self) -> Optional[nn.Module]: @@ -108,30 +209,3 @@ def sample_and_log_prob( samples = self.sample(sample_shape, condition, **kwargs) log_probs = self.log_prob(samples, condition, **kwargs) return samples, log_probs - - def _check_condition_shape(self, condition: Tensor): - r"""This method checks whether the condition has the correct shape. - - Args: - condition: Conditions of shape `(batch_dim, *event_shape_condition)`. - - Raises: - ValueError: If the condition has a dimensionality that does not match - the expected input dimensionality. - ValueError: If the shape of the condition does not match the expected - input dimensionality. - """ - if len(condition.shape) < len(self.condition_shape): - raise ValueError( - f"Dimensionality of condition is to small and does not match the\ - expected input dimensionality {len(self.condition_shape)}, as provided\ - by condition_shape." - ) - else: - condition_shape = condition.shape[-len(self.condition_shape) :] - if tuple(condition_shape) != tuple(self.condition_shape): - raise ValueError( - f"Shape of condition {tuple(condition_shape)} does not match the \ - expected input dimensionality {tuple(self.condition_shape)}, as \ - provided by condition_shape. Please reshape it accordingly." - ) From 705e9df84cf4b75ff8b0e92c232592c958163cce Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 3 May 2024 16:14:43 +0200 Subject: [PATCH 02/71] intermediate commit --- sbi/inference/posteriors/base_posterior.py | 12 +++++ sbi/inference/posteriors/direct_posterior.py | 51 ++++++++++++++++++++ sbi/samplers/rejection/rejection.py | 19 ++++++-- 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 0db66b1cb..f5b9cf62e 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -121,6 +121,18 @@ def sample( """See child classes for docstring.""" pass + @abstractmethod + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + show_progress_bars: bool = True, + mcmc_method: Optional[str] = None, + mcmc_parameters: Optional[Dict[str, Any]] = None, + ) -> Tensor: + """See child classes for docstring.""" + pass + @property def default_x(self) -> Optional[Tensor]: """Return default x used by `.sample(), .log_prob` as conditioning context.""" diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index fb20a580e..388ed53d0 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -135,6 +135,57 @@ def sample( return samples + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + max_sampling_batch_size: int = 10_000, + sample_with: Optional[str] = None, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Return samples from posterior $p(\theta|x)$ given multiple observations. + + Args: + sample_shape: Desired shape of samples that are drawn from the posterior + given every observation. + x: A batch of observations, of shape `(batch_dim, event_shape_x)`. + `batch_dim` corresponds to the number of observations to be drawn. + sample_with: This argument only exists to keep backward-compatibility with + `sbi` v0.17.2 or older. If it is set, we instantly raise an error. + show_progress_bars: Whether to show sampling progress monitor. + """ + + num_samples = torch.Size(sample_shape).numel() + # x = self._x_else_default_x(x) + x = reshape_to_batch_event( + x, event_shape=self.posterior_estimator.condition_shape + ) + + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size + ) + + if sample_with is not None: + raise ValueError( + f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting " + f"`sample_with` is no longer supported. You have to rerun " + f"`.build_posterior(sample_with={sample_with}).`" + ) + + samples = accept_reject_sample( + proposal=self.posterior_estimator, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"condition": x}, + alternative_method="build_posterior(..., sample_with='mcmc')", + )[0] + + return samples + def log_prob( self, theta: Tensor, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 5c78cd8dd..791ffcd98 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -253,8 +253,11 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) + print("Bullas;fjkdsafj;dlsjfldsaj") num_sampled_total, num_remaining = 0, num_samples + num_xo = proposal_sampling_kwargs["condition"].shape[0] + accepted_every_obs = [torch.tensor((0, 2)) for _ in range(num_xo)] accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False # Ruff suggestion @@ -272,15 +275,23 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - samples = candidates[are_accepted] - accepted.append(samples) + print("are_accepted", are_accepted.shape) + for obs_index in range(num_xo): + accepted = candidates[are_accepted[:, obs_index], obs_index] + print("accepted", accepted.shape) + print("accepted_every_obs[obs_index]", accepted_every_obs[obs_index].shape) + here = accepted_every_obs[obs_index] + print("here", here.shape) + print("acc", accepted.shape) + accepted_every_obs[obs_index] = torch.cat([accepted_every_obs[obs_index], accepted], dim=0) + lowest_num_accepted = min(len(s) for s in accepted_every_obs) + num_remaining = num_samples - lowest_num_accepted # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work # in dim = -2. num_sampled_total += sampling_batch_size - num_remaining -= samples.shape[-2] pbar.update(samples.shape[-2]) # To avoid endless sampling when leakage is high, we raise a warning if the @@ -331,7 +342,7 @@ def accept_reject_sample( pbar.close() # When in case of leakage a batch size was used there could be too many samples. - samples = torch.cat(accepted, dim=-2)[..., :num_samples, :] + samples = samples[..., :num_samples, :] assert ( samples.shape[-2] == num_samples ), "Number of accepted samples must match required samples." From 07b53cdb29afe00504f5a90b9adf678723551e24 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 3 May 2024 17:05:01 +0200 Subject: [PATCH 03/71] make autoreload work --- sbi/inference/posteriors/direct_posterior.py | 4 ++-- sbi/samplers/rejection/rejection.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 388ed53d0..5928c52b7 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -16,7 +16,7 @@ reshape_to_batch_event, reshape_to_sample_batch_event, ) -from sbi.samplers.rejection.rejection import accept_reject_sample +from sbi.samplers.rejection import rejection from sbi.sbi_types import Shape from sbi.utils import check_prior, within_support from sbi.utils.torchutils import ensure_theta_batched @@ -123,7 +123,7 @@ def sample( f"`.build_posterior(sample_with={sample_with}).`" ) - samples = accept_reject_sample( + samples = rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 791ffcd98..149446844 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -253,7 +253,6 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) - print("Bullas;fjkdsafj;dlsjfldsaj") num_sampled_total, num_remaining = 0, num_samples num_xo = proposal_sampling_kwargs["condition"].shape[0] From dd02e227ad555bdbbc997a30ff8f03e58e9f5f4f Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Sun, 5 May 2024 17:09:56 +0200 Subject: [PATCH 04/71] `amortized_sample` works for MCMCPosterior --- sbi/inference/posteriors/direct_posterior.py | 4 +- sbi/inference/posteriors/mcmc_posterior.py | 102 +++++++++++++++++- .../potentials/likelihood_based_potential.py | 5 + sbi/samplers/rejection/rejection.py | 28 ++--- 4 files changed, 120 insertions(+), 19 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 5928c52b7..ed131e218 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -133,7 +133,7 @@ def sample( alternative_method="build_posterior(..., sample_with='mcmc')", )[0] - return samples + return samples[:, 0] # Remove batch dimension. def amortized_sample( self, @@ -174,7 +174,7 @@ def amortized_sample( f"`.build_posterior(sample_with={sample_with}).`" ) - samples = accept_reject_sample( + samples = rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 5ef9f882a..c47db3e7b 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -17,6 +17,7 @@ from torch import Tensor from torch import multiprocessing as mp from tqdm.auto import tqdm +import numpy as np from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.potentials.base_potential import BasePotential @@ -355,6 +356,83 @@ def sample( return samples.reshape((*sample_shape, -1)) # type: ignore + + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + method: Optional[str] = None, + thin: Optional[int] = None, + warmup_steps: Optional[int] = None, + num_chains: Optional[int] = None, + init_strategy: Optional[str] = None, + init_strategy_parameters: Optional[Dict[str, Any]] = None, + num_workers: Optional[int] = None, + mp_context: Optional[str] = None, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC. + + Check the `__init__()` method for a description of all arguments as well as + their default values. + + Args: + sample_shape: Desired shape of samples that are drawn from posterior. If + sample_shape is multidimensional we simply draw `sample_shape.numel()` + samples and then reshape into the desired shape. + show_progress_bars: Whether to show sampling progress monitor. + + Returns: + Samples from posterior. + """ + self.potential_fn.set_x(self._x_else_default_x(x)) + + # Replace arguments that were not passed with their default. + method = self.method if method is None else method + thin = self.thin if thin is None else thin + warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps + num_chains = self.num_chains if num_chains is None else num_chains + init_strategy = self.init_strategy if init_strategy is None else init_strategy + num_workers = self.num_workers if num_workers is None else num_workers + mp_context = self.mp_context if mp_context is None else mp_context + init_strategy_parameters = ( + self.init_strategy_parameters + if init_strategy_parameters is None + else init_strategy_parameters + ) + self.potential_ = self._prepare_potential(method) # type: ignore + + print("Getting initial params") + initial_params = self._get_initial_params( + init_strategy, # type: ignore + num_chains, # type: ignore + num_workers, + show_progress_bars, + **init_strategy_parameters, + ) + print("Finished init") + num_samples = torch.Size(sample_shape).numel() + + assert method == "slice_np_vectorized" + with torch.set_grad_enabled(False): + transformed_samples = self._slice_np_mcmc( + num_samples=num_samples, + potential_function=self.potential_, + initial_params=initial_params, + thin=thin, # type: ignore + warmup_steps=warmup_steps, # type: ignore + vectorized=(method == "slice_np_vectorized"), + num_workers=num_workers, + show_progress_bars=show_progress_bars, + ) + print("transformed_samples", transformed_samples.shape) + samples = self.theta_transform.inv(transformed_samples) + print("samples", samples.shape) + num_obs = 5 + + return samples.reshape((*sample_shape, num_obs, -1)) # type: ignore + + def _build_mcmc_init_fn( self, proposal: Any, @@ -507,10 +585,24 @@ def _slice_np_mcmc( else: SliceSamplerMultiChain = SliceSamplerVectorized + def multi_obs_potential(params): + # Params are of shape (num_chains * num_obs, event). + # We now reshape them to (num_chains, num_obs, event). + # params = np.reshape(params, (num_chains, num_obs, -1)) + # print("params", params.shape) + # print("potential_function", potential_function) + + # `all_potentials` is of shape (num_chains, num_obs). + all_potentials = potential_function(params) + return all_potentials.flatten() + + num_obs = 5 + initial_params = torch.concatenate([initial_params] * num_obs) + posterior_sampler = SliceSamplerMultiChain( init_params=tensor2numpy(initial_params), - log_prob_fn=potential_function, - num_chains=num_chains, + log_prob_fn=multi_obs_potential, + num_chains=num_chains * num_obs, thin=thin, verbose=show_progress_bars, num_workers=num_workers, @@ -519,7 +611,9 @@ def _slice_np_mcmc( warmup_ = warmup_steps * thin num_samples_ = ceil((num_samples * thin) / num_chains) # Run mcmc including warmup + print("Start run") samples = posterior_sampler.run(warmup_ + num_samples_) + print("Finish run") samples = samples[:, warmup_steps:, :] # discard warmup steps samples = torch.from_numpy(samples) # chains x samples x dim @@ -527,10 +621,10 @@ def _slice_np_mcmc( self._posterior_sampler = posterior_sampler # Save sample as potential next init (if init_strategy == 'latest_sample'). - self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) + self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, num_obs, dim_samples) # Collect samples from all chains. - samples = samples.reshape(-1, dim_samples)[:num_samples] + samples = samples.reshape(-1, num_obs, dim_samples)[:num_samples] return samples.type(torch.float32).to(self._device) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index eab36e91f..a9e94e7e2 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -123,6 +123,8 @@ def _log_likelihoods_over_trials( log_likelihood_trial_sum: log likelihood for each parameter, summed over all batch entries (iid trials) in `x`. """ + # print("x", x.shape) + # print("theta", theta.shape) # Shape of `x` is (iid_dim, *event_shape). x = reshape_to_sample_batch_event( x, event_shape=x.shape[1:], leading_is_sample=True @@ -146,6 +148,9 @@ def _log_likelihoods_over_trials( # `DensityEstimator.log_prob`. theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:]) + # print("After reshape theta: ", theta.shape) + # print("After reshape x: ", x.shape) + # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): log_likelihood_trial_batch = estimator.log_prob(x, condition=theta) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 149446844..ed9625b97 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -236,7 +236,8 @@ def accept_reject_sample( `rejection_sample()`. Warn if not empty. Returns: - Accepted samples and acceptance rate as scalar Tensor. + Accepted samples of shape `(sample_dim, batch_dim, *event_shape)`, and + worst-case acceptance rate as scalar Tensor. """ if kwargs: @@ -255,8 +256,12 @@ def accept_reject_sample( ) num_sampled_total, num_remaining = 0, num_samples - num_xo = proposal_sampling_kwargs["condition"].shape[0] - accepted_every_obs = [torch.tensor((0, 2)) for _ in range(num_xo)] + if "condition" in list(proposal_sampling_kwargs.keys()): + num_xo = proposal_sampling_kwargs["condition"].shape[0] + else: + num_xo = 1 + + accepted_every_obs = [torch.tensor(()) for _ in range(num_xo)] accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False # Ruff suggestion @@ -274,14 +279,8 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - print("are_accepted", are_accepted.shape) for obs_index in range(num_xo): accepted = candidates[are_accepted[:, obs_index], obs_index] - print("accepted", accepted.shape) - print("accepted_every_obs[obs_index]", accepted_every_obs[obs_index].shape) - here = accepted_every_obs[obs_index] - print("here", here.shape) - print("acc", accepted.shape) accepted_every_obs[obs_index] = torch.cat([accepted_every_obs[obs_index], accepted], dim=0) lowest_num_accepted = min(len(s) for s in accepted_every_obs) num_remaining = num_samples - lowest_num_accepted @@ -291,7 +290,7 @@ def accept_reject_sample( # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work # in dim = -2. num_sampled_total += sampling_batch_size - pbar.update(samples.shape[-2]) + pbar.update(num_samples - num_remaining) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. @@ -340,10 +339,13 @@ def accept_reject_sample( pbar.close() + for obs_index in range(num_xo): + accepted_every_obs[obs_index] = accepted_every_obs[obs_index][:num_samples] + + accepted_every_obs = torch.stack(accepted_every_obs) # When in case of leakage a batch size was used there could be too many samples. - samples = samples[..., :num_samples, :] assert ( - samples.shape[-2] == num_samples + accepted_every_obs.shape[-2] == num_samples ), "Number of accepted samples must match required samples." - return samples, as_tensor(acceptance_rate) + return torch.permute(accepted_every_obs, (1, 0, -1)), as_tensor(acceptance_rate) From 663185b4a59a386545dcff618b37c2216b2fff35 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 08:27:39 +0200 Subject: [PATCH 05/71] fixes current bug! --- sbi/neural_nets/density_estimators/nflows_flow.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index a1de5355c..3f162493f 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -135,12 +135,14 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: num_samples = torch.Size(sample_shape).numel() samples = self.net.sample(num_samples, context=condition) - - return samples.reshape(( - *sample_shape, - condition_batch_dim, - -1, - )) + samples = samples.transpose(0, 1) + return samples.reshape( + ( + *sample_shape, + condition_batch_dim, + ) + + self.input_shape + ) def sample_and_log_prob( self, sample_shape: torch.Size, condition: Tensor, **kwargs From df8899a5dd575a6062199c34f0e5a540a71f8e23 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 10:44:15 +0200 Subject: [PATCH 06/71] Added tests --- .../density_estimators/nflows_flow.py | 8 +-- tests/density_estimator_test.py | 69 +++++++++++++++++++ 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 3f162493f..8d6aaba55 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -136,13 +136,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: samples = self.net.sample(num_samples, context=condition) samples = samples.transpose(0, 1) - return samples.reshape( - ( - *sample_shape, - condition_batch_dim, - ) - + self.input_shape - ) + return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape)) def sample_and_log_prob( self, sample_shape: torch.Size, condition: Tensor, **kwargs diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 7543a8241..35fb0d946 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -283,6 +283,75 @@ def test_correctness_of_density_estimator_log_prob( assert torch.allclose(log_probs[0, :], log_probs[1, :]) +@pytest.mark.parametrize( + "density_estimator_build_fn", + ( + build_mdn, + build_maf, + build_maf_rqs, + build_nsf, + build_zuko_bpf, + build_zuko_gf, + build_zuko_maf, + build_zuko_naf, + build_zuko_ncsf, + build_zuko_nice, + build_zuko_nsf, + build_zuko_sospf, + build_zuko_unaf, + build_categoricalmassestimator, + build_mnle, + ), +) +@pytest.mark.parametrize("input_event_shape", ((1,), (4,))) +@pytest.mark.parametrize("condition_event_shape", ((1,), (7,))) +def test_correctness_of_batched_vs_seperate_sample_and_log_prob( + density_estimator_build_fn, input_event_shape, condition_event_shape +): + input_sample_dim = 2 + batch_dim = 2 + density_estimator, inputs, condition = _build_density_estimator_and_tensors( + density_estimator_build_fn, + input_event_shape, + condition_event_shape, + batch_dim, + input_sample_dim, + ) + # Batched vs separate sampling + samples = density_estimator.sample((1000,), condition=condition) + samples_separate1 = density_estimator.sample( + (1000,), condition=condition[0][None, ...] + ) + samples_separate2 = density_estimator.sample( + (1000,), condition=condition[1][None, ...] + ) + + # Check if means are approx. same + samples_m = torch.mean(samples, dim=0, dtype=torch.float32) + samples_separate1_m = torch.mean(samples_separate1, dim=0, dtype=torch.float32) + samples_separate2_m = torch.mean(samples_separate2, dim=0, dtype=torch.float32) + samples_sep_m = torch.cat([samples_separate1_m, samples_separate2_m], dim=0) + + assert torch.allclose( + samples_m, samples_sep_m, atol=0.5, rtol=0.5 + ), "Batched sampling is not consistent with separate sampling." + + # Batched vs separate log_prob + log_probs = density_estimator.log_prob(inputs, condition=condition) + + log_probs_separate1 = density_estimator.log_prob( + inputs[:, :1], condition=condition[0][None, ...] + ) + log_probs_separate2 = density_estimator.log_prob( + inputs[:, 1:], condition=condition[1][None, ...] + ) + log_probs_sep = torch.hstack([log_probs_separate1, log_probs_separate2]) + + assert torch.allclose( + log_probs, log_probs_sep, atol=1e-2, rtol=1e-2 + ), "Batched log_prob is not consistent with separate log_prob." + + def _build_density_estimator_and_tensors( density_estimator_build_fn: str, input_event_shape: Tuple[int], From aa82aab59b9df46da458dbe66d5056cc854b3f5d Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 17:25:11 +0200 Subject: [PATCH 07/71] batched_rejection_sampling --- sbi/inference/posteriors/direct_posterior.py | 29 +++++++++++++++++++ sbi/samplers/rejection/rejection.py | 30 +++++++++++++------- sbi/utils/sbiutils.py | 3 +- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index fb20a580e..72733151f 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -135,6 +135,35 @@ def sample( return samples + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10_000, + show_progress_bars: bool = True, + ) -> Tensor: + num_samples = torch.Size(sample_shape).numel() + condition_shape = self.posterior_estimator.condition_shape + x = reshape_to_batch_event(x, event_shape=condition_shape) + print(x.shape) + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size + ) + + samples = accept_reject_sample( + proposal=self.posterior_estimator, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"condition": x}, + alternative_method="build_posterior(..., sample_with='mcmc')", + )[0] + + return samples + def log_prob( self, theta: Tensor, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 5c78cd8dd..43988e244 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -253,13 +253,18 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) + if proposal_sampling_kwargs is None: + proposal_sampling_kwargs = {} num_sampled_total, num_remaining = 0, num_samples - accepted, acceptance_rate = [], float("Nan") + if "condition" in proposal_sampling_kwargs: + num_xos = proposal_sampling_kwargs["condition"].shape[0] + else: + num_xos = 1 + accepted, acceptance_rate = [[] for _ in range(num_xos)], float("Nan") + num_accepted = torch.zeros(num_xos) leakage_warning_raised = False # Ruff suggestion - if proposal_sampling_kwargs is None: - proposal_sampling_kwargs = {} # To cover cases with few samples without leakage: sampling_batch_size = min(num_samples, max_sampling_batch_size) @@ -272,16 +277,18 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - samples = candidates[are_accepted] - accepted.append(samples) + num_accepted += are_accepted.sum(dim=0) + + for i in range(num_xos): + accepted[i].append(candidates[:, i][are_accepted[:, i]]) # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the - # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work - # in dim = -2. + # samples will be of shape(sampling_batch_size,*batch_shape, *event_shape) + # and hence work in dim = 0. num_sampled_total += sampling_batch_size - num_remaining -= samples.shape[-2] - pbar.update(samples.shape[-2]) + num_remaining -= num_accepted.min().item() + pbar.update(num_accepted.mean().item()) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. @@ -331,9 +338,10 @@ def accept_reject_sample( pbar.close() # When in case of leakage a batch size was used there could be too many samples. - samples = torch.cat(accepted, dim=-2)[..., :num_samples, :] + samples = [torch.cat(accepted[i], dim=0)[:num_samples] for i in range(num_xos)] + samples = torch.stack(samples, dim=1) assert ( - samples.shape[-2] == num_samples + samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 58146ccae..a2b0f102f 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -627,7 +627,8 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: try: sample_check = distribution.support.check(samples) if sample_check.shape == samples.shape: - sample_check = torch.all(sample_check, dim=-1) + # With new shapeing conventions we need dim=-2 + sample_check = torch.all(sample_check, dim=-2) return sample_check # Falling back to log prob method of either the NeuralPosterior's net, or of a From 00cdadeaa3fadf2d3f0d8e72a0dc985ac59add81 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 3 May 2024 16:14:43 +0200 Subject: [PATCH 08/71] intermediate commit --- sbi/inference/posteriors/base_posterior.py | 12 +++++ sbi/inference/posteriors/direct_posterior.py | 51 ++++++++++++++++++++ sbi/samplers/rejection/rejection.py | 19 ++++++-- 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 0db66b1cb..f5b9cf62e 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -121,6 +121,18 @@ def sample( """See child classes for docstring.""" pass + @abstractmethod + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + show_progress_bars: bool = True, + mcmc_method: Optional[str] = None, + mcmc_parameters: Optional[Dict[str, Any]] = None, + ) -> Tensor: + """See child classes for docstring.""" + pass + @property def default_x(self) -> Optional[Tensor]: """Return default x used by `.sample(), .log_prob` as conditioning context.""" diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index fb20a580e..388ed53d0 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -135,6 +135,57 @@ def sample( return samples + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + max_sampling_batch_size: int = 10_000, + sample_with: Optional[str] = None, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Return samples from posterior $p(\theta|x)$ given multiple observations. + + Args: + sample_shape: Desired shape of samples that are drawn from the posterior + given every observation. + x: A batch of observations, of shape `(batch_dim, event_shape_x)`. + `batch_dim` corresponds to the number of observations to be drawn. + sample_with: This argument only exists to keep backward-compatibility with + `sbi` v0.17.2 or older. If it is set, we instantly raise an error. + show_progress_bars: Whether to show sampling progress monitor. + """ + + num_samples = torch.Size(sample_shape).numel() + # x = self._x_else_default_x(x) + x = reshape_to_batch_event( + x, event_shape=self.posterior_estimator.condition_shape + ) + + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size + ) + + if sample_with is not None: + raise ValueError( + f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting " + f"`sample_with` is no longer supported. You have to rerun " + f"`.build_posterior(sample_with={sample_with}).`" + ) + + samples = accept_reject_sample( + proposal=self.posterior_estimator, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"condition": x}, + alternative_method="build_posterior(..., sample_with='mcmc')", + )[0] + + return samples + def log_prob( self, theta: Tensor, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 5c78cd8dd..791ffcd98 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -253,8 +253,11 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) + print("Bullas;fjkdsafj;dlsjfldsaj") num_sampled_total, num_remaining = 0, num_samples + num_xo = proposal_sampling_kwargs["condition"].shape[0] + accepted_every_obs = [torch.tensor((0, 2)) for _ in range(num_xo)] accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False # Ruff suggestion @@ -272,15 +275,23 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - samples = candidates[are_accepted] - accepted.append(samples) + print("are_accepted", are_accepted.shape) + for obs_index in range(num_xo): + accepted = candidates[are_accepted[:, obs_index], obs_index] + print("accepted", accepted.shape) + print("accepted_every_obs[obs_index]", accepted_every_obs[obs_index].shape) + here = accepted_every_obs[obs_index] + print("here", here.shape) + print("acc", accepted.shape) + accepted_every_obs[obs_index] = torch.cat([accepted_every_obs[obs_index], accepted], dim=0) + lowest_num_accepted = min(len(s) for s in accepted_every_obs) + num_remaining = num_samples - lowest_num_accepted # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work # in dim = -2. num_sampled_total += sampling_batch_size - num_remaining -= samples.shape[-2] pbar.update(samples.shape[-2]) # To avoid endless sampling when leakage is high, we raise a warning if the @@ -331,7 +342,7 @@ def accept_reject_sample( pbar.close() # When in case of leakage a batch size was used there could be too many samples. - samples = torch.cat(accepted, dim=-2)[..., :num_samples, :] + samples = samples[..., :num_samples, :] assert ( samples.shape[-2] == num_samples ), "Number of accepted samples must match required samples." From cb8e4d8186805ce461a00ed83fcd8dfcf6cb3242 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 3 May 2024 17:05:01 +0200 Subject: [PATCH 09/71] make autoreload work --- sbi/inference/posteriors/direct_posterior.py | 4 ++-- sbi/samplers/rejection/rejection.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 388ed53d0..5928c52b7 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -16,7 +16,7 @@ reshape_to_batch_event, reshape_to_sample_batch_event, ) -from sbi.samplers.rejection.rejection import accept_reject_sample +from sbi.samplers.rejection import rejection from sbi.sbi_types import Shape from sbi.utils import check_prior, within_support from sbi.utils.torchutils import ensure_theta_batched @@ -123,7 +123,7 @@ def sample( f"`.build_posterior(sample_with={sample_with}).`" ) - samples = accept_reject_sample( + samples = rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 791ffcd98..149446844 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -253,7 +253,6 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) - print("Bullas;fjkdsafj;dlsjfldsaj") num_sampled_total, num_remaining = 0, num_samples num_xo = proposal_sampling_kwargs["condition"].shape[0] From d64557fd7e8863ac0e9965150a38f7f2568a9bf0 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Sun, 5 May 2024 17:09:56 +0200 Subject: [PATCH 10/71] `amortized_sample` works for MCMCPosterior --- sbi/inference/posteriors/direct_posterior.py | 4 +- sbi/inference/posteriors/mcmc_posterior.py | 102 +++++++++++++++++- .../potentials/likelihood_based_potential.py | 5 + sbi/samplers/rejection/rejection.py | 28 ++--- 4 files changed, 120 insertions(+), 19 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 5928c52b7..ed131e218 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -133,7 +133,7 @@ def sample( alternative_method="build_posterior(..., sample_with='mcmc')", )[0] - return samples + return samples[:, 0] # Remove batch dimension. def amortized_sample( self, @@ -174,7 +174,7 @@ def amortized_sample( f"`.build_posterior(sample_with={sample_with}).`" ) - samples = accept_reject_sample( + samples = rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 5ef9f882a..c47db3e7b 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -17,6 +17,7 @@ from torch import Tensor from torch import multiprocessing as mp from tqdm.auto import tqdm +import numpy as np from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.potentials.base_potential import BasePotential @@ -355,6 +356,83 @@ def sample( return samples.reshape((*sample_shape, -1)) # type: ignore + + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + method: Optional[str] = None, + thin: Optional[int] = None, + warmup_steps: Optional[int] = None, + num_chains: Optional[int] = None, + init_strategy: Optional[str] = None, + init_strategy_parameters: Optional[Dict[str, Any]] = None, + num_workers: Optional[int] = None, + mp_context: Optional[str] = None, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC. + + Check the `__init__()` method for a description of all arguments as well as + their default values. + + Args: + sample_shape: Desired shape of samples that are drawn from posterior. If + sample_shape is multidimensional we simply draw `sample_shape.numel()` + samples and then reshape into the desired shape. + show_progress_bars: Whether to show sampling progress monitor. + + Returns: + Samples from posterior. + """ + self.potential_fn.set_x(self._x_else_default_x(x)) + + # Replace arguments that were not passed with their default. + method = self.method if method is None else method + thin = self.thin if thin is None else thin + warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps + num_chains = self.num_chains if num_chains is None else num_chains + init_strategy = self.init_strategy if init_strategy is None else init_strategy + num_workers = self.num_workers if num_workers is None else num_workers + mp_context = self.mp_context if mp_context is None else mp_context + init_strategy_parameters = ( + self.init_strategy_parameters + if init_strategy_parameters is None + else init_strategy_parameters + ) + self.potential_ = self._prepare_potential(method) # type: ignore + + print("Getting initial params") + initial_params = self._get_initial_params( + init_strategy, # type: ignore + num_chains, # type: ignore + num_workers, + show_progress_bars, + **init_strategy_parameters, + ) + print("Finished init") + num_samples = torch.Size(sample_shape).numel() + + assert method == "slice_np_vectorized" + with torch.set_grad_enabled(False): + transformed_samples = self._slice_np_mcmc( + num_samples=num_samples, + potential_function=self.potential_, + initial_params=initial_params, + thin=thin, # type: ignore + warmup_steps=warmup_steps, # type: ignore + vectorized=(method == "slice_np_vectorized"), + num_workers=num_workers, + show_progress_bars=show_progress_bars, + ) + print("transformed_samples", transformed_samples.shape) + samples = self.theta_transform.inv(transformed_samples) + print("samples", samples.shape) + num_obs = 5 + + return samples.reshape((*sample_shape, num_obs, -1)) # type: ignore + + def _build_mcmc_init_fn( self, proposal: Any, @@ -507,10 +585,24 @@ def _slice_np_mcmc( else: SliceSamplerMultiChain = SliceSamplerVectorized + def multi_obs_potential(params): + # Params are of shape (num_chains * num_obs, event). + # We now reshape them to (num_chains, num_obs, event). + # params = np.reshape(params, (num_chains, num_obs, -1)) + # print("params", params.shape) + # print("potential_function", potential_function) + + # `all_potentials` is of shape (num_chains, num_obs). + all_potentials = potential_function(params) + return all_potentials.flatten() + + num_obs = 5 + initial_params = torch.concatenate([initial_params] * num_obs) + posterior_sampler = SliceSamplerMultiChain( init_params=tensor2numpy(initial_params), - log_prob_fn=potential_function, - num_chains=num_chains, + log_prob_fn=multi_obs_potential, + num_chains=num_chains * num_obs, thin=thin, verbose=show_progress_bars, num_workers=num_workers, @@ -519,7 +611,9 @@ def _slice_np_mcmc( warmup_ = warmup_steps * thin num_samples_ = ceil((num_samples * thin) / num_chains) # Run mcmc including warmup + print("Start run") samples = posterior_sampler.run(warmup_ + num_samples_) + print("Finish run") samples = samples[:, warmup_steps:, :] # discard warmup steps samples = torch.from_numpy(samples) # chains x samples x dim @@ -527,10 +621,10 @@ def _slice_np_mcmc( self._posterior_sampler = posterior_sampler # Save sample as potential next init (if init_strategy == 'latest_sample'). - self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) + self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, num_obs, dim_samples) # Collect samples from all chains. - samples = samples.reshape(-1, dim_samples)[:num_samples] + samples = samples.reshape(-1, num_obs, dim_samples)[:num_samples] return samples.type(torch.float32).to(self._device) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index eab36e91f..a9e94e7e2 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -123,6 +123,8 @@ def _log_likelihoods_over_trials( log_likelihood_trial_sum: log likelihood for each parameter, summed over all batch entries (iid trials) in `x`. """ + # print("x", x.shape) + # print("theta", theta.shape) # Shape of `x` is (iid_dim, *event_shape). x = reshape_to_sample_batch_event( x, event_shape=x.shape[1:], leading_is_sample=True @@ -146,6 +148,9 @@ def _log_likelihoods_over_trials( # `DensityEstimator.log_prob`. theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:]) + # print("After reshape theta: ", theta.shape) + # print("After reshape x: ", x.shape) + # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): log_likelihood_trial_batch = estimator.log_prob(x, condition=theta) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 149446844..ed9625b97 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -236,7 +236,8 @@ def accept_reject_sample( `rejection_sample()`. Warn if not empty. Returns: - Accepted samples and acceptance rate as scalar Tensor. + Accepted samples of shape `(sample_dim, batch_dim, *event_shape)`, and + worst-case acceptance rate as scalar Tensor. """ if kwargs: @@ -255,8 +256,12 @@ def accept_reject_sample( ) num_sampled_total, num_remaining = 0, num_samples - num_xo = proposal_sampling_kwargs["condition"].shape[0] - accepted_every_obs = [torch.tensor((0, 2)) for _ in range(num_xo)] + if "condition" in list(proposal_sampling_kwargs.keys()): + num_xo = proposal_sampling_kwargs["condition"].shape[0] + else: + num_xo = 1 + + accepted_every_obs = [torch.tensor(()) for _ in range(num_xo)] accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False # Ruff suggestion @@ -274,14 +279,8 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - print("are_accepted", are_accepted.shape) for obs_index in range(num_xo): accepted = candidates[are_accepted[:, obs_index], obs_index] - print("accepted", accepted.shape) - print("accepted_every_obs[obs_index]", accepted_every_obs[obs_index].shape) - here = accepted_every_obs[obs_index] - print("here", here.shape) - print("acc", accepted.shape) accepted_every_obs[obs_index] = torch.cat([accepted_every_obs[obs_index], accepted], dim=0) lowest_num_accepted = min(len(s) for s in accepted_every_obs) num_remaining = num_samples - lowest_num_accepted @@ -291,7 +290,7 @@ def accept_reject_sample( # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work # in dim = -2. num_sampled_total += sampling_batch_size - pbar.update(samples.shape[-2]) + pbar.update(num_samples - num_remaining) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. @@ -340,10 +339,13 @@ def accept_reject_sample( pbar.close() + for obs_index in range(num_xo): + accepted_every_obs[obs_index] = accepted_every_obs[obs_index][:num_samples] + + accepted_every_obs = torch.stack(accepted_every_obs) # When in case of leakage a batch size was used there could be too many samples. - samples = samples[..., :num_samples, :] assert ( - samples.shape[-2] == num_samples + accepted_every_obs.shape[-2] == num_samples ), "Number of accepted samples must match required samples." - return samples, as_tensor(acceptance_rate) + return torch.permute(accepted_every_obs, (1, 0, -1)), as_tensor(acceptance_rate) From e54a2fb67d1cf211657a9aeaa59418293ef39485 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 18:39:58 +0200 Subject: [PATCH 11/71] Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-posteriors' into amortizedsample" This reverts commit 07084e28fb586d43605dba6786d60c3e48ed96e5, reversing changes made to f16622d552e0dd69b17855bea9d672594e11d8ce. --- sbi/neural_nets/density_estimators/base.py | 136 +++++---------------- 1 file changed, 31 insertions(+), 105 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index 6d56fbf64..b3b83567c 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,116 +1,13 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from abc import ABC, abstractmethod from typing import Optional, Tuple import torch from torch import Tensor, nn -class Estimator(nn.Module, ABC): - r"""Base class for estimators i.e. neural nets estimating a certain quantity that - characterizes a distribution. This for example can be: - - Conditional density estimator of the posterior $p(\theta|x)$. - - Conditional density estimator of the likelihood $p(x|\theta)$. - - Estimator of the density ratio $p(x|\theta)/p(x)$. - - and more ... - """ - - def __init__(self, input_shape: torch.Size, condition_shape: torch.Size) -> None: - r"""Base class for estimators. - - Args: - input_shape: Event shape of the input at which the density is being - evaluated (and which is also the event_shape of samples). - condition_shape: Shape of the condition. If not provided, it will assume a - 1D input. - """ - super().__init__() - self._input_shape = torch.Size(input_shape) - self._condition_shape = torch.Size(condition_shape) - - @property - def input_shape(self) -> torch.Size: - r"""Return the input shape.""" - return self._input_shape - - @property - def condition_shape(self) -> torch.Size: - r"""Return the condition shape.""" - return self._condition_shape - - @abstractmethod - def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: - r"""Return the loss for training the estimator. - - Args: - input: Inputs to evaluate the loss on of shape - `(batch_dim, *input_event_shape)`. - condition: Conditions of shape `(batch_dim, *event_shape_condition)`. - - Returns: - Loss of shape (batch_dim,) - """ - pass - - def _check_condition_shape(self, condition: Tensor): - r"""This method checks whether the condition has the correct shape. - - Args: - condition: Conditions of shape `(batch_dim, *event_shape_condition)`. - - Raises: - ValueError: If the condition has a dimensionality that does not match - the expected input dimensionality. - ValueError: If the shape of the condition does not match the expected - input dimensionality. - """ - if len(condition.shape) < len(self.condition_shape): - raise ValueError( - f"Dimensionality of condition is to small and does not match the\ - expected input dimensionality {len(self.condition_shape)}, as provided\ - by condition_shape." - ) - else: - condition_shape = condition.shape[-len(self.condition_shape) :] - if tuple(condition_shape) != tuple(self.condition_shape): - raise ValueError( - f"Shape of condition {tuple(condition_shape)} does not match the \ - expected input dimensionality {tuple(self.condition_shape)}, as \ - provided by condition_shape. Please reshape it accordingly." - ) - - def _check_input_shape(self, input: Tensor): - r"""This method checks whether the input has the correct shape. - - Args: - input: Inputs to evaluate the log probability on of shape - `(sample_dim_input, batch_dim_input, *event_shape_input)`. - - Raises: - ValueError: If the input has a dimensionality that does not match - the expected input dimensionality. - ValueError: If the shape of the input does not match the expected - input dimensionality. - """ - if len(input.shape) < len(self.input_shape): - raise ValueError( - f"Dimensionality of input is to small and does not match the expected \ - input dimensionality {len(self.input_shape)}, as provided by \ - input_shape." - ) - else: - input_shape = input.shape[-len(self.input_shape) :] - if tuple(input_shape) != tuple(self.input_shape): - raise ValueError( - f"Shape of input {tuple(input_shape)} does not match the expected \ - input dimensionality {tuple(self.input_shape)}, as provided by \ - input_shape. Please reshape it accordingly." - ) - - -class DensityEstimator(Estimator): +class DensityEstimator(nn.Module): r"""Base class for density estimators. The density estimator class is a wrapper around neural networks that @@ -137,8 +34,10 @@ def __init__( condition_shape: Shape of the condition. If not provided, it will assume a 1D input. """ - super().__init__(input_shape, condition_shape) + super().__init__() self.net = net + self.input_shape = input_shape + self.condition_shape = condition_shape @property def embedding_net(self) -> Optional[nn.Module]: @@ -212,3 +111,30 @@ def sample_and_log_prob( samples = self.sample(sample_shape, condition, **kwargs) log_probs = self.log_prob(samples, condition, **kwargs) return samples, log_probs + + def _check_condition_shape(self, condition: Tensor): + r"""This method checks whether the condition has the correct shape. + + Args: + condition: Conditions of shape `(batch_dim, *event_shape_condition)`. + + Raises: + ValueError: If the condition has a dimensionality that does not match + the expected input dimensionality. + ValueError: If the shape of the condition does not match the expected + input dimensionality. + """ + if len(condition.shape) < len(self.condition_shape): + raise ValueError( + f"Dimensionality of condition is to small and does not match the\ + expected input dimensionality {len(self.condition_shape)}, as provided\ + by condition_shape." + ) + else: + condition_shape = condition.shape[-len(self.condition_shape) :] + if tuple(condition_shape) != tuple(self.condition_shape): + raise ValueError( + f"Shape of condition {tuple(condition_shape)} does not match the \ + expected input dimensionality {tuple(self.condition_shape)}, as \ + provided by condition_shape. Please reshape it accordingly." + ) From cd808d5e6d0385cfc4b6410129c75e0ab1f1cb9d Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 19:40:59 +0200 Subject: [PATCH 12/71] sample works, try log_prob_batched --- sbi/inference/posteriors/direct_posterior.py | 97 ++++++++++---------- sbi/samplers/rejection/rejection.py | 20 ++-- sbi/utils/sbiutils.py | 2 +- 3 files changed, 62 insertions(+), 57 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index ec11cb8f1..150c524f5 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -135,12 +135,11 @@ def sample( return samples[:, 0] # Remove batch dimension. - def amortized_sample( + def sample_batched( self, - sample_shape: Shape = torch.Size(), - x: Optional[Tensor] = None, + sample_shape: Shape, + x: Tensor, max_sampling_batch_size: int = 10_000, - sample_with: Optional[str] = None, show_progress_bars: bool = True, ) -> Tensor: r"""Return samples from posterior $p(\theta|x)$ given multiple observations. @@ -150,53 +149,13 @@ def amortized_sample( given every observation. x: A batch of observations, of shape `(batch_dim, event_shape_x)`. `batch_dim` corresponds to the number of observations to be drawn. - sample_with: This argument only exists to keep backward-compatibility with - `sbi` v0.17.2 or older. If it is set, we instantly raise an error. + max_sampling_batch_size: Maximum batch size for rejection sampling. show_progress_bars: Whether to show sampling progress monitor. """ - - num_samples = torch.Size(sample_shape).numel() - # x = self._x_else_default_x(x) - x = reshape_to_batch_event( - x, event_shape=self.posterior_estimator.condition_shape - ) - - max_sampling_batch_size = ( - self.max_sampling_batch_size - if max_sampling_batch_size is None - else max_sampling_batch_size - ) - - if sample_with is not None: - raise ValueError( - f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting " - f"`sample_with` is no longer supported. You have to rerun " - f"`.build_posterior(sample_with={sample_with}).`" - ) - - samples = rejection.accept_reject_sample( - proposal=self.posterior_estimator, - accept_reject_fn=lambda theta: within_support(self.prior, theta), - num_samples=num_samples, - show_progress_bars=show_progress_bars, - max_sampling_batch_size=max_sampling_batch_size, - proposal_sampling_kwargs={"condition": x}, - alternative_method="build_posterior(..., sample_with='mcmc')", - )[0] - - return samples - - def sample_batched( - self, - sample_shape: Shape, - x: Tensor, - max_sampling_batch_size: int = 10_000, - show_progress_bars: bool = True, - ) -> Tensor: num_samples = torch.Size(sample_shape).numel() condition_shape = self.posterior_estimator.condition_shape x = reshape_to_batch_event(x, event_shape=condition_shape) - print(x.shape) + max_sampling_batch_size = ( self.max_sampling_batch_size if max_sampling_batch_size is None @@ -290,6 +249,52 @@ def log_prob( return masked_log_prob - log_factor + def log_prob_batched( + self, + theta: Tensor, + x: Tensor, + norm_posterior: bool = True, + track_gradients: bool = False, + leakage_correction_params: Optional[dict] = None, + ) -> Tensor: + theta = ensure_theta_batched(torch.as_tensor(theta)) + theta_density_estimator = reshape_to_sample_batch_event( + theta, theta.shape[1:], leading_is_sample=True + ) + x_density_estimator = reshape_to_batch_event( + x, event_shape=self.posterior_estimator.condition_shape + ) + + self.posterior_estimator.eval() + + with torch.set_grad_enabled(track_gradients): + # Evaluate on device, move back to cpu for comparison with prior. + unnorm_log_prob = self.posterior_estimator.log_prob( + theta_density_estimator, condition=x_density_estimator + ) + # `log_prob` supports only a single observation (i.e. `batchsize==1`). + # We now remove this additional dimension. + unnorm_log_prob = unnorm_log_prob.squeeze(dim=1) + + # Force probability to be zero outside prior support. + in_prior_support = within_support(self.prior, theta) + + masked_log_prob = torch.where( + in_prior_support, + unnorm_log_prob, + torch.tensor(float("-inf"), dtype=torch.float32, device=self._device), + ) + + if leakage_correction_params is None: + leakage_correction_params = dict() # use defaults + log_factor = ( + log(self.leakage_correction(x=x, **leakage_correction_params)) + if norm_posterior + else 0 + ) + + return masked_log_prob - log_factor + @torch.no_grad() def leakage_correction( self, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 63e87690b..ca8145129 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -179,9 +179,10 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: # When in case of leakage a batch size was used there could be too many samples. samples = torch.cat(accepted)[:num_samples] - assert ( - samples.shape[0] == num_samples - ), "Number of accepted samples must match required samples." + print(samples.shape) + # assert ( + # samples.shape[0] == num_samples + # ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) @@ -263,7 +264,6 @@ def accept_reject_sample( else: num_xos = 1 accepted, acceptance_rate = [[] for _ in range(num_xos)], float("Nan") - num_accepted = torch.zeros(num_xos) leakage_warning_raised = False # Ruff suggestion @@ -278,7 +278,7 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - num_accepted += are_accepted.sum(dim=0) + num_accepted = are_accepted.sum(dim=0).min().item() for i in range(num_xos): accepted[i].append(candidates[are_accepted[:, i], i]) @@ -288,8 +288,8 @@ def accept_reject_sample( # samples will be of shape(sampling_batch_size,*batch_shape, *event_shape) # and hence work in dim = 0. num_sampled_total += sampling_batch_size - num_remaining -= num_accepted.min().item() - pbar.update(num_accepted.mean().item()) + num_remaining -= num_accepted + pbar.update(num_accepted) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. @@ -341,8 +341,8 @@ def accept_reject_sample( # When in case of leakage a batch size was used there could be too many samples. samples = [torch.cat(accepted[i], dim=0)[:num_samples] for i in range(num_xos)] samples = torch.stack(samples, dim=1) - assert ( - samples.shape[0] == num_samples - ), "Number of accepted samples must match required samples." + # assert ( + # samples.shape[0] == num_samples + # ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index a2b0f102f..70635fa44 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -628,7 +628,7 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: sample_check = distribution.support.check(samples) if sample_check.shape == samples.shape: # With new shapeing conventions we need dim=-2 - sample_check = torch.all(sample_check, dim=-2) + sample_check = torch.all(sample_check, dim=-1) return sample_check # Falling back to log prob method of either the NeuralPosterior's net, or of a From f54222443b81ed58a4169af47dda64bb0274b99e Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 19:50:29 +0200 Subject: [PATCH 13/71] log_prob_batched works --- sbi/inference/posteriors/base_posterior.py | 9 ++++----- sbi/inference/posteriors/direct_posterior.py | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index f5b9cf62e..4aaf1385e 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -122,13 +122,12 @@ def sample( pass @abstractmethod - def amortized_sample( + def sample_batched( self, - sample_shape: Shape = torch.Size(), - x: Optional[Tensor] = None, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10_000, show_progress_bars: bool = True, - mcmc_method: Optional[str] = None, - mcmc_parameters: Optional[Dict[str, Any]] = None, ) -> Tensor: """See child classes for docstring.""" pass diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 150c524f5..f4a0aa6e3 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -258,13 +258,15 @@ def log_prob_batched( leakage_correction_params: Optional[dict] = None, ) -> Tensor: theta = ensure_theta_batched(torch.as_tensor(theta)) + event_shape = self.posterior_estimator.input_shape theta_density_estimator = reshape_to_sample_batch_event( - theta, theta.shape[1:], leading_is_sample=True + theta, event_shape, leading_is_sample=True ) x_density_estimator = reshape_to_batch_event( x, event_shape=self.posterior_estimator.condition_shape ) + print(theta_density_estimator.shape, x_density_estimator.shape) self.posterior_estimator.eval() with torch.set_grad_enabled(track_gradients): From 48a1a285f709d270ce20e3e8e8977ed820589f7b Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 20:00:19 +0200 Subject: [PATCH 14/71] abstract method implement for other methods --- sbi/inference/posteriors/ensemble_posterior.py | 9 +++++++++ sbi/inference/posteriors/importance_posterior.py | 11 +++++++++++ sbi/inference/posteriors/mcmc_posterior.py | 4 ++-- sbi/inference/posteriors/rejection_posterior.py | 11 +++++++++++ sbi/inference/posteriors/vi_posterior.py | 9 +++++++++ 5 files changed, 42 insertions(+), 2 deletions(-) diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 72af02d88..abb2a1a3d 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -179,6 +179,15 @@ def sample( ) return torch.vstack(samples).reshape(*sample_shape, -1) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError("This method is not implemented yet.") + def log_prob( self, theta: Tensor, diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index bbd4ce32f..62b295d4d 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -194,6 +194,17 @@ def sample( else: raise NameError + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not implemented for ImportanceSamplingPosterior." + ) + def _importance_sample( self, sample_shape: Shape = torch.Size(), diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 02a47563d..28aea0c7d 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -355,7 +355,7 @@ def sample( return samples.reshape((*sample_shape, -1)) # type: ignore - def amortized_sample( + def sample_batched( self, sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, @@ -425,7 +425,7 @@ def amortized_sample( ) print("transformed_samples", transformed_samples.shape) samples = self.theta_transform.inv(transformed_samples) - print("samples", samples.shape) + num_obs = 5 return samples.reshape((*sample_shape, num_obs, -1)) # type: ignore diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py index 6da838059..5eb53497b 100644 --- a/sbi/inference/posteriors/rejection_posterior.py +++ b/sbi/inference/posteriors/rejection_posterior.py @@ -167,6 +167,17 @@ def sample( return samples.reshape((*sample_shape, -1)) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not supported for rejection sampling." + ) + def map( self, x: Optional[Tensor] = None, diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 006ab543a..fd89a5654 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -296,6 +296,15 @@ def sample( samples = self.q.sample(torch.Size(sample_shape)) return samples.reshape((*sample_shape, samples.shape[-1])) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError("Batched sampling is not supported for VIPosterior.") + def log_prob( self, theta: Tensor, From 5a37330b8116a923637208d9e2adae9377694d6a Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 20:47:18 +0200 Subject: [PATCH 15/71] temp fix mcmcposterior --- sbi/inference/posteriors/mcmc_posterior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 28aea0c7d..f2211a210 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -593,7 +593,7 @@ def multi_obs_potential(params): all_potentials = potential_function(params) return all_potentials.flatten() - num_obs = 5 + num_obs = 1 # TODO This will fail for num_obs > 1 in embedding_net_test.py initial_params = torch.concatenate([initial_params] * num_obs) posterior_sampler = SliceSamplerMultiChain( From 2b23e42bce73be05d4f6bfbf1ba972c8df37980b Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 21:43:48 +0200 Subject: [PATCH 16/71] meh for general use i.e. in the restriction prior we have to add some reshapes in rejection --- sbi/samplers/rejection/rejection.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index ca8145129..09df9480d 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -179,10 +179,10 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: # When in case of leakage a batch size was used there could be too many samples. samples = torch.cat(accepted)[:num_samples] - print(samples.shape) - # assert ( - # samples.shape[0] == num_samples - # ), "Number of accepted samples must match required samples." + # print(samples.shape) + assert ( + samples.shape[0] == num_samples + ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) @@ -275,13 +275,16 @@ def accept_reject_sample( (sampling_batch_size,), # type: ignore **proposal_sampling_kwargs, ) - # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) + are_accepted = are_accepted.reshape(sampling_batch_size, num_xos) + candidates_to_reject = candidates.reshape( + sampling_batch_size, num_xos, *candidates.shape[1:] + ) num_accepted = are_accepted.sum(dim=0).min().item() - + # print(are_accepted.shape) for i in range(num_xos): - accepted[i].append(candidates[are_accepted[:, i], i]) + accepted[i].append(candidates_to_reject[are_accepted[:, i], i]) # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the @@ -341,8 +344,9 @@ def accept_reject_sample( # When in case of leakage a batch size was used there could be too many samples. samples = [torch.cat(accepted[i], dim=0)[:num_samples] for i in range(num_xos)] samples = torch.stack(samples, dim=1) - # assert ( - # samples.shape[0] == num_samples - # ), "Number of accepted samples must match required samples." + samples = samples.reshape(num_samples, *candidates.shape[1:]) + assert ( + samples.shape[0] == num_samples + ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) From 6362051d3cfb6b950a0ce03abaa36ec3393ba3ce Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 21:57:52 +0200 Subject: [PATCH 17/71] ... test class --- tests/test_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index b6730e65b..a1cea1e07 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -246,6 +246,17 @@ def sample( return self.potential_fn.posterior.sample(sample_shape) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not supported for TractablePosterior." + ) + def log_prob( self, theta: Tensor, From 294609da953c2795aa299e5c87c774004d1a13de Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 07:26:07 +0200 Subject: [PATCH 18/71] Revert "Base estimator class" This reverts commit 17c534303343bd6306ea8e45fd4085a929ba42c2. --- sbi/neural_nets/density_estimators/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index b3b83567c..252c850bc 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,6 +1,3 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Apache License Version 2.0, see - from typing import Optional, Tuple import torch From 99abbb18c22f29ac5b9e492fa30bf1cdbeec1d30 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 07:27:06 +0200 Subject: [PATCH 19/71] removing previous change --- sbi/utils/sbiutils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 70635fa44..58146ccae 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -627,7 +627,6 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: try: sample_check = distribution.support.check(samples) if sample_check.shape == samples.shape: - # With new shapeing conventions we need dim=-2 sample_check = torch.all(sample_check, dim=-1) return sample_check From ef9e99c9eca76f3d7c30663feb6bb42beccac2d0 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 07:33:04 +0200 Subject: [PATCH 20/71] removing some artifacts --- sbi/inference/potentials/likelihood_based_potential.py | 5 ----- sbi/samplers/rejection/rejection.py | 5 +++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index a9e94e7e2..eab36e91f 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -123,8 +123,6 @@ def _log_likelihoods_over_trials( log_likelihood_trial_sum: log likelihood for each parameter, summed over all batch entries (iid trials) in `x`. """ - # print("x", x.shape) - # print("theta", theta.shape) # Shape of `x` is (iid_dim, *event_shape). x = reshape_to_sample_batch_event( x, event_shape=x.shape[1:], leading_is_sample=True @@ -148,9 +146,6 @@ def _log_likelihoods_over_trials( # `DensityEstimator.log_prob`. theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:]) - # print("After reshape theta: ", theta.shape) - # print("After reshape x: ", x.shape) - # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): log_likelihood_trial_batch = estimator.log_prob(x, condition=theta) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 09df9480d..125f55fd7 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -179,7 +179,7 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: # When in case of leakage a batch size was used there could be too many samples. samples = torch.cat(accepted)[:num_samples] - # print(samples.shape) + assert ( samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." @@ -277,12 +277,13 @@ def accept_reject_sample( ) # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) + # Reshape necessary in certain cases which do not follow the shape conventions + # of the "DensityEstimator" class. are_accepted = are_accepted.reshape(sampling_batch_size, num_xos) candidates_to_reject = candidates.reshape( sampling_batch_size, num_xos, *candidates.shape[1:] ) num_accepted = are_accepted.sum(dim=0).min().item() - # print(are_accepted.shape) for i in range(num_xos): accepted[i].append(candidates_to_reject[are_accepted[:, i], i]) From 5eb1007873a0fdacc980a243391beb4f9f6b8d4a Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 07:41:17 +0200 Subject: [PATCH 21/71] revert wierd change --- sbi/neural_nets/density_estimators/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index 252c850bc..b3b83567c 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,3 +1,6 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see + from typing import Optional, Tuple import torch From 82127ab2b403e889b3d0c2f0fa1968f17f2f52fe Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 08:14:10 +0200 Subject: [PATCH 22/71] docs and tests --- sbi/inference/posteriors/direct_posterior.py | 33 +++++++++++++++++--- sbi/samplers/rejection/rejection.py | 3 +- tests/posterior_nn_test.py | 31 ++++++++++++++++++ 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index f4a0aa6e3..6c8c12579 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -257,6 +257,35 @@ def log_prob_batched( track_gradients: bool = False, leakage_correction_params: Optional[dict] = None, ) -> Tensor: + """Returns the log-probabilities of the posteriors $p(\theta_1|x_1),..., \ + p(\theta_B|x_B)$. + + Args: + theta: Batch of parameters $\theta$ of shape \ + `(*sample_shape, batch_dim, *theta_shape)`. + x: Batch of observations $x$ of shape \ + `(batch_dim, *condition_shape)`. + norm_posterior: Whether to enforce a normalized posterior density. + Renormalization of the posterior is useful when some + probability falls out or leaks out of the prescribed prior support. + The normalizing factor is calculated via rejection sampling, so if you + need speedier but unnormalized log posterior estimates set here + `norm_posterior=False`. The returned log posterior is set to + -∞ outside of the prior support regardless of this setting. + track_gradients: Whether the returned tensor supports tracking gradients. + This can be helpful for e.g. sensitivity analysis, but increases memory + consumption. + leakage_correction_params: A `dict` of keyword arguments to override the + default values of `leakage_correction()`. Possible options are: + `num_rejection_samples`, `force_update`, `show_progress_bars`, and + `rejection_sampling_batch_size`. + These parameters only have an effect if `norm_posterior=True`. + + Returns: + `(len(θ), B)`-shaped log posterior probability $\\log p(\theta|x)$\\ for θ \ + in the support of the prior, -∞ (corresponding to 0 probability) outside. + """ + theta = ensure_theta_batched(torch.as_tensor(theta)) event_shape = self.posterior_estimator.input_shape theta_density_estimator = reshape_to_sample_batch_event( @@ -266,7 +295,6 @@ def log_prob_batched( x, event_shape=self.posterior_estimator.condition_shape ) - print(theta_density_estimator.shape, x_density_estimator.shape) self.posterior_estimator.eval() with torch.set_grad_enabled(track_gradients): @@ -274,9 +302,6 @@ def log_prob_batched( unnorm_log_prob = self.posterior_estimator.log_prob( theta_density_estimator, condition=x_density_estimator ) - # `log_prob` supports only a single observation (i.e. `batchsize==1`). - # We now remove this additional dimension. - unnorm_log_prob = unnorm_log_prob.squeeze(dim=1) # Force probability to be zero outside prior support. in_prior_support = within_support(self.prior, theta) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 125f55fd7..d9cd4094b 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -281,8 +281,9 @@ def accept_reject_sample( # of the "DensityEstimator" class. are_accepted = are_accepted.reshape(sampling_batch_size, num_xos) candidates_to_reject = candidates.reshape( - sampling_batch_size, num_xos, *candidates.shape[1:] + sampling_batch_size, num_xos, *candidates.shape[candidates.ndim - 1 :] ) + num_accepted = are_accepted.sum(dim=0).min().item() for i in range(num_xos): accepted[i].append(candidates_to_reject[are_accepted[:, i], i]) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 33f4c29e4..593660341 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -49,3 +49,34 @@ def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool): ).set_default_x(x_o) samples = posterior.sample((10,)) _ = posterior.log_prob(samples) + + +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +@pytest.mark.parametrize( + "x_o_batch_dim", + ( + 0, + 1, + 2, + ), +) +def test_batched_sample_log_prob_with_different_x( + snpe_method: type, x_o_batch_dim: bool +): + num_dim = 2 + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + simulator = diagonal_linear_gaussian + + inference = snpe_method(prior=prior) + theta, x = simulate_for_sbi(simulator, prior, 1000) + posterior_estimator = inference.append_simulations(theta, x).train(max_num_epochs=3) + + x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) + + posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior) + + samples = posterior.sample_batched((10,), x_o) + batched_log_probs = posterior.log_prob_batched(samples, x_o) + + assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) From 41617a8646a3cac2cf15d18af18d1f4e5391dbc1 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 14 May 2024 09:09:30 +0200 Subject: [PATCH 23/71] MCMC sample_batched works but not log_prob batched --- .../posteriors/ensemble_posterior.py | 5 +- sbi/inference/posteriors/mcmc_posterior.py | 74 +++++++++++++------ tests/posterior_nn_test.py | 45 +++++++++++ 3 files changed, 100 insertions(+), 24 deletions(-) diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index abb2a1a3d..58e5cb53c 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -186,7 +186,10 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: - raise NotImplementedError("This method is not implemented yet.") + raise NotImplementedError( + "Batched sampling is not implemented for \ + EnsemblePosterior." + ) def log_prob( self, diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index f2211a210..112152f30 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -207,6 +207,38 @@ def log_prob( theta.to(self._device), track_gradients=track_gradients ) + def log_prob_batched( + self, theta: Tensor, x: Tensor, track_gradients: bool = False + ) -> Tensor: + r"""Returns the log-probability of theta under the multiple posteriors. + + Args: + theta: Parameters $\theta$. + track_gradients: Whether the returned tensor supports tracking gradients. + This can be helpful for e.g. sensitivity analysis, but increases memory + consumption. + + Returns: + `len($\theta$)`-shaped log-probability. + """ + warn( + """`.log_prob()` is deprecated for methods that can only evaluate the + log-probability up to a normalizing constant. Use `.potential()` + instead.""", + stacklevel=2, + ) + warn("The log-probability is unnormalized!", stacklevel=2) + + self.potential_fn.set_x(x) + print(x.shape) + theta = ensure_theta_batched(torch.as_tensor(theta)) + print(theta.shape) + potential = self.potential_fn( + theta.to(self._device), track_gradients=track_gradients + ) + print(potential) + return potential + def sample( self, sample_shape: Shape = torch.Size(), @@ -350,15 +382,16 @@ def sample( ) else: raise NameError(f"The sampling method {method} is not implemented!") - + print(transformed_samples.shape) samples = self.theta_transform.inv(transformed_samples) + samples = samples.reshape((*sample_shape, -1)) # type: ignore return samples.reshape((*sample_shape, -1)) # type: ignore def sample_batched( self, - sample_shape: Shape = torch.Size(), - x: Optional[Tensor] = None, + sample_shape: Shape, + x: Tensor, method: Optional[str] = None, thin: Optional[int] = None, warmup_steps: Optional[int] = None, @@ -383,7 +416,8 @@ def sample_batched( Returns: Samples from posterior. """ - self.potential_fn.set_x(self._x_else_default_x(x)) + batch_size = x.shape[0] + self.potential_fn.set_x(x) # Replace arguments that were not passed with their default. method = self.method if method is None else method @@ -400,18 +434,21 @@ def sample_batched( ) self.potential_ = self._prepare_potential(method) # type: ignore - print("Getting initial params") + num_chains_extended = batch_size * num_chains initial_params = self._get_initial_params( init_strategy, # type: ignore - num_chains, # type: ignore + num_chains_extended, # type: ignore num_workers, show_progress_bars, **init_strategy_parameters, ) - print("Finished init") - num_samples = torch.Size(sample_shape).numel() - assert method == "slice_np_vectorized" + num_samples = torch.Size(sample_shape).numel() * batch_size + + assert ( + method == "slice_np_vectorized" + ), "Batched sampling only supported for vectorized samplers!" + with torch.set_grad_enabled(False): transformed_samples = self._slice_np_mcmc( num_samples=num_samples, @@ -423,12 +460,10 @@ def sample_batched( num_workers=num_workers, show_progress_bars=show_progress_bars, ) - print("transformed_samples", transformed_samples.shape) - samples = self.theta_transform.inv(transformed_samples) - num_obs = 5 + samples = self.theta_transform.inv(transformed_samples) - return samples.reshape((*sample_shape, num_obs, -1)) # type: ignore + return samples.reshape((*sample_shape, batch_size, -1)) # type: ignore def _build_mcmc_init_fn( self, @@ -586,20 +621,15 @@ def multi_obs_potential(params): # Params are of shape (num_chains * num_obs, event). # We now reshape them to (num_chains, num_obs, event). # params = np.reshape(params, (num_chains, num_obs, -1)) - # print("params", params.shape) - # print("potential_function", potential_function) # `all_potentials` is of shape (num_chains, num_obs). all_potentials = potential_function(params) return all_potentials.flatten() - num_obs = 1 # TODO This will fail for num_obs > 1 in embedding_net_test.py - initial_params = torch.concatenate([initial_params] * num_obs) - posterior_sampler = SliceSamplerMultiChain( init_params=tensor2numpy(initial_params), log_prob_fn=multi_obs_potential, - num_chains=num_chains * num_obs, + num_chains=num_chains, thin=thin, verbose=show_progress_bars, num_workers=num_workers, @@ -618,12 +648,10 @@ def multi_obs_potential(params): self._posterior_sampler = posterior_sampler # Save sample as potential next init (if init_strategy == 'latest_sample'). - self._mcmc_init_params = samples[:, -1, :].reshape( - num_chains, num_obs, dim_samples - ) + self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) # Collect samples from all chains. - samples = samples.reshape(-1, num_obs, dim_samples)[:num_samples] + samples = samples.reshape(-1, dim_samples)[:num_samples] return samples.type(torch.float32).to(self._device) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 593660341..a7efef04c 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -8,6 +8,7 @@ from torch.distributions import MultivariateNormal from sbi.inference import ( + SNLE_A, SNPE_A, SNPE_C, DirectPosterior, @@ -79,4 +80,48 @@ def test_batched_sample_log_prob_with_different_x( samples = posterior.sample_batched((10,), x_o) batched_log_probs = posterior.log_prob_batched(samples, x_o) + assert ( + samples.shape == (10, x_o_batch_dim, num_dim) + if x_o_batch_dim > 0 + else (10, num_dim) + ) assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) + + +@pytest.mark.mcmc +@pytest.mark.parametrize("snlre_method", [SNLE_A]) +@pytest.mark.parametrize( + "x_o_batch_dim", + ( + 0, + 1, + 2, + ), +) +def test_batched_mcmc_sample_log_prob_with_different_x( + snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict +): + num_dim = 2 + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + simulator = diagonal_linear_gaussian + + inference = snlre_method(prior=prior) + theta, x = simulate_for_sbi(simulator, prior, 1000) + _ = inference.append_simulations(theta, x).train(max_num_epochs=3) + + x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) + + posterior = inference.build_posterior( + mcmc_method="slice_np_vectorized", mcmc_parameters=mcmc_params_fast + ) + + samples = posterior.sample_batched((10,), x_o) + # batched_log_probs = posterior.log_prob_batched(samples, x_o) + + assert ( + samples.shape == (10, x_o_batch_dim, num_dim) + if x_o_batch_dim > 0 + else (10, num_dim) + ) + # assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) From 82951db009d0a7ff113e6c1ba0d95494a33b5252 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 14 May 2024 11:03:06 +0200 Subject: [PATCH 24/71] adding some docs --- sbi/inference/posteriors/direct_posterior.py | 13 +++-- .../posteriors/ensemble_posterior.py | 7 ++- .../posteriors/importance_posterior.py | 5 +- sbi/inference/posteriors/mcmc_posterior.py | 49 ++++--------------- .../posteriors/rejection_posterior.py | 4 +- sbi/inference/posteriors/vi_posterior.py | 6 ++- tests/posterior_nn_test.py | 5 +- 7 files changed, 40 insertions(+), 49 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 6c8c12579..1c7ef691d 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -142,7 +142,9 @@ def sample_batched( max_sampling_batch_size: int = 10_000, show_progress_bars: bool = True, ) -> Tensor: - r"""Return samples from posterior $p(\theta|x)$ given multiple observations. + r"""Given a batch of observations [x_1, ..., x_B] this function samples from + posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized) + manner. Args: sample_shape: Desired shape of samples that are drawn from the posterior @@ -151,6 +153,9 @@ def sample_batched( `batch_dim` corresponds to the number of observations to be drawn. max_sampling_batch_size: Maximum batch size for rejection sampling. show_progress_bars: Whether to show sampling progress monitor. + + Returns: + Samples from the posteriors of shape (*sample_shape, B, *input_shape) """ num_samples = torch.Size(sample_shape).numel() condition_shape = self.posterior_estimator.condition_shape @@ -257,8 +262,10 @@ def log_prob_batched( track_gradients: bool = False, leakage_correction_params: Optional[dict] = None, ) -> Tensor: - """Returns the log-probabilities of the posteriors $p(\theta_1|x_1),..., \ - p(\theta_B|x_B)$. + """Given a batch of observations [x_1, ..., x_B] and a batch of parameters \ + [$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \ + of the posterior $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \ + (i.e. vectorized) manner. Args: theta: Batch of parameters $\theta$ of shape \ diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 58e5cb53c..e895353c3 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -186,9 +186,12 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: + # TODO Can be implemented in the future, for all base posterior that support + # batched sampling. raise NotImplementedError( - "Batched sampling is not implemented for \ - EnsemblePosterior." + "Batched sampling is not implemented for EnsemblePosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." ) def log_prob( diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index 62b295d4d..0b659c84e 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -201,8 +201,11 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: + # TODO Can be implemented in the future. raise NotImplementedError( - "Batched sampling is not implemented for ImportanceSamplingPosterior." + "Batched sampling is not implemented for ImportanceSamplingPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." ) def _importance_sample( diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 112152f30..b0cd6d598 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -207,38 +207,6 @@ def log_prob( theta.to(self._device), track_gradients=track_gradients ) - def log_prob_batched( - self, theta: Tensor, x: Tensor, track_gradients: bool = False - ) -> Tensor: - r"""Returns the log-probability of theta under the multiple posteriors. - - Args: - theta: Parameters $\theta$. - track_gradients: Whether the returned tensor supports tracking gradients. - This can be helpful for e.g. sensitivity analysis, but increases memory - consumption. - - Returns: - `len($\theta$)`-shaped log-probability. - """ - warn( - """`.log_prob()` is deprecated for methods that can only evaluate the - log-probability up to a normalizing constant. Use `.potential()` - instead.""", - stacklevel=2, - ) - warn("The log-probability is unnormalized!", stacklevel=2) - - self.potential_fn.set_x(x) - print(x.shape) - theta = ensure_theta_batched(torch.as_tensor(theta)) - print(theta.shape) - potential = self.potential_fn( - theta.to(self._device), track_gradients=track_gradients - ) - print(potential) - return potential - def sample( self, sample_shape: Shape = torch.Size(), @@ -382,7 +350,7 @@ def sample( ) else: raise NameError(f"The sampling method {method} is not implemented!") - print(transformed_samples.shape) + samples = self.theta_transform.inv(transformed_samples) samples = samples.reshape((*sample_shape, -1)) # type: ignore @@ -402,19 +370,22 @@ def sample_batched( mp_context: Optional[str] = None, show_progress_bars: bool = True, ) -> Tensor: - r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC. + r"""Given a batch of observations [x_1, ..., x_B] this function samples from + posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized) + manner. Check the `__init__()` method for a description of all arguments as well as their default values. Args: - sample_shape: Desired shape of samples that are drawn from posterior. If - sample_shape is multidimensional we simply draw `sample_shape.numel()` - samples and then reshape into the desired shape. + sample_shape: Desired shape of samples that are drawn from the posterior + given every observation. + x: A batch of observations, of shape `(batch_dim, event_shape_x)`. + `batch_dim` corresponds to the number of observations to be drawn. show_progress_bars: Whether to show sampling progress monitor. Returns: - Samples from posterior. + Samples from the posteriors of shape (*sample_shape, B, *input_shape) """ batch_size = x.shape[0] self.potential_fn.set_x(x) @@ -638,9 +609,7 @@ def multi_obs_potential(params): warmup_ = warmup_steps * thin num_samples_ = ceil((num_samples * thin) / num_chains) # Run mcmc including warmup - print("Start run") samples = posterior_sampler.run(warmup_ + num_samples_) - print("Finish run") samples = samples[:, warmup_steps:, :] # discard warmup steps samples = torch.from_numpy(samples) # chains x samples x dim diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py index 5eb53497b..549942ce2 100644 --- a/sbi/inference/posteriors/rejection_posterior.py +++ b/sbi/inference/posteriors/rejection_posterior.py @@ -175,7 +175,9 @@ def sample_batched( show_progress_bars: bool = True, ) -> Tensor: raise NotImplementedError( - "Batched sampling is not supported for rejection sampling." + "Batched sampling is not implemented for RejectionPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." ) def map( diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index fd89a5654..b0e7bcf8f 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -303,7 +303,11 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: - raise NotImplementedError("Batched sampling is not supported for VIPosterior.") + raise NotImplementedError( + "Batched sampling is not implemented for ImportanceSamplingPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." + ) def log_prob( self, diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index a7efef04c..dfe02a05d 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -117,7 +117,10 @@ def test_batched_mcmc_sample_log_prob_with_different_x( ) samples = posterior.sample_batched((10,), x_o) - # batched_log_probs = posterior.log_prob_batched(samples, x_o) + print(x_o.shape) + print(samples.shape) + batched_log_probs = posterior.log_prob_batched(samples, x_o) + print(batched_log_probs.shape) assert ( samples.shape == (10, x_o_batch_dim, num_dim) From c5fac1d5f39b6d42e5f69e9ecad8e776f1a3cdd2 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 14 May 2024 11:06:01 +0200 Subject: [PATCH 25/71] batch_log_prob for MCMC requires at best changes for potential -> removed --- tests/posterior_nn_test.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index dfe02a05d..d005f4412 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -84,8 +84,8 @@ def test_batched_sample_log_prob_with_different_x( samples.shape == (10, x_o_batch_dim, num_dim) if x_o_batch_dim > 0 else (10, num_dim) - ) - assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) + ), "Sample shape wrong" + assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)), "logprob shape wrong" @pytest.mark.mcmc @@ -117,14 +117,9 @@ def test_batched_mcmc_sample_log_prob_with_different_x( ) samples = posterior.sample_batched((10,), x_o) - print(x_o.shape) - print(samples.shape) - batched_log_probs = posterior.log_prob_batched(samples, x_o) - print(batched_log_probs.shape) assert ( samples.shape == (10, x_o_batch_dim, num_dim) if x_o_batch_dim > 0 else (10, num_dim) - ) - # assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) + ), "Sampel shape wrong" From 0d8242255ff274c5c257a070aa6cd8084612de0b Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 3 May 2024 16:14:43 +0200 Subject: [PATCH 26/71] intermediate commit --- sbi/inference/posteriors/base_posterior.py | 12 +++++ sbi/inference/posteriors/direct_posterior.py | 51 ++++++++++++++++++++ sbi/samplers/rejection/rejection.py | 19 ++++++-- 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 0db66b1cb..f5b9cf62e 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -121,6 +121,18 @@ def sample( """See child classes for docstring.""" pass + @abstractmethod + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + show_progress_bars: bool = True, + mcmc_method: Optional[str] = None, + mcmc_parameters: Optional[Dict[str, Any]] = None, + ) -> Tensor: + """See child classes for docstring.""" + pass + @property def default_x(self) -> Optional[Tensor]: """Return default x used by `.sample(), .log_prob` as conditioning context.""" diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index fb20a580e..388ed53d0 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -135,6 +135,57 @@ def sample( return samples + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + max_sampling_batch_size: int = 10_000, + sample_with: Optional[str] = None, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Return samples from posterior $p(\theta|x)$ given multiple observations. + + Args: + sample_shape: Desired shape of samples that are drawn from the posterior + given every observation. + x: A batch of observations, of shape `(batch_dim, event_shape_x)`. + `batch_dim` corresponds to the number of observations to be drawn. + sample_with: This argument only exists to keep backward-compatibility with + `sbi` v0.17.2 or older. If it is set, we instantly raise an error. + show_progress_bars: Whether to show sampling progress monitor. + """ + + num_samples = torch.Size(sample_shape).numel() + # x = self._x_else_default_x(x) + x = reshape_to_batch_event( + x, event_shape=self.posterior_estimator.condition_shape + ) + + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size + ) + + if sample_with is not None: + raise ValueError( + f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting " + f"`sample_with` is no longer supported. You have to rerun " + f"`.build_posterior(sample_with={sample_with}).`" + ) + + samples = accept_reject_sample( + proposal=self.posterior_estimator, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"condition": x}, + alternative_method="build_posterior(..., sample_with='mcmc')", + )[0] + + return samples + def log_prob( self, theta: Tensor, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 5c78cd8dd..791ffcd98 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -253,8 +253,11 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) + print("Bullas;fjkdsafj;dlsjfldsaj") num_sampled_total, num_remaining = 0, num_samples + num_xo = proposal_sampling_kwargs["condition"].shape[0] + accepted_every_obs = [torch.tensor((0, 2)) for _ in range(num_xo)] accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False # Ruff suggestion @@ -272,15 +275,23 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - samples = candidates[are_accepted] - accepted.append(samples) + print("are_accepted", are_accepted.shape) + for obs_index in range(num_xo): + accepted = candidates[are_accepted[:, obs_index], obs_index] + print("accepted", accepted.shape) + print("accepted_every_obs[obs_index]", accepted_every_obs[obs_index].shape) + here = accepted_every_obs[obs_index] + print("here", here.shape) + print("acc", accepted.shape) + accepted_every_obs[obs_index] = torch.cat([accepted_every_obs[obs_index], accepted], dim=0) + lowest_num_accepted = min(len(s) for s in accepted_every_obs) + num_remaining = num_samples - lowest_num_accepted # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work # in dim = -2. num_sampled_total += sampling_batch_size - num_remaining -= samples.shape[-2] pbar.update(samples.shape[-2]) # To avoid endless sampling when leakage is high, we raise a warning if the @@ -331,7 +342,7 @@ def accept_reject_sample( pbar.close() # When in case of leakage a batch size was used there could be too many samples. - samples = torch.cat(accepted, dim=-2)[..., :num_samples, :] + samples = samples[..., :num_samples, :] assert ( samples.shape[-2] == num_samples ), "Number of accepted samples must match required samples." From 57cfde3a8c8079c719d2734c3738515440898fb5 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 3 May 2024 17:05:01 +0200 Subject: [PATCH 27/71] make autoreload work --- sbi/inference/posteriors/direct_posterior.py | 4 ++-- sbi/samplers/rejection/rejection.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 388ed53d0..5928c52b7 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -16,7 +16,7 @@ reshape_to_batch_event, reshape_to_sample_batch_event, ) -from sbi.samplers.rejection.rejection import accept_reject_sample +from sbi.samplers.rejection import rejection from sbi.sbi_types import Shape from sbi.utils import check_prior, within_support from sbi.utils.torchutils import ensure_theta_batched @@ -123,7 +123,7 @@ def sample( f"`.build_posterior(sample_with={sample_with}).`" ) - samples = accept_reject_sample( + samples = rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 791ffcd98..149446844 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -253,7 +253,6 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) - print("Bullas;fjkdsafj;dlsjfldsaj") num_sampled_total, num_remaining = 0, num_samples num_xo = proposal_sampling_kwargs["condition"].shape[0] From de5d647ff2d139dc26524b496de922d3c5cd487d Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Sun, 5 May 2024 17:09:56 +0200 Subject: [PATCH 28/71] `amortized_sample` works for MCMCPosterior --- sbi/inference/posteriors/direct_posterior.py | 4 +- sbi/inference/posteriors/mcmc_posterior.py | 102 +++++++++++++++++- .../potentials/likelihood_based_potential.py | 5 + sbi/samplers/rejection/rejection.py | 28 ++--- 4 files changed, 120 insertions(+), 19 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 5928c52b7..ed131e218 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -133,7 +133,7 @@ def sample( alternative_method="build_posterior(..., sample_with='mcmc')", )[0] - return samples + return samples[:, 0] # Remove batch dimension. def amortized_sample( self, @@ -174,7 +174,7 @@ def amortized_sample( f"`.build_posterior(sample_with={sample_with}).`" ) - samples = accept_reject_sample( + samples = rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 5ef9f882a..c47db3e7b 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -17,6 +17,7 @@ from torch import Tensor from torch import multiprocessing as mp from tqdm.auto import tqdm +import numpy as np from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.potentials.base_potential import BasePotential @@ -355,6 +356,83 @@ def sample( return samples.reshape((*sample_shape, -1)) # type: ignore + + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + method: Optional[str] = None, + thin: Optional[int] = None, + warmup_steps: Optional[int] = None, + num_chains: Optional[int] = None, + init_strategy: Optional[str] = None, + init_strategy_parameters: Optional[Dict[str, Any]] = None, + num_workers: Optional[int] = None, + mp_context: Optional[str] = None, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC. + + Check the `__init__()` method for a description of all arguments as well as + their default values. + + Args: + sample_shape: Desired shape of samples that are drawn from posterior. If + sample_shape is multidimensional we simply draw `sample_shape.numel()` + samples and then reshape into the desired shape. + show_progress_bars: Whether to show sampling progress monitor. + + Returns: + Samples from posterior. + """ + self.potential_fn.set_x(self._x_else_default_x(x)) + + # Replace arguments that were not passed with their default. + method = self.method if method is None else method + thin = self.thin if thin is None else thin + warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps + num_chains = self.num_chains if num_chains is None else num_chains + init_strategy = self.init_strategy if init_strategy is None else init_strategy + num_workers = self.num_workers if num_workers is None else num_workers + mp_context = self.mp_context if mp_context is None else mp_context + init_strategy_parameters = ( + self.init_strategy_parameters + if init_strategy_parameters is None + else init_strategy_parameters + ) + self.potential_ = self._prepare_potential(method) # type: ignore + + print("Getting initial params") + initial_params = self._get_initial_params( + init_strategy, # type: ignore + num_chains, # type: ignore + num_workers, + show_progress_bars, + **init_strategy_parameters, + ) + print("Finished init") + num_samples = torch.Size(sample_shape).numel() + + assert method == "slice_np_vectorized" + with torch.set_grad_enabled(False): + transformed_samples = self._slice_np_mcmc( + num_samples=num_samples, + potential_function=self.potential_, + initial_params=initial_params, + thin=thin, # type: ignore + warmup_steps=warmup_steps, # type: ignore + vectorized=(method == "slice_np_vectorized"), + num_workers=num_workers, + show_progress_bars=show_progress_bars, + ) + print("transformed_samples", transformed_samples.shape) + samples = self.theta_transform.inv(transformed_samples) + print("samples", samples.shape) + num_obs = 5 + + return samples.reshape((*sample_shape, num_obs, -1)) # type: ignore + + def _build_mcmc_init_fn( self, proposal: Any, @@ -507,10 +585,24 @@ def _slice_np_mcmc( else: SliceSamplerMultiChain = SliceSamplerVectorized + def multi_obs_potential(params): + # Params are of shape (num_chains * num_obs, event). + # We now reshape them to (num_chains, num_obs, event). + # params = np.reshape(params, (num_chains, num_obs, -1)) + # print("params", params.shape) + # print("potential_function", potential_function) + + # `all_potentials` is of shape (num_chains, num_obs). + all_potentials = potential_function(params) + return all_potentials.flatten() + + num_obs = 5 + initial_params = torch.concatenate([initial_params] * num_obs) + posterior_sampler = SliceSamplerMultiChain( init_params=tensor2numpy(initial_params), - log_prob_fn=potential_function, - num_chains=num_chains, + log_prob_fn=multi_obs_potential, + num_chains=num_chains * num_obs, thin=thin, verbose=show_progress_bars, num_workers=num_workers, @@ -519,7 +611,9 @@ def _slice_np_mcmc( warmup_ = warmup_steps * thin num_samples_ = ceil((num_samples * thin) / num_chains) # Run mcmc including warmup + print("Start run") samples = posterior_sampler.run(warmup_ + num_samples_) + print("Finish run") samples = samples[:, warmup_steps:, :] # discard warmup steps samples = torch.from_numpy(samples) # chains x samples x dim @@ -527,10 +621,10 @@ def _slice_np_mcmc( self._posterior_sampler = posterior_sampler # Save sample as potential next init (if init_strategy == 'latest_sample'). - self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) + self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, num_obs, dim_samples) # Collect samples from all chains. - samples = samples.reshape(-1, dim_samples)[:num_samples] + samples = samples.reshape(-1, num_obs, dim_samples)[:num_samples] return samples.type(torch.float32).to(self._device) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index eab36e91f..a9e94e7e2 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -123,6 +123,8 @@ def _log_likelihoods_over_trials( log_likelihood_trial_sum: log likelihood for each parameter, summed over all batch entries (iid trials) in `x`. """ + # print("x", x.shape) + # print("theta", theta.shape) # Shape of `x` is (iid_dim, *event_shape). x = reshape_to_sample_batch_event( x, event_shape=x.shape[1:], leading_is_sample=True @@ -146,6 +148,9 @@ def _log_likelihoods_over_trials( # `DensityEstimator.log_prob`. theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:]) + # print("After reshape theta: ", theta.shape) + # print("After reshape x: ", x.shape) + # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): log_likelihood_trial_batch = estimator.log_prob(x, condition=theta) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 149446844..ed9625b97 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -236,7 +236,8 @@ def accept_reject_sample( `rejection_sample()`. Warn if not empty. Returns: - Accepted samples and acceptance rate as scalar Tensor. + Accepted samples of shape `(sample_dim, batch_dim, *event_shape)`, and + worst-case acceptance rate as scalar Tensor. """ if kwargs: @@ -255,8 +256,12 @@ def accept_reject_sample( ) num_sampled_total, num_remaining = 0, num_samples - num_xo = proposal_sampling_kwargs["condition"].shape[0] - accepted_every_obs = [torch.tensor((0, 2)) for _ in range(num_xo)] + if "condition" in list(proposal_sampling_kwargs.keys()): + num_xo = proposal_sampling_kwargs["condition"].shape[0] + else: + num_xo = 1 + + accepted_every_obs = [torch.tensor(()) for _ in range(num_xo)] accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False # Ruff suggestion @@ -274,14 +279,8 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - print("are_accepted", are_accepted.shape) for obs_index in range(num_xo): accepted = candidates[are_accepted[:, obs_index], obs_index] - print("accepted", accepted.shape) - print("accepted_every_obs[obs_index]", accepted_every_obs[obs_index].shape) - here = accepted_every_obs[obs_index] - print("here", here.shape) - print("acc", accepted.shape) accepted_every_obs[obs_index] = torch.cat([accepted_every_obs[obs_index], accepted], dim=0) lowest_num_accepted = min(len(s) for s in accepted_every_obs) num_remaining = num_samples - lowest_num_accepted @@ -291,7 +290,7 @@ def accept_reject_sample( # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work # in dim = -2. num_sampled_total += sampling_batch_size - pbar.update(samples.shape[-2]) + pbar.update(num_samples - num_remaining) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. @@ -340,10 +339,13 @@ def accept_reject_sample( pbar.close() + for obs_index in range(num_xo): + accepted_every_obs[obs_index] = accepted_every_obs[obs_index][:num_samples] + + accepted_every_obs = torch.stack(accepted_every_obs) # When in case of leakage a batch size was used there could be too many samples. - samples = samples[..., :num_samples, :] assert ( - samples.shape[-2] == num_samples + accepted_every_obs.shape[-2] == num_samples ), "Number of accepted samples must match required samples." - return samples, as_tensor(acceptance_rate) + return torch.permute(accepted_every_obs, (1, 0, -1)), as_tensor(acceptance_rate) From f8b6604017f9e8d333e5e77022c41c2f43491266 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 3 May 2024 16:14:43 +0200 Subject: [PATCH 29/71] intermediate commit --- sbi/inference/posteriors/direct_posterior.py | 51 ++++++++++++++++++++ sbi/samplers/rejection/rejection.py | 9 ++-- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index ed131e218..a54826b91 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -186,6 +186,57 @@ def amortized_sample( return samples + def amortized_sample( + self, + sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, + max_sampling_batch_size: int = 10_000, + sample_with: Optional[str] = None, + show_progress_bars: bool = True, + ) -> Tensor: + r"""Return samples from posterior $p(\theta|x)$ given multiple observations. + + Args: + sample_shape: Desired shape of samples that are drawn from the posterior + given every observation. + x: A batch of observations, of shape `(batch_dim, event_shape_x)`. + `batch_dim` corresponds to the number of observations to be drawn. + sample_with: This argument only exists to keep backward-compatibility with + `sbi` v0.17.2 or older. If it is set, we instantly raise an error. + show_progress_bars: Whether to show sampling progress monitor. + """ + + num_samples = torch.Size(sample_shape).numel() + # x = self._x_else_default_x(x) + x = reshape_to_batch_event( + x, event_shape=self.posterior_estimator.condition_shape + ) + + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size + ) + + if sample_with is not None: + raise ValueError( + f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting " + f"`sample_with` is no longer supported. You have to rerun " + f"`.build_posterior(sample_with={sample_with}).`" + ) + + samples = accept_reject_sample( + proposal=self.posterior_estimator, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"condition": x}, + alternative_method="build_posterior(..., sample_with='mcmc')", + )[0] + + return samples + def log_prob( self, theta: Tensor, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index ed9625b97..72737063f 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -254,6 +254,10 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) + # Ruff suggestion + if proposal_sampling_kwargs is None: + proposal_sampling_kwargs = {} + num_sampled_total, num_remaining = 0, num_samples if "condition" in list(proposal_sampling_kwargs.keys()): @@ -264,10 +268,7 @@ def accept_reject_sample( accepted_every_obs = [torch.tensor(()) for _ in range(num_xo)] accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False - # Ruff suggestion - if proposal_sampling_kwargs is None: - proposal_sampling_kwargs = {} - + # To cover cases with few samples without leakage: sampling_batch_size = min(num_samples, max_sampling_batch_size) while num_remaining > 0: From 1dcf882bb335268fee6e6e3d9f748e486a1d4aa0 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 3 May 2024 17:05:01 +0200 Subject: [PATCH 30/71] make autoreload work --- sbi/inference/posteriors/direct_posterior.py | 53 +------------------- sbi/samplers/rejection/rejection.py | 8 +++ 2 files changed, 9 insertions(+), 52 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index a54826b91..c7f1b6620 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -133,57 +133,6 @@ def sample( alternative_method="build_posterior(..., sample_with='mcmc')", )[0] - return samples[:, 0] # Remove batch dimension. - - def amortized_sample( - self, - sample_shape: Shape = torch.Size(), - x: Optional[Tensor] = None, - max_sampling_batch_size: int = 10_000, - sample_with: Optional[str] = None, - show_progress_bars: bool = True, - ) -> Tensor: - r"""Return samples from posterior $p(\theta|x)$ given multiple observations. - - Args: - sample_shape: Desired shape of samples that are drawn from the posterior - given every observation. - x: A batch of observations, of shape `(batch_dim, event_shape_x)`. - `batch_dim` corresponds to the number of observations to be drawn. - sample_with: This argument only exists to keep backward-compatibility with - `sbi` v0.17.2 or older. If it is set, we instantly raise an error. - show_progress_bars: Whether to show sampling progress monitor. - """ - - num_samples = torch.Size(sample_shape).numel() - # x = self._x_else_default_x(x) - x = reshape_to_batch_event( - x, event_shape=self.posterior_estimator.condition_shape - ) - - max_sampling_batch_size = ( - self.max_sampling_batch_size - if max_sampling_batch_size is None - else max_sampling_batch_size - ) - - if sample_with is not None: - raise ValueError( - f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting " - f"`sample_with` is no longer supported. You have to rerun " - f"`.build_posterior(sample_with={sample_with}).`" - ) - - samples = rejection.accept_reject_sample( - proposal=self.posterior_estimator, - accept_reject_fn=lambda theta: within_support(self.prior, theta), - num_samples=num_samples, - show_progress_bars=show_progress_bars, - max_sampling_batch_size=max_sampling_batch_size, - proposal_sampling_kwargs={"condition": x}, - alternative_method="build_posterior(..., sample_with='mcmc')", - )[0] - return samples def amortized_sample( @@ -225,7 +174,7 @@ def amortized_sample( f"`.build_posterior(sample_with={sample_with}).`" ) - samples = accept_reject_sample( + samples = rejection.accept_reject_sample( proposal=self.posterior_estimator, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 72737063f..7ed68390c 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -254,10 +254,18 @@ def accept_reject_sample( total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) + # Ruff suggestion if proposal_sampling_kwargs is None: proposal_sampling_kwargs = {} + num_sampled_total, num_remaining = 0, num_samples + num_xo = proposal_sampling_kwargs["condition"].shape[0] + accepted_every_obs = [torch.tensor((0, 2)) for _ in range(num_xo)] + accepted, acceptance_rate = [], float("Nan") + leakage_warning_raised = False + + num_sampled_total, num_remaining = 0, num_samples if "condition" in list(proposal_sampling_kwargs.keys()): From 5a31970f341044f142f76552fa68c421958c6ebd Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Sun, 5 May 2024 17:09:56 +0200 Subject: [PATCH 31/71] `amortized_sample` works for MCMCPosterior --- sbi/inference/posteriors/direct_posterior.py | 2 +- sbi/samplers/rejection/rejection.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index c7f1b6620..ed131e218 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -133,7 +133,7 @@ def sample( alternative_method="build_posterior(..., sample_with='mcmc')", )[0] - return samples + return samples[:, 0] # Remove batch dimension. def amortized_sample( self, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 7ed68390c..e3c7b221a 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -260,8 +260,12 @@ def accept_reject_sample( proposal_sampling_kwargs = {} num_sampled_total, num_remaining = 0, num_samples - num_xo = proposal_sampling_kwargs["condition"].shape[0] - accepted_every_obs = [torch.tensor((0, 2)) for _ in range(num_xo)] + if "condition" in list(proposal_sampling_kwargs.keys()): + num_xo = proposal_sampling_kwargs["condition"].shape[0] + else: + num_xo = 1 + + accepted_every_obs = [torch.tensor(()) for _ in range(num_xo)] accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False From 871c4ded37d11dcf6e98677fb1fe7ab30752fa87 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 29 Apr 2024 09:04:20 +0200 Subject: [PATCH 32/71] Base estimator class --- sbi/neural_nets/density_estimators/base.py | 136 ++++++++++++++++----- 1 file changed, 105 insertions(+), 31 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index b3b83567c..6d56fbf64 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,13 +1,116 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +from abc import ABC, abstractmethod from typing import Optional, Tuple import torch from torch import Tensor, nn -class DensityEstimator(nn.Module): +class Estimator(nn.Module, ABC): + r"""Base class for estimators i.e. neural nets estimating a certain quantity that + characterizes a distribution. This for example can be: + - Conditional density estimator of the posterior $p(\theta|x)$. + - Conditional density estimator of the likelihood $p(x|\theta)$. + - Estimator of the density ratio $p(x|\theta)/p(x)$. + - and more ... + """ + + def __init__(self, input_shape: torch.Size, condition_shape: torch.Size) -> None: + r"""Base class for estimators. + + Args: + input_shape: Event shape of the input at which the density is being + evaluated (and which is also the event_shape of samples). + condition_shape: Shape of the condition. If not provided, it will assume a + 1D input. + """ + super().__init__() + self._input_shape = torch.Size(input_shape) + self._condition_shape = torch.Size(condition_shape) + + @property + def input_shape(self) -> torch.Size: + r"""Return the input shape.""" + return self._input_shape + + @property + def condition_shape(self) -> torch.Size: + r"""Return the condition shape.""" + return self._condition_shape + + @abstractmethod + def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: + r"""Return the loss for training the estimator. + + Args: + input: Inputs to evaluate the loss on of shape + `(batch_dim, *input_event_shape)`. + condition: Conditions of shape `(batch_dim, *event_shape_condition)`. + + Returns: + Loss of shape (batch_dim,) + """ + pass + + def _check_condition_shape(self, condition: Tensor): + r"""This method checks whether the condition has the correct shape. + + Args: + condition: Conditions of shape `(batch_dim, *event_shape_condition)`. + + Raises: + ValueError: If the condition has a dimensionality that does not match + the expected input dimensionality. + ValueError: If the shape of the condition does not match the expected + input dimensionality. + """ + if len(condition.shape) < len(self.condition_shape): + raise ValueError( + f"Dimensionality of condition is to small and does not match the\ + expected input dimensionality {len(self.condition_shape)}, as provided\ + by condition_shape." + ) + else: + condition_shape = condition.shape[-len(self.condition_shape) :] + if tuple(condition_shape) != tuple(self.condition_shape): + raise ValueError( + f"Shape of condition {tuple(condition_shape)} does not match the \ + expected input dimensionality {tuple(self.condition_shape)}, as \ + provided by condition_shape. Please reshape it accordingly." + ) + + def _check_input_shape(self, input: Tensor): + r"""This method checks whether the input has the correct shape. + + Args: + input: Inputs to evaluate the log probability on of shape + `(sample_dim_input, batch_dim_input, *event_shape_input)`. + + Raises: + ValueError: If the input has a dimensionality that does not match + the expected input dimensionality. + ValueError: If the shape of the input does not match the expected + input dimensionality. + """ + if len(input.shape) < len(self.input_shape): + raise ValueError( + f"Dimensionality of input is to small and does not match the expected \ + input dimensionality {len(self.input_shape)}, as provided by \ + input_shape." + ) + else: + input_shape = input.shape[-len(self.input_shape) :] + if tuple(input_shape) != tuple(self.input_shape): + raise ValueError( + f"Shape of input {tuple(input_shape)} does not match the expected \ + input dimensionality {tuple(self.input_shape)}, as provided by \ + input_shape. Please reshape it accordingly." + ) + + +class DensityEstimator(Estimator): r"""Base class for density estimators. The density estimator class is a wrapper around neural networks that @@ -34,10 +137,8 @@ def __init__( condition_shape: Shape of the condition. If not provided, it will assume a 1D input. """ - super().__init__() + super().__init__(input_shape, condition_shape) self.net = net - self.input_shape = input_shape - self.condition_shape = condition_shape @property def embedding_net(self) -> Optional[nn.Module]: @@ -111,30 +212,3 @@ def sample_and_log_prob( samples = self.sample(sample_shape, condition, **kwargs) log_probs = self.log_prob(samples, condition, **kwargs) return samples, log_probs - - def _check_condition_shape(self, condition: Tensor): - r"""This method checks whether the condition has the correct shape. - - Args: - condition: Conditions of shape `(batch_dim, *event_shape_condition)`. - - Raises: - ValueError: If the condition has a dimensionality that does not match - the expected input dimensionality. - ValueError: If the shape of the condition does not match the expected - input dimensionality. - """ - if len(condition.shape) < len(self.condition_shape): - raise ValueError( - f"Dimensionality of condition is to small and does not match the\ - expected input dimensionality {len(self.condition_shape)}, as provided\ - by condition_shape." - ) - else: - condition_shape = condition.shape[-len(self.condition_shape) :] - if tuple(condition_shape) != tuple(self.condition_shape): - raise ValueError( - f"Shape of condition {tuple(condition_shape)} does not match the \ - expected input dimensionality {tuple(self.condition_shape)}, as \ - provided by condition_shape. Please reshape it accordingly." - ) From f87d6b695e087e882ec5ac98da503311875e4530 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 18:39:58 +0200 Subject: [PATCH 33/71] Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-posteriors' into amortizedsample" This reverts commit 07084e28fb586d43605dba6786d60c3e48ed96e5, reversing changes made to f16622d552e0dd69b17855bea9d672594e11d8ce. --- sbi/neural_nets/density_estimators/base.py | 136 +++++---------------- 1 file changed, 31 insertions(+), 105 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index 6d56fbf64..b3b83567c 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,116 +1,13 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from abc import ABC, abstractmethod from typing import Optional, Tuple import torch from torch import Tensor, nn -class Estimator(nn.Module, ABC): - r"""Base class for estimators i.e. neural nets estimating a certain quantity that - characterizes a distribution. This for example can be: - - Conditional density estimator of the posterior $p(\theta|x)$. - - Conditional density estimator of the likelihood $p(x|\theta)$. - - Estimator of the density ratio $p(x|\theta)/p(x)$. - - and more ... - """ - - def __init__(self, input_shape: torch.Size, condition_shape: torch.Size) -> None: - r"""Base class for estimators. - - Args: - input_shape: Event shape of the input at which the density is being - evaluated (and which is also the event_shape of samples). - condition_shape: Shape of the condition. If not provided, it will assume a - 1D input. - """ - super().__init__() - self._input_shape = torch.Size(input_shape) - self._condition_shape = torch.Size(condition_shape) - - @property - def input_shape(self) -> torch.Size: - r"""Return the input shape.""" - return self._input_shape - - @property - def condition_shape(self) -> torch.Size: - r"""Return the condition shape.""" - return self._condition_shape - - @abstractmethod - def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: - r"""Return the loss for training the estimator. - - Args: - input: Inputs to evaluate the loss on of shape - `(batch_dim, *input_event_shape)`. - condition: Conditions of shape `(batch_dim, *event_shape_condition)`. - - Returns: - Loss of shape (batch_dim,) - """ - pass - - def _check_condition_shape(self, condition: Tensor): - r"""This method checks whether the condition has the correct shape. - - Args: - condition: Conditions of shape `(batch_dim, *event_shape_condition)`. - - Raises: - ValueError: If the condition has a dimensionality that does not match - the expected input dimensionality. - ValueError: If the shape of the condition does not match the expected - input dimensionality. - """ - if len(condition.shape) < len(self.condition_shape): - raise ValueError( - f"Dimensionality of condition is to small and does not match the\ - expected input dimensionality {len(self.condition_shape)}, as provided\ - by condition_shape." - ) - else: - condition_shape = condition.shape[-len(self.condition_shape) :] - if tuple(condition_shape) != tuple(self.condition_shape): - raise ValueError( - f"Shape of condition {tuple(condition_shape)} does not match the \ - expected input dimensionality {tuple(self.condition_shape)}, as \ - provided by condition_shape. Please reshape it accordingly." - ) - - def _check_input_shape(self, input: Tensor): - r"""This method checks whether the input has the correct shape. - - Args: - input: Inputs to evaluate the log probability on of shape - `(sample_dim_input, batch_dim_input, *event_shape_input)`. - - Raises: - ValueError: If the input has a dimensionality that does not match - the expected input dimensionality. - ValueError: If the shape of the input does not match the expected - input dimensionality. - """ - if len(input.shape) < len(self.input_shape): - raise ValueError( - f"Dimensionality of input is to small and does not match the expected \ - input dimensionality {len(self.input_shape)}, as provided by \ - input_shape." - ) - else: - input_shape = input.shape[-len(self.input_shape) :] - if tuple(input_shape) != tuple(self.input_shape): - raise ValueError( - f"Shape of input {tuple(input_shape)} does not match the expected \ - input dimensionality {tuple(self.input_shape)}, as provided by \ - input_shape. Please reshape it accordingly." - ) - - -class DensityEstimator(Estimator): +class DensityEstimator(nn.Module): r"""Base class for density estimators. The density estimator class is a wrapper around neural networks that @@ -137,8 +34,10 @@ def __init__( condition_shape: Shape of the condition. If not provided, it will assume a 1D input. """ - super().__init__(input_shape, condition_shape) + super().__init__() self.net = net + self.input_shape = input_shape + self.condition_shape = condition_shape @property def embedding_net(self) -> Optional[nn.Module]: @@ -212,3 +111,30 @@ def sample_and_log_prob( samples = self.sample(sample_shape, condition, **kwargs) log_probs = self.log_prob(samples, condition, **kwargs) return samples, log_probs + + def _check_condition_shape(self, condition: Tensor): + r"""This method checks whether the condition has the correct shape. + + Args: + condition: Conditions of shape `(batch_dim, *event_shape_condition)`. + + Raises: + ValueError: If the condition has a dimensionality that does not match + the expected input dimensionality. + ValueError: If the shape of the condition does not match the expected + input dimensionality. + """ + if len(condition.shape) < len(self.condition_shape): + raise ValueError( + f"Dimensionality of condition is to small and does not match the\ + expected input dimensionality {len(self.condition_shape)}, as provided\ + by condition_shape." + ) + else: + condition_shape = condition.shape[-len(self.condition_shape) :] + if tuple(condition_shape) != tuple(self.condition_shape): + raise ValueError( + f"Shape of condition {tuple(condition_shape)} does not match the \ + expected input dimensionality {tuple(self.condition_shape)}, as \ + provided by condition_shape. Please reshape it accordingly." + ) From dbd01099f328d91b8e098273ae324979a98299eb Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 08:27:39 +0200 Subject: [PATCH 34/71] fixes current bug! --- sbi/neural_nets/density_estimators/nflows_flow.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 8d6aaba55..3f162493f 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -136,7 +136,13 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: samples = self.net.sample(num_samples, context=condition) samples = samples.transpose(0, 1) - return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape)) + return samples.reshape( + ( + *sample_shape, + condition_batch_dim, + ) + + self.input_shape + ) def sample_and_log_prob( self, sample_shape: torch.Size, condition: Tensor, **kwargs From 264b6c4d28afc2e66af52663567f8e5ac80df1ad Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 10:44:15 +0200 Subject: [PATCH 35/71] Added tests --- sbi/neural_nets/density_estimators/nflows_flow.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 3f162493f..8d6aaba55 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -136,13 +136,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: samples = self.net.sample(num_samples, context=condition) samples = samples.transpose(0, 1) - return samples.reshape( - ( - *sample_shape, - condition_batch_dim, - ) - + self.input_shape - ) + return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape)) def sample_and_log_prob( self, sample_shape: torch.Size, condition: Tensor, **kwargs From 339b57b0435c10ba56232ae7d42d5c756c5dd202 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 17:25:11 +0200 Subject: [PATCH 36/71] batched_rejection_sampling --- sbi/inference/posteriors/direct_posterior.py | 29 ++++++++++++++++++++ sbi/samplers/rejection/rejection.py | 4 +-- sbi/utils/sbiutils.py | 3 +- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index ed131e218..db39801bf 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -186,6 +186,35 @@ def amortized_sample( return samples + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10_000, + show_progress_bars: bool = True, + ) -> Tensor: + num_samples = torch.Size(sample_shape).numel() + condition_shape = self.posterior_estimator.condition_shape + x = reshape_to_batch_event(x, event_shape=condition_shape) + print(x.shape) + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size + ) + + samples = accept_reject_sample( + proposal=self.posterior_estimator, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"condition": x}, + alternative_method="build_posterior(..., sample_with='mcmc')", + )[0] + + return samples + def log_prob( self, theta: Tensor, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index e3c7b221a..d21f52d12 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -300,8 +300,8 @@ def accept_reject_sample( # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the - # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work - # in dim = -2. + # samples will be of shape(sampling_batch_size,*batch_shape, *event_shape) + # and hence work in dim = 0. num_sampled_total += sampling_batch_size pbar.update(num_samples - num_remaining) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 58146ccae..a2b0f102f 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -627,7 +627,8 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: try: sample_check = distribution.support.check(samples) if sample_check.shape == samples.shape: - sample_check = torch.all(sample_check, dim=-1) + # With new shapeing conventions we need dim=-2 + sample_check = torch.all(sample_check, dim=-2) return sample_check # Falling back to log prob method of either the NeuralPosterior's net, or of a From 676c2719c5bc386586deb9af4e9f8ea2c374ea0a Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 19:40:59 +0200 Subject: [PATCH 37/71] sample works, try log_prob_batched --- sbi/inference/posteriors/direct_posterior.py | 97 ++++++++++---------- sbi/samplers/rejection/rejection.py | 33 +++---- sbi/utils/sbiutils.py | 2 +- 3 files changed, 69 insertions(+), 63 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index db39801bf..9bb9b38d9 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -135,12 +135,11 @@ def sample( return samples[:, 0] # Remove batch dimension. - def amortized_sample( + def sample_batched( self, - sample_shape: Shape = torch.Size(), - x: Optional[Tensor] = None, + sample_shape: Shape, + x: Tensor, max_sampling_batch_size: int = 10_000, - sample_with: Optional[str] = None, show_progress_bars: bool = True, ) -> Tensor: r"""Return samples from posterior $p(\theta|x)$ given multiple observations. @@ -150,53 +149,13 @@ def amortized_sample( given every observation. x: A batch of observations, of shape `(batch_dim, event_shape_x)`. `batch_dim` corresponds to the number of observations to be drawn. - sample_with: This argument only exists to keep backward-compatibility with - `sbi` v0.17.2 or older. If it is set, we instantly raise an error. + max_sampling_batch_size: Maximum batch size for rejection sampling. show_progress_bars: Whether to show sampling progress monitor. """ - - num_samples = torch.Size(sample_shape).numel() - # x = self._x_else_default_x(x) - x = reshape_to_batch_event( - x, event_shape=self.posterior_estimator.condition_shape - ) - - max_sampling_batch_size = ( - self.max_sampling_batch_size - if max_sampling_batch_size is None - else max_sampling_batch_size - ) - - if sample_with is not None: - raise ValueError( - f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting " - f"`sample_with` is no longer supported. You have to rerun " - f"`.build_posterior(sample_with={sample_with}).`" - ) - - samples = rejection.accept_reject_sample( - proposal=self.posterior_estimator, - accept_reject_fn=lambda theta: within_support(self.prior, theta), - num_samples=num_samples, - show_progress_bars=show_progress_bars, - max_sampling_batch_size=max_sampling_batch_size, - proposal_sampling_kwargs={"condition": x}, - alternative_method="build_posterior(..., sample_with='mcmc')", - )[0] - - return samples - - def sample_batched( - self, - sample_shape: Shape, - x: Tensor, - max_sampling_batch_size: int = 10_000, - show_progress_bars: bool = True, - ) -> Tensor: num_samples = torch.Size(sample_shape).numel() condition_shape = self.posterior_estimator.condition_shape x = reshape_to_batch_event(x, event_shape=condition_shape) - print(x.shape) + max_sampling_batch_size = ( self.max_sampling_batch_size if max_sampling_batch_size is None @@ -290,6 +249,52 @@ def log_prob( return masked_log_prob - log_factor + def log_prob_batched( + self, + theta: Tensor, + x: Tensor, + norm_posterior: bool = True, + track_gradients: bool = False, + leakage_correction_params: Optional[dict] = None, + ) -> Tensor: + theta = ensure_theta_batched(torch.as_tensor(theta)) + theta_density_estimator = reshape_to_sample_batch_event( + theta, theta.shape[1:], leading_is_sample=True + ) + x_density_estimator = reshape_to_batch_event( + x, event_shape=self.posterior_estimator.condition_shape + ) + + self.posterior_estimator.eval() + + with torch.set_grad_enabled(track_gradients): + # Evaluate on device, move back to cpu for comparison with prior. + unnorm_log_prob = self.posterior_estimator.log_prob( + theta_density_estimator, condition=x_density_estimator + ) + # `log_prob` supports only a single observation (i.e. `batchsize==1`). + # We now remove this additional dimension. + unnorm_log_prob = unnorm_log_prob.squeeze(dim=1) + + # Force probability to be zero outside prior support. + in_prior_support = within_support(self.prior, theta) + + masked_log_prob = torch.where( + in_prior_support, + unnorm_log_prob, + torch.tensor(float("-inf"), dtype=torch.float32, device=self._device), + ) + + if leakage_correction_params is None: + leakage_correction_params = dict() # use defaults + log_factor = ( + log(self.leakage_correction(x=x, **leakage_correction_params)) + if norm_posterior + else 0 + ) + + return masked_log_prob - log_factor + @torch.no_grad() def leakage_correction( self, diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index d21f52d12..c0e65512b 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -179,9 +179,10 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: # When in case of leakage a batch size was used there could be too many samples. samples = torch.cat(accepted)[:num_samples] - assert ( - samples.shape[0] == num_samples - ), "Number of accepted samples must match required samples." + print(samples.shape) + # assert ( + # samples.shape[0] == num_samples + # ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) @@ -263,10 +264,8 @@ def accept_reject_sample( if "condition" in list(proposal_sampling_kwargs.keys()): num_xo = proposal_sampling_kwargs["condition"].shape[0] else: - num_xo = 1 - - accepted_every_obs = [torch.tensor(()) for _ in range(num_xo)] - accepted, acceptance_rate = [], float("Nan") + num_xos = 1 + accepted, acceptance_rate = [[] for _ in range(num_xos)], float("Nan") leakage_warning_raised = False @@ -292,18 +291,18 @@ def accept_reject_sample( # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - for obs_index in range(num_xo): - accepted = candidates[are_accepted[:, obs_index], obs_index] - accepted_every_obs[obs_index] = torch.cat([accepted_every_obs[obs_index], accepted], dim=0) - lowest_num_accepted = min(len(s) for s in accepted_every_obs) - num_remaining = num_samples - lowest_num_accepted + num_accepted = are_accepted.sum(dim=0).min().item() + + for i in range(num_xos): + accepted[i].append(candidates[are_accepted[:, i], i]) # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the # samples will be of shape(sampling_batch_size,*batch_shape, *event_shape) # and hence work in dim = 0. num_sampled_total += sampling_batch_size - pbar.update(num_samples - num_remaining) + num_remaining -= num_accepted + pbar.update(num_accepted) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. @@ -357,8 +356,10 @@ def accept_reject_sample( accepted_every_obs = torch.stack(accepted_every_obs) # When in case of leakage a batch size was used there could be too many samples. - assert ( - accepted_every_obs.shape[-2] == num_samples - ), "Number of accepted samples must match required samples." + samples = [torch.cat(accepted[i], dim=0)[:num_samples] for i in range(num_xos)] + samples = torch.stack(samples, dim=1) + # assert ( + # samples.shape[0] == num_samples + # ), "Number of accepted samples must match required samples." return torch.permute(accepted_every_obs, (1, 0, -1)), as_tensor(acceptance_rate) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index a2b0f102f..70635fa44 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -628,7 +628,7 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: sample_check = distribution.support.check(samples) if sample_check.shape == samples.shape: # With new shapeing conventions we need dim=-2 - sample_check = torch.all(sample_check, dim=-2) + sample_check = torch.all(sample_check, dim=-1) return sample_check # Falling back to log prob method of either the NeuralPosterior's net, or of a From 7a8a84def5d716919268f2e0d1a80d20073c43f2 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 19:50:29 +0200 Subject: [PATCH 38/71] log_prob_batched works --- sbi/inference/posteriors/base_posterior.py | 9 ++++----- sbi/inference/posteriors/direct_posterior.py | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index f5b9cf62e..4aaf1385e 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -122,13 +122,12 @@ def sample( pass @abstractmethod - def amortized_sample( + def sample_batched( self, - sample_shape: Shape = torch.Size(), - x: Optional[Tensor] = None, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10_000, show_progress_bars: bool = True, - mcmc_method: Optional[str] = None, - mcmc_parameters: Optional[Dict[str, Any]] = None, ) -> Tensor: """See child classes for docstring.""" pass diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 9bb9b38d9..6d2a7f708 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -258,13 +258,15 @@ def log_prob_batched( leakage_correction_params: Optional[dict] = None, ) -> Tensor: theta = ensure_theta_batched(torch.as_tensor(theta)) + event_shape = self.posterior_estimator.input_shape theta_density_estimator = reshape_to_sample_batch_event( - theta, theta.shape[1:], leading_is_sample=True + theta, event_shape, leading_is_sample=True ) x_density_estimator = reshape_to_batch_event( x, event_shape=self.posterior_estimator.condition_shape ) + print(theta_density_estimator.shape, x_density_estimator.shape) self.posterior_estimator.eval() with torch.set_grad_enabled(track_gradients): From 5daab922ae386640908145337f8ad2d4f5bab717 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 20:00:19 +0200 Subject: [PATCH 39/71] abstract method implement for other methods --- sbi/inference/posteriors/ensemble_posterior.py | 9 +++++++++ sbi/inference/posteriors/importance_posterior.py | 11 +++++++++++ sbi/inference/posteriors/mcmc_posterior.py | 5 ++--- sbi/inference/posteriors/rejection_posterior.py | 11 +++++++++++ sbi/inference/posteriors/vi_posterior.py | 9 +++++++++ 5 files changed, 42 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 72af02d88..abb2a1a3d 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -179,6 +179,15 @@ def sample( ) return torch.vstack(samples).reshape(*sample_shape, -1) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError("This method is not implemented yet.") + def log_prob( self, theta: Tensor, diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index bbd4ce32f..62b295d4d 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -194,6 +194,17 @@ def sample( else: raise NameError + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not implemented for ImportanceSamplingPosterior." + ) + def _importance_sample( self, sample_shape: Shape = torch.Size(), diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index c47db3e7b..7a0613ff1 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -356,8 +356,7 @@ def sample( return samples.reshape((*sample_shape, -1)) # type: ignore - - def amortized_sample( + def sample_batched( self, sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, @@ -427,7 +426,7 @@ def amortized_sample( ) print("transformed_samples", transformed_samples.shape) samples = self.theta_transform.inv(transformed_samples) - print("samples", samples.shape) + num_obs = 5 return samples.reshape((*sample_shape, num_obs, -1)) # type: ignore diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py index 6da838059..5eb53497b 100644 --- a/sbi/inference/posteriors/rejection_posterior.py +++ b/sbi/inference/posteriors/rejection_posterior.py @@ -167,6 +167,17 @@ def sample( return samples.reshape((*sample_shape, -1)) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not supported for rejection sampling." + ) + def map( self, x: Optional[Tensor] = None, diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 006ab543a..fd89a5654 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -296,6 +296,15 @@ def sample( samples = self.q.sample(torch.Size(sample_shape)) return samples.reshape((*sample_shape, samples.shape[-1])) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError("Batched sampling is not supported for VIPosterior.") + def log_prob( self, theta: Tensor, From 40897a085d7b77573af41af76207261153c54bd4 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 20:47:18 +0200 Subject: [PATCH 40/71] temp fix mcmcposterior --- sbi/inference/posteriors/mcmc_posterior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 7a0613ff1..192f8bed5 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -594,8 +594,8 @@ def multi_obs_potential(params): # `all_potentials` is of shape (num_chains, num_obs). all_potentials = potential_function(params) return all_potentials.flatten() - - num_obs = 5 + + num_obs = 1 # TODO This will fail for num_obs > 1 in embedding_net_test.py initial_params = torch.concatenate([initial_params] * num_obs) posterior_sampler = SliceSamplerMultiChain( From a2b7e32980741365a7300db3641b0989f36f473b Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 21:43:48 +0200 Subject: [PATCH 41/71] meh for general use i.e. in the restriction prior we have to add some reshapes in rejection --- sbi/samplers/rejection/rejection.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index c0e65512b..3e855b9f7 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -179,10 +179,10 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: # When in case of leakage a batch size was used there could be too many samples. samples = torch.cat(accepted)[:num_samples] - print(samples.shape) - # assert ( - # samples.shape[0] == num_samples - # ), "Number of accepted samples must match required samples." + # print(samples.shape) + assert ( + samples.shape[0] == num_samples + ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) @@ -288,13 +288,16 @@ def accept_reject_sample( (sampling_batch_size,), # type: ignore **proposal_sampling_kwargs, ) - # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) + are_accepted = are_accepted.reshape(sampling_batch_size, num_xos) + candidates_to_reject = candidates.reshape( + sampling_batch_size, num_xos, *candidates.shape[1:] + ) num_accepted = are_accepted.sum(dim=0).min().item() - + # print(are_accepted.shape) for i in range(num_xos): - accepted[i].append(candidates[are_accepted[:, i], i]) + accepted[i].append(candidates_to_reject[are_accepted[:, i], i]) # Update. # Note: For any condition of shape (*batch_shape, *condition_shape), the @@ -358,8 +361,9 @@ def accept_reject_sample( # When in case of leakage a batch size was used there could be too many samples. samples = [torch.cat(accepted[i], dim=0)[:num_samples] for i in range(num_xos)] samples = torch.stack(samples, dim=1) - # assert ( - # samples.shape[0] == num_samples - # ), "Number of accepted samples must match required samples." + samples = samples.reshape(num_samples, *candidates.shape[1:]) + assert ( + samples.shape[0] == num_samples + ), "Number of accepted samples must match required samples." return torch.permute(accepted_every_obs, (1, 0, -1)), as_tensor(acceptance_rate) From cb4d8ae7cc6df087aef8476add3ef9808b773949 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 7 May 2024 21:57:52 +0200 Subject: [PATCH 42/71] ... test class --- tests/test_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index b6730e65b..a1cea1e07 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -246,6 +246,17 @@ def sample( return self.potential_fn.posterior.sample(sample_shape) + def sample_batched( + self, + sample_shape: Shape, + x: Tensor, + max_sampling_batch_size: int = 10000, + show_progress_bars: bool = True, + ) -> Tensor: + raise NotImplementedError( + "Batched sampling is not supported for TractablePosterior." + ) + def log_prob( self, theta: Tensor, From ab9b1e1955c36939a2e528893d780ca3d3ee0aac Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 07:26:07 +0200 Subject: [PATCH 43/71] Revert "Base estimator class" This reverts commit 17c534303343bd6306ea8e45fd4085a929ba42c2. --- sbi/neural_nets/density_estimators/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index b3b83567c..252c850bc 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,6 +1,3 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Apache License Version 2.0, see - from typing import Optional, Tuple import torch From d2b1a627b0e9748fee154dbfe25df0c92365d555 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 07:27:06 +0200 Subject: [PATCH 44/71] removing previous change --- sbi/utils/sbiutils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 70635fa44..58146ccae 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -627,7 +627,6 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: try: sample_check = distribution.support.check(samples) if sample_check.shape == samples.shape: - # With new shapeing conventions we need dim=-2 sample_check = torch.all(sample_check, dim=-1) return sample_check From a0c0c979f7a7bfef1206b8755881b69549671c83 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 07:33:04 +0200 Subject: [PATCH 45/71] removing some artifacts --- sbi/inference/potentials/likelihood_based_potential.py | 5 ----- sbi/samplers/rejection/rejection.py | 5 +++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index a9e94e7e2..eab36e91f 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -123,8 +123,6 @@ def _log_likelihoods_over_trials( log_likelihood_trial_sum: log likelihood for each parameter, summed over all batch entries (iid trials) in `x`. """ - # print("x", x.shape) - # print("theta", theta.shape) # Shape of `x` is (iid_dim, *event_shape). x = reshape_to_sample_batch_event( x, event_shape=x.shape[1:], leading_is_sample=True @@ -148,9 +146,6 @@ def _log_likelihoods_over_trials( # `DensityEstimator.log_prob`. theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:]) - # print("After reshape theta: ", theta.shape) - # print("After reshape x: ", x.shape) - # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): log_likelihood_trial_batch = estimator.log_prob(x, condition=theta) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 3e855b9f7..1ba6bcb05 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -179,7 +179,7 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: # When in case of leakage a batch size was used there could be too many samples. samples = torch.cat(accepted)[:num_samples] - # print(samples.shape) + assert ( samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." @@ -290,12 +290,13 @@ def accept_reject_sample( ) # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) + # Reshape necessary in certain cases which do not follow the shape conventions + # of the "DensityEstimator" class. are_accepted = are_accepted.reshape(sampling_batch_size, num_xos) candidates_to_reject = candidates.reshape( sampling_batch_size, num_xos, *candidates.shape[1:] ) num_accepted = are_accepted.sum(dim=0).min().item() - # print(are_accepted.shape) for i in range(num_xos): accepted[i].append(candidates_to_reject[are_accepted[:, i], i]) From 8fc5a46b30809371c04c1ae06c5350a07b9a46a6 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 07:41:17 +0200 Subject: [PATCH 46/71] revert wierd change --- sbi/neural_nets/density_estimators/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index 252c850bc..b3b83567c 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,3 +1,6 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see + from typing import Optional, Tuple import torch From 18c7d36515bfcb1c8fd7a2565e0ccf1ae805f130 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 8 May 2024 08:14:10 +0200 Subject: [PATCH 47/71] docs and tests --- sbi/inference/posteriors/direct_posterior.py | 33 +++++++++++++++++--- sbi/samplers/rejection/rejection.py | 3 +- tests/posterior_nn_test.py | 31 ++++++++++++++++++ 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 6d2a7f708..a0eb6ba6a 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -257,6 +257,35 @@ def log_prob_batched( track_gradients: bool = False, leakage_correction_params: Optional[dict] = None, ) -> Tensor: + """Returns the log-probabilities of the posteriors $p(\theta_1|x_1),..., \ + p(\theta_B|x_B)$. + + Args: + theta: Batch of parameters $\theta$ of shape \ + `(*sample_shape, batch_dim, *theta_shape)`. + x: Batch of observations $x$ of shape \ + `(batch_dim, *condition_shape)`. + norm_posterior: Whether to enforce a normalized posterior density. + Renormalization of the posterior is useful when some + probability falls out or leaks out of the prescribed prior support. + The normalizing factor is calculated via rejection sampling, so if you + need speedier but unnormalized log posterior estimates set here + `norm_posterior=False`. The returned log posterior is set to + -∞ outside of the prior support regardless of this setting. + track_gradients: Whether the returned tensor supports tracking gradients. + This can be helpful for e.g. sensitivity analysis, but increases memory + consumption. + leakage_correction_params: A `dict` of keyword arguments to override the + default values of `leakage_correction()`. Possible options are: + `num_rejection_samples`, `force_update`, `show_progress_bars`, and + `rejection_sampling_batch_size`. + These parameters only have an effect if `norm_posterior=True`. + + Returns: + `(len(θ), B)`-shaped log posterior probability $\\log p(\theta|x)$\\ for θ \ + in the support of the prior, -∞ (corresponding to 0 probability) outside. + """ + theta = ensure_theta_batched(torch.as_tensor(theta)) event_shape = self.posterior_estimator.input_shape theta_density_estimator = reshape_to_sample_batch_event( @@ -266,7 +295,6 @@ def log_prob_batched( x, event_shape=self.posterior_estimator.condition_shape ) - print(theta_density_estimator.shape, x_density_estimator.shape) self.posterior_estimator.eval() with torch.set_grad_enabled(track_gradients): @@ -274,9 +302,6 @@ def log_prob_batched( unnorm_log_prob = self.posterior_estimator.log_prob( theta_density_estimator, condition=x_density_estimator ) - # `log_prob` supports only a single observation (i.e. `batchsize==1`). - # We now remove this additional dimension. - unnorm_log_prob = unnorm_log_prob.squeeze(dim=1) # Force probability to be zero outside prior support. in_prior_support = within_support(self.prior, theta) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 1ba6bcb05..32a608923 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -294,8 +294,9 @@ def accept_reject_sample( # of the "DensityEstimator" class. are_accepted = are_accepted.reshape(sampling_batch_size, num_xos) candidates_to_reject = candidates.reshape( - sampling_batch_size, num_xos, *candidates.shape[1:] + sampling_batch_size, num_xos, *candidates.shape[candidates.ndim - 1 :] ) + num_accepted = are_accepted.sum(dim=0).min().item() for i in range(num_xos): accepted[i].append(candidates_to_reject[are_accepted[:, i], i]) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 33f4c29e4..593660341 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -49,3 +49,34 @@ def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool): ).set_default_x(x_o) samples = posterior.sample((10,)) _ = posterior.log_prob(samples) + + +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +@pytest.mark.parametrize( + "x_o_batch_dim", + ( + 0, + 1, + 2, + ), +) +def test_batched_sample_log_prob_with_different_x( + snpe_method: type, x_o_batch_dim: bool +): + num_dim = 2 + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + simulator = diagonal_linear_gaussian + + inference = snpe_method(prior=prior) + theta, x = simulate_for_sbi(simulator, prior, 1000) + posterior_estimator = inference.append_simulations(theta, x).train(max_num_epochs=3) + + x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) + + posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior) + + samples = posterior.sample_batched((10,), x_o) + batched_log_probs = posterior.log_prob_batched(samples, x_o) + + assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) From 6ad6cb7e4b6cd55f6f0fe7be78fcc496143edb45 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 14 May 2024 09:09:30 +0200 Subject: [PATCH 48/71] MCMC sample_batched works but not log_prob batched --- .../posteriors/ensemble_posterior.py | 5 +- sbi/inference/posteriors/mcmc_posterior.py | 72 +++++++++++++------ tests/posterior_nn_test.py | 45 ++++++++++++ 3 files changed, 100 insertions(+), 22 deletions(-) diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index abb2a1a3d..58e5cb53c 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -186,7 +186,10 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: - raise NotImplementedError("This method is not implemented yet.") + raise NotImplementedError( + "Batched sampling is not implemented for \ + EnsemblePosterior." + ) def log_prob( self, diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 192f8bed5..416343e7a 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -208,6 +208,38 @@ def log_prob( theta.to(self._device), track_gradients=track_gradients ) + def log_prob_batched( + self, theta: Tensor, x: Tensor, track_gradients: bool = False + ) -> Tensor: + r"""Returns the log-probability of theta under the multiple posteriors. + + Args: + theta: Parameters $\theta$. + track_gradients: Whether the returned tensor supports tracking gradients. + This can be helpful for e.g. sensitivity analysis, but increases memory + consumption. + + Returns: + `len($\theta$)`-shaped log-probability. + """ + warn( + """`.log_prob()` is deprecated for methods that can only evaluate the + log-probability up to a normalizing constant. Use `.potential()` + instead.""", + stacklevel=2, + ) + warn("The log-probability is unnormalized!", stacklevel=2) + + self.potential_fn.set_x(x) + print(x.shape) + theta = ensure_theta_batched(torch.as_tensor(theta)) + print(theta.shape) + potential = self.potential_fn( + theta.to(self._device), track_gradients=track_gradients + ) + print(potential) + return potential + def sample( self, sample_shape: Shape = torch.Size(), @@ -351,15 +383,16 @@ def sample( ) else: raise NameError(f"The sampling method {method} is not implemented!") - + print(transformed_samples.shape) samples = self.theta_transform.inv(transformed_samples) + samples = samples.reshape((*sample_shape, -1)) # type: ignore return samples.reshape((*sample_shape, -1)) # type: ignore def sample_batched( self, - sample_shape: Shape = torch.Size(), - x: Optional[Tensor] = None, + sample_shape: Shape, + x: Tensor, method: Optional[str] = None, thin: Optional[int] = None, warmup_steps: Optional[int] = None, @@ -384,7 +417,8 @@ def sample_batched( Returns: Samples from posterior. """ - self.potential_fn.set_x(self._x_else_default_x(x)) + batch_size = x.shape[0] + self.potential_fn.set_x(x) # Replace arguments that were not passed with their default. method = self.method if method is None else method @@ -401,18 +435,21 @@ def sample_batched( ) self.potential_ = self._prepare_potential(method) # type: ignore - print("Getting initial params") + num_chains_extended = batch_size * num_chains initial_params = self._get_initial_params( init_strategy, # type: ignore - num_chains, # type: ignore + num_chains_extended, # type: ignore num_workers, show_progress_bars, **init_strategy_parameters, ) - print("Finished init") - num_samples = torch.Size(sample_shape).numel() - assert method == "slice_np_vectorized" + num_samples = torch.Size(sample_shape).numel() * batch_size + + assert ( + method == "slice_np_vectorized" + ), "Batched sampling only supported for vectorized samplers!" + with torch.set_grad_enabled(False): transformed_samples = self._slice_np_mcmc( num_samples=num_samples, @@ -424,12 +461,10 @@ def sample_batched( num_workers=num_workers, show_progress_bars=show_progress_bars, ) - print("transformed_samples", transformed_samples.shape) - samples = self.theta_transform.inv(transformed_samples) - num_obs = 5 + samples = self.theta_transform.inv(transformed_samples) - return samples.reshape((*sample_shape, num_obs, -1)) # type: ignore + return samples.reshape((*sample_shape, batch_size, -1)) # type: ignore def _build_mcmc_init_fn( @@ -588,20 +623,15 @@ def multi_obs_potential(params): # Params are of shape (num_chains * num_obs, event). # We now reshape them to (num_chains, num_obs, event). # params = np.reshape(params, (num_chains, num_obs, -1)) - # print("params", params.shape) - # print("potential_function", potential_function) # `all_potentials` is of shape (num_chains, num_obs). all_potentials = potential_function(params) return all_potentials.flatten() - num_obs = 1 # TODO This will fail for num_obs > 1 in embedding_net_test.py - initial_params = torch.concatenate([initial_params] * num_obs) - posterior_sampler = SliceSamplerMultiChain( init_params=tensor2numpy(initial_params), log_prob_fn=multi_obs_potential, - num_chains=num_chains * num_obs, + num_chains=num_chains, thin=thin, verbose=show_progress_bars, num_workers=num_workers, @@ -620,10 +650,10 @@ def multi_obs_potential(params): self._posterior_sampler = posterior_sampler # Save sample as potential next init (if init_strategy == 'latest_sample'). - self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, num_obs, dim_samples) + self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) # Collect samples from all chains. - samples = samples.reshape(-1, num_obs, dim_samples)[:num_samples] + samples = samples.reshape(-1, dim_samples)[:num_samples] return samples.type(torch.float32).to(self._device) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 593660341..a7efef04c 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -8,6 +8,7 @@ from torch.distributions import MultivariateNormal from sbi.inference import ( + SNLE_A, SNPE_A, SNPE_C, DirectPosterior, @@ -79,4 +80,48 @@ def test_batched_sample_log_prob_with_different_x( samples = posterior.sample_batched((10,), x_o) batched_log_probs = posterior.log_prob_batched(samples, x_o) + assert ( + samples.shape == (10, x_o_batch_dim, num_dim) + if x_o_batch_dim > 0 + else (10, num_dim) + ) assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) + + +@pytest.mark.mcmc +@pytest.mark.parametrize("snlre_method", [SNLE_A]) +@pytest.mark.parametrize( + "x_o_batch_dim", + ( + 0, + 1, + 2, + ), +) +def test_batched_mcmc_sample_log_prob_with_different_x( + snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict +): + num_dim = 2 + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + simulator = diagonal_linear_gaussian + + inference = snlre_method(prior=prior) + theta, x = simulate_for_sbi(simulator, prior, 1000) + _ = inference.append_simulations(theta, x).train(max_num_epochs=3) + + x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) + + posterior = inference.build_posterior( + mcmc_method="slice_np_vectorized", mcmc_parameters=mcmc_params_fast + ) + + samples = posterior.sample_batched((10,), x_o) + # batched_log_probs = posterior.log_prob_batched(samples, x_o) + + assert ( + samples.shape == (10, x_o_batch_dim, num_dim) + if x_o_batch_dim > 0 + else (10, num_dim) + ) + # assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) From 03c10f361d4a37a195aee264e39b662364c2b638 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 14 May 2024 11:03:06 +0200 Subject: [PATCH 49/71] adding some docs --- sbi/inference/posteriors/direct_posterior.py | 13 +++-- .../posteriors/ensemble_posterior.py | 7 ++- .../posteriors/importance_posterior.py | 5 +- sbi/inference/posteriors/mcmc_posterior.py | 49 ++++--------------- .../posteriors/rejection_posterior.py | 4 +- sbi/inference/posteriors/vi_posterior.py | 6 ++- tests/posterior_nn_test.py | 5 +- 7 files changed, 40 insertions(+), 49 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index a0eb6ba6a..08e4b5a5f 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -142,7 +142,9 @@ def sample_batched( max_sampling_batch_size: int = 10_000, show_progress_bars: bool = True, ) -> Tensor: - r"""Return samples from posterior $p(\theta|x)$ given multiple observations. + r"""Given a batch of observations [x_1, ..., x_B] this function samples from + posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized) + manner. Args: sample_shape: Desired shape of samples that are drawn from the posterior @@ -151,6 +153,9 @@ def sample_batched( `batch_dim` corresponds to the number of observations to be drawn. max_sampling_batch_size: Maximum batch size for rejection sampling. show_progress_bars: Whether to show sampling progress monitor. + + Returns: + Samples from the posteriors of shape (*sample_shape, B, *input_shape) """ num_samples = torch.Size(sample_shape).numel() condition_shape = self.posterior_estimator.condition_shape @@ -257,8 +262,10 @@ def log_prob_batched( track_gradients: bool = False, leakage_correction_params: Optional[dict] = None, ) -> Tensor: - """Returns the log-probabilities of the posteriors $p(\theta_1|x_1),..., \ - p(\theta_B|x_B)$. + """Given a batch of observations [x_1, ..., x_B] and a batch of parameters \ + [$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \ + of the posterior $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \ + (i.e. vectorized) manner. Args: theta: Batch of parameters $\theta$ of shape \ diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 58e5cb53c..e895353c3 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -186,9 +186,12 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: + # TODO Can be implemented in the future, for all base posterior that support + # batched sampling. raise NotImplementedError( - "Batched sampling is not implemented for \ - EnsemblePosterior." + "Batched sampling is not implemented for EnsemblePosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." ) def log_prob( diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index 62b295d4d..0b659c84e 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -201,8 +201,11 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: + # TODO Can be implemented in the future. raise NotImplementedError( - "Batched sampling is not implemented for ImportanceSamplingPosterior." + "Batched sampling is not implemented for ImportanceSamplingPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." ) def _importance_sample( diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 416343e7a..579831538 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -208,38 +208,6 @@ def log_prob( theta.to(self._device), track_gradients=track_gradients ) - def log_prob_batched( - self, theta: Tensor, x: Tensor, track_gradients: bool = False - ) -> Tensor: - r"""Returns the log-probability of theta under the multiple posteriors. - - Args: - theta: Parameters $\theta$. - track_gradients: Whether the returned tensor supports tracking gradients. - This can be helpful for e.g. sensitivity analysis, but increases memory - consumption. - - Returns: - `len($\theta$)`-shaped log-probability. - """ - warn( - """`.log_prob()` is deprecated for methods that can only evaluate the - log-probability up to a normalizing constant. Use `.potential()` - instead.""", - stacklevel=2, - ) - warn("The log-probability is unnormalized!", stacklevel=2) - - self.potential_fn.set_x(x) - print(x.shape) - theta = ensure_theta_batched(torch.as_tensor(theta)) - print(theta.shape) - potential = self.potential_fn( - theta.to(self._device), track_gradients=track_gradients - ) - print(potential) - return potential - def sample( self, sample_shape: Shape = torch.Size(), @@ -383,7 +351,7 @@ def sample( ) else: raise NameError(f"The sampling method {method} is not implemented!") - print(transformed_samples.shape) + samples = self.theta_transform.inv(transformed_samples) samples = samples.reshape((*sample_shape, -1)) # type: ignore @@ -403,19 +371,22 @@ def sample_batched( mp_context: Optional[str] = None, show_progress_bars: bool = True, ) -> Tensor: - r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC. + r"""Given a batch of observations [x_1, ..., x_B] this function samples from + posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized) + manner. Check the `__init__()` method for a description of all arguments as well as their default values. Args: - sample_shape: Desired shape of samples that are drawn from posterior. If - sample_shape is multidimensional we simply draw `sample_shape.numel()` - samples and then reshape into the desired shape. + sample_shape: Desired shape of samples that are drawn from the posterior + given every observation. + x: A batch of observations, of shape `(batch_dim, event_shape_x)`. + `batch_dim` corresponds to the number of observations to be drawn. show_progress_bars: Whether to show sampling progress monitor. Returns: - Samples from posterior. + Samples from the posteriors of shape (*sample_shape, B, *input_shape) """ batch_size = x.shape[0] self.potential_fn.set_x(x) @@ -640,9 +611,7 @@ def multi_obs_potential(params): warmup_ = warmup_steps * thin num_samples_ = ceil((num_samples * thin) / num_chains) # Run mcmc including warmup - print("Start run") samples = posterior_sampler.run(warmup_ + num_samples_) - print("Finish run") samples = samples[:, warmup_steps:, :] # discard warmup steps samples = torch.from_numpy(samples) # chains x samples x dim diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py index 5eb53497b..549942ce2 100644 --- a/sbi/inference/posteriors/rejection_posterior.py +++ b/sbi/inference/posteriors/rejection_posterior.py @@ -175,7 +175,9 @@ def sample_batched( show_progress_bars: bool = True, ) -> Tensor: raise NotImplementedError( - "Batched sampling is not supported for rejection sampling." + "Batched sampling is not implemented for RejectionPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." ) def map( diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index fd89a5654..b0e7bcf8f 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -303,7 +303,11 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: - raise NotImplementedError("Batched sampling is not supported for VIPosterior.") + raise NotImplementedError( + "Batched sampling is not implemented for ImportanceSamplingPosterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." + ) def log_prob( self, diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index a7efef04c..dfe02a05d 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -117,7 +117,10 @@ def test_batched_mcmc_sample_log_prob_with_different_x( ) samples = posterior.sample_batched((10,), x_o) - # batched_log_probs = posterior.log_prob_batched(samples, x_o) + print(x_o.shape) + print(samples.shape) + batched_log_probs = posterior.log_prob_batched(samples, x_o) + print(batched_log_probs.shape) assert ( samples.shape == (10, x_o_batch_dim, num_dim) From 24c4821efde84e3b8970c1ab79f2bf998a2e2ce7 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 14 May 2024 11:06:01 +0200 Subject: [PATCH 50/71] batch_log_prob for MCMC requires at best changes for potential -> removed --- tests/posterior_nn_test.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index dfe02a05d..d005f4412 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -84,8 +84,8 @@ def test_batched_sample_log_prob_with_different_x( samples.shape == (10, x_o_batch_dim, num_dim) if x_o_batch_dim > 0 else (10, num_dim) - ) - assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) + ), "Sample shape wrong" + assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)), "logprob shape wrong" @pytest.mark.mcmc @@ -117,14 +117,9 @@ def test_batched_mcmc_sample_log_prob_with_different_x( ) samples = posterior.sample_batched((10,), x_o) - print(x_o.shape) - print(samples.shape) - batched_log_probs = posterior.log_prob_batched(samples, x_o) - print(batched_log_probs.shape) assert ( samples.shape == (10, x_o_batch_dim, num_dim) if x_o_batch_dim > 0 else (10, num_dim) - ) - # assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)) + ), "Sampel shape wrong" From a445a6c159cd10d0d8b4a54c2114a5a9f70054d9 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 10:40:40 +0200 Subject: [PATCH 51/71] Fixing bug from rebase... --- sbi/samplers/rejection/rejection.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 08d4739f0..d9cd4094b 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -267,24 +267,6 @@ def accept_reject_sample( leakage_warning_raised = False # Ruff suggestion - num_sampled_total, num_remaining = 0, num_samples - if "condition" in list(proposal_sampling_kwargs.keys()): - num_xo = proposal_sampling_kwargs["condition"].shape[0] - else: - num_xos = 1 - accepted, acceptance_rate = [[] for _ in range(num_xos)], float("Nan") - leakage_warning_raised = False - - num_sampled_total, num_remaining = 0, num_samples - if "condition" in list(proposal_sampling_kwargs.keys()): - num_xo = proposal_sampling_kwargs["condition"].shape[0] - else: - num_xo = 1 - - accepted_every_obs = [torch.tensor(()) for _ in range(num_xo)] - accepted, acceptance_rate = [], float("Nan") - leakage_warning_raised = False - # To cover cases with few samples without leakage: sampling_batch_size = min(num_samples, max_sampling_batch_size) while num_remaining > 0: @@ -361,10 +343,6 @@ def accept_reject_sample( pbar.close() - for obs_index in range(num_xo): - accepted_every_obs[obs_index] = accepted_every_obs[obs_index][:num_samples] - - accepted_every_obs = torch.stack(accepted_every_obs) # When in case of leakage a batch size was used there could be too many samples. samples = [torch.cat(accepted[i], dim=0)[:num_samples] for i in range(num_xos)] samples = torch.stack(samples, dim=1) @@ -373,4 +351,4 @@ def accept_reject_sample( samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." - return torch.permute(accepted_every_obs, (1, 0, -1)), as_tensor(acceptance_rate) + return samples, as_tensor(acceptance_rate) From 86767a1b430b6e31cd8c8be937fcb9f2322c7db6 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 11:20:14 +0200 Subject: [PATCH 52/71] tracking all acceptance rates --- sbi/samplers/rejection/rejection.py | 37 +++++++++++++++++------------ 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index d9cd4094b..cb751b03a 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -179,7 +179,6 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: # When in case of leakage a batch size was used there could be too many samples. samples = torch.cat(accepted)[:num_samples] - assert ( samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." @@ -238,7 +237,7 @@ def accept_reject_sample( Returns: Accepted samples of shape `(sample_dim, batch_dim, *event_shape)`, and - worst-case acceptance rate as scalar Tensor. + acceptance rates for each observation. """ if kwargs: @@ -258,17 +257,21 @@ def accept_reject_sample( if proposal_sampling_kwargs is None: proposal_sampling_kwargs = {} - num_sampled_total, num_remaining = 0, num_samples + num_remaining = num_samples if "condition" in proposal_sampling_kwargs: num_xos = proposal_sampling_kwargs["condition"].shape[0] else: num_xos = 1 - accepted, acceptance_rate = [[] for _ in range(num_xos)], float("Nan") + + accepted = [[] for _ in range(num_xos)] + acceptance_rate = torch.full((num_xos,), float("Nan")) leakage_warning_raised = False # Ruff suggestion # To cover cases with few samples without leakage: sampling_batch_size = min(num_samples, max_sampling_batch_size) + num_sampled_total = torch.zeros(num_xos) + num_samples_possible = 0 while num_remaining > 0: # Sample and reject. candidates = proposal.sample( @@ -284,7 +287,6 @@ def accept_reject_sample( sampling_batch_size, num_xos, *candidates.shape[candidates.ndim - 1 :] ) - num_accepted = are_accepted.sum(dim=0).min().item() for i in range(num_xos): accepted[i].append(candidates_to_reject[are_accepted[:, i], i]) @@ -292,13 +294,17 @@ def accept_reject_sample( # Note: For any condition of shape (*batch_shape, *condition_shape), the # samples will be of shape(sampling_batch_size,*batch_shape, *event_shape) # and hence work in dim = 0. - num_sampled_total += sampling_batch_size - num_remaining -= num_accepted - pbar.update(num_accepted) + num_accepted = are_accepted.sum(dim=0) + num_sampled_total += num_accepted + num_samples_possible += sampling_batch_size + min_num_accepted = num_accepted.min().item() + num_remaining -= min_num_accepted + pbar.update(min_num_accepted) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. - acceptance_rate = (num_samples - num_remaining) / num_sampled_total + acceptance_rate = num_sampled_total / num_samples_possible + min_acceptance_rate = acceptance_rate.min().item() # For remaining iterations (leakage or many samples) continue # sampling with fixed batch size, reduced in cased the number @@ -306,20 +312,21 @@ def accept_reject_sample( # by zero if acceptance rate is zero. sampling_batch_size = min( max_sampling_batch_size, - max(int(1.5 * num_remaining / max(acceptance_rate, 1e-12)), 100), + max(int(1.5 * num_remaining / max(min_acceptance_rate, 1e-12)), 100), ) if ( num_sampled_total > 1000 - and acceptance_rate < warn_acceptance + and min_acceptance_rate < warn_acceptance and not leakage_warning_raised ): if sample_for_correction_factor: + idx_min = acceptance_rate.argmin().item() logging.warning( f"""Drawing samples from posterior to estimate the normalizing constant for `log_prob()`. However, only - {acceptance_rate:.3%} posterior samples are within the - prior support. It may take a long time to collect the - remaining {num_remaining} samples. + {min_acceptance_rate:.3%} posterior samples are within the + prior support (for condition {idx_min}). It may take a long time + to collect the remaining {num_remaining} samples. Consider interrupting (Ctrl-C) and either basing the estimate of the normalizing constant on fewer samples (by calling `posterior.leakage_correction(x_o, @@ -331,7 +338,7 @@ def accept_reject_sample( result in an unnormalized `log_prob()`.""" ) else: - warn_msg = f"""Only {acceptance_rate:.3%} proposal samples are + warn_msg = f"""Only {min_acceptance_rate:.3%} proposal samples are accepted. It may take a long time to collect the remaining {num_remaining} samples. """ if alternative_method is not None: From 9502af34af076edab982ff2ded233eabc0553705 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 11:24:17 +0200 Subject: [PATCH 53/71] Comment on NFlows --- sbi/neural_nets/density_estimators/nflows_flow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 8d6aaba55..42aee9d47 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -135,6 +135,8 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: num_samples = torch.Size(sample_shape).numel() samples = self.net.sample(num_samples, context=condition) + # Change from Nflows' convention of (batch_dim, sample_dim, *event_shape) to + # (sample_dim, batch_dim, *event_shape) (PyTorch + SBI). samples = samples.transpose(0, 1) return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape)) From c80e6ffecc51c82419bf123060e484031e6e38be Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 12:01:05 +0200 Subject: [PATCH 54/71] Also testing SNRE batched sampling, Need to test ensemble implementation --- .../posteriors/ensemble_posterior.py | 24 ++++++++++++------- sbi/inference/posteriors/vi_posterior.py | 2 +- sbi/samplers/rejection/rejection.py | 2 +- tests/posterior_nn_test.py | 7 ++++-- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index e895353c3..0e55ca130 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -183,16 +183,24 @@ def sample_batched( self, sample_shape: Shape, x: Tensor, - max_sampling_batch_size: int = 10000, - show_progress_bars: bool = True, + **kwargs, ) -> Tensor: - # TODO Can be implemented in the future, for all base posterior that support - # batched sampling. - raise NotImplementedError( - "Batched sampling is not implemented for EnsemblePosterior. \ - Alternatively you can use `sample` in a loop \ - [posterior.sample(theta, x_o) for x_o in x]." + num_samples = torch.Size(sample_shape).numel() + posterior_indizes = torch.multinomial( + self._weights, num_samples, replacement=True ) + samples = [] + for posterior_index, sample_size in torch.vstack( + posterior_indizes.unique(return_counts=True) + ).T: + sample_shape_c = torch.Size((int(sample_size),)) + samples.append( + self.posteriors[posterior_index].sample_batched( + sample_shape_c, x=x, **kwargs + ) + ) + samples = torch.vstack(samples) + return samples.reshape(sample_shape + samples.shape[1:]) def log_prob( self, diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index b0e7bcf8f..f75ac0a4b 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -304,7 +304,7 @@ def sample_batched( show_progress_bars: bool = True, ) -> Tensor: raise NotImplementedError( - "Batched sampling is not implemented for ImportanceSamplingPosterior. \ + "Batched sampling is not implemented for VIPosterior. \ Alternatively you can use `sample` in a loop \ [posterior.sample(theta, x_o) for x_o in x]." ) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index cb751b03a..e29fe51db 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -358,4 +358,4 @@ def accept_reject_sample( samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." - return samples, as_tensor(acceptance_rate) + return samples, as_tensor(min_acceptance_rate) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index d005f4412..a86fcebce 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -11,6 +11,9 @@ SNLE_A, SNPE_A, SNPE_C, + SNRE_A, + SNRE_B, + SNRE_C, DirectPosterior, simulate_for_sbi, ) @@ -89,7 +92,7 @@ def test_batched_sample_log_prob_with_different_x( @pytest.mark.mcmc -@pytest.mark.parametrize("snlre_method", [SNLE_A]) +@pytest.mark.parametrize("snlre_method", [SNLE_A, SNRE_A, SNRE_B, SNRE_C]) @pytest.mark.parametrize( "x_o_batch_dim", ( @@ -122,4 +125,4 @@ def test_batched_mcmc_sample_log_prob_with_different_x( samples.shape == (10, x_o_batch_dim, num_dim) if x_o_batch_dim > 0 else (10, num_dim) - ), "Sampel shape wrong" + ), "Sample shape wrong" From 7aac84c74fb2b376211715ddae8ed6669782c1ed Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 12:26:57 +0200 Subject: [PATCH 55/71] fig bug --- sbi/samplers/rejection/rejection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index e29fe51db..8e7326346 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -315,7 +315,7 @@ def accept_reject_sample( max(int(1.5 * num_remaining / max(min_acceptance_rate, 1e-12)), 100), ) if ( - num_sampled_total > 1000 + num_sampled_total.min().item() > 1000 and min_acceptance_rate < warn_acceptance and not leakage_warning_raised ): From 7d4eb55e817fb2d3a830d6833e2b2f0e085ce6ed Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 14:28:20 +0200 Subject: [PATCH 56/71] Ensemble sample_batched is working (with tests) --- tests/ensemble_test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/ensemble_test.py b/tests/ensemble_test.py index 8ec819eb3..d4e8cc4e6 100644 --- a/tests/ensemble_test.py +++ b/tests/ensemble_test.py @@ -139,3 +139,16 @@ def simulator(theta): # test individual log_prob and map posterior.log_prob(samples, individually=True) + + # Test sample_batched + x_o_batch_dim = 2 + if isinstance(inferer, (SNLE_A, SNRE_A)): + samples = posterior.sample_batched( + (10,), + ones(x_o_batch_dim, num_dim), + method="slice_np_vectorized", + ) + else: + samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim)) + + assert samples.shape == (10, x_o_batch_dim, num_dim), "Sample shape wrong" From f53e1ecf06a30e03a768525197f83bd83e6c76df Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 14:52:46 +0200 Subject: [PATCH 57/71] GPU compatibility --- sbi/samplers/rejection/rejection.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 8e7326346..ea19e7b97 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -258,6 +258,10 @@ def accept_reject_sample( proposal_sampling_kwargs = {} num_remaining = num_samples + + # NOTE: We might want to change this to a more general approach in the future. + # Currently limited to a single "batch_dim" for the condition. + # But this would require giving the method the condition_shape explicitly... if "condition" in proposal_sampling_kwargs: num_xos = proposal_sampling_kwargs["condition"].shape[0] else: @@ -295,7 +299,7 @@ def accept_reject_sample( # samples will be of shape(sampling_batch_size,*batch_shape, *event_shape) # and hence work in dim = 0. num_accepted = are_accepted.sum(dim=0) - num_sampled_total += num_accepted + num_sampled_total += num_accepted.to(num_sampled_total.device) num_samples_possible += sampling_batch_size min_num_accepted = num_accepted.min().item() num_remaining -= min_num_accepted @@ -358,4 +362,4 @@ def accept_reject_sample( samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." - return samples, as_tensor(min_acceptance_rate) + return samples, acceptance_rate From 2dc6ebdbdc7e478d54d7443c81debda7fbd794f0 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 15:01:05 +0200 Subject: [PATCH 58/71] restriction priopr requires float as output of accept_reject --- sbi/samplers/rejection/rejection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index ea19e7b97..5c1d5ffc5 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -362,4 +362,6 @@ def accept_reject_sample( samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." - return samples, acceptance_rate + # NOTE: Restriction prior does currently require a float as return for the + # acceptance rate, which is why we for now also return the minimum acceptance rate. + return samples, as_tensor(min_acceptance_rate) From 7dfda13cd587d3fa42c9f5bad1e359a98f0153ae Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 15:18:30 +0200 Subject: [PATCH 59/71] Adding a few comments --- sbi/inference/posteriors/mcmc_posterior.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index b0cd6d598..6f322a767 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -354,7 +354,7 @@ def sample( samples = self.theta_transform.inv(transformed_samples) samples = samples.reshape((*sample_shape, -1)) # type: ignore - return samples.reshape((*sample_shape, -1)) # type: ignore + return samples def sample_batched( self, @@ -405,6 +405,7 @@ def sample_batched( ) self.potential_ = self._prepare_potential(method) # type: ignore + # For each observation in the batch, we have num_chains independent chains. num_chains_extended = batch_size * num_chains initial_params = self._get_initial_params( init_strategy, # type: ignore @@ -413,7 +414,7 @@ def sample_batched( show_progress_bars, **init_strategy_parameters, ) - + # We need num_samples from each posterior in the batch num_samples = torch.Size(sample_shape).numel() * batch_size assert ( @@ -433,7 +434,7 @@ def sample_batched( ) samples = self.theta_transform.inv(transformed_samples) - + # Samples are of shape (num_samples, num_chains_extended, *input_shape) return samples.reshape((*sample_shape, batch_size, -1)) # type: ignore def _build_mcmc_init_fn( From 89b6e8f00bfbd0965c9040915300b7b44cc6ae66 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 11 Jun 2024 15:54:09 +0200 Subject: [PATCH 60/71] 2d sample_shape tests --- tests/density_estimator_test.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 35fb0d946..9a84e036b 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -299,14 +299,18 @@ def test_correctness_of_density_estimator_log_prob( build_zuko_nsf, build_zuko_sospf, build_zuko_unaf, - build_categoricalmassestimator, - build_mnle, + # build_categoricalmassestimator, NOTE: This does not support 2d sample_shape + # build_mnle, NOTE: This does not support 2d sample_shape ), ) @pytest.mark.parametrize("input_event_shape", ((1,), (4,))) @pytest.mark.parametrize("condition_event_shape", ((1,), (7,))) +@pytest.mark.parametrize("sample_shape", ((1000,), (500, 2))) def test_correctness_of_batched_vs_seperate_sample_and_log_prob( - density_estimator_build_fn, input_event_shape, condition_event_shape + density_estimator_build_fn, + input_event_shape, + condition_event_shape, + sample_shape, ): input_sample_dim = 2 batch_dim = 2 @@ -318,7 +322,9 @@ def test_correctness_of_batched_vs_seperate_sample_and_log_prob( input_sample_dim, ) # Batched vs separate sampling - samples = density_estimator.sample((1000,), condition=condition) + samples = density_estimator.sample(sample_shape, condition=condition) + samples = samples.reshape(-1, batch_dim, *input_event_shape) # Flat for comp. + samples_separate1 = density_estimator.sample( (1000,), condition=condition[0][None, ...] ) From 93ca374e75d2e4cc15cc451728956a56e5ff1057 Mon Sep 17 00:00:00 2001 From: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Date: Fri, 14 Jun 2024 11:52:46 +0200 Subject: [PATCH 61/71] Apply suggestions from code review Co-authored-by: Jan --- sbi/inference/posteriors/importance_posterior.py | 1 - tests/density_estimator_test.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index 0b659c84e..b6e9721d1 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -201,7 +201,6 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: - # TODO Can be implemented in the future. raise NotImplementedError( "Batched sampling is not implemented for ImportanceSamplingPosterior. \ Alternatively you can use `sample` in a loop \ diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 9a84e036b..4f562a4b1 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -299,8 +299,8 @@ def test_correctness_of_density_estimator_log_prob( build_zuko_nsf, build_zuko_sospf, build_zuko_unaf, - # build_categoricalmassestimator, NOTE: This does not support 2d sample_shape - # build_mnle, NOTE: This does not support 2d sample_shape + pytest.param(build_categoricalmassestimator, marks=pytest.mark.xfail(reason='see issue #1172')), + pytest.param(build_mnle, marks=pytest.mark.xfail(reason='see issue #1172')), ), ) @pytest.mark.parametrize("input_event_shape", ((1,), (4,))) From 86f3531d2ee584d45afb0f099ce311e82a79678a Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Fri, 14 Jun 2024 16:12:50 +0200 Subject: [PATCH 62/71] Adding comment about squeeze --- sbi/inference/posteriors/mcmc_posterior.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 6f322a767..e439d7805 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -352,6 +352,9 @@ def sample( raise NameError(f"The sampling method {method} is not implemented!") samples = self.theta_transform.inv(transformed_samples) + # NOTE: Currently MCMCPosteriors will require a single dimension for the + # parameter dimension. With recent ConditionalDensity(Ratio) estimators, we + # can have multiple dimensions for the parameter dimension. samples = samples.reshape((*sample_shape, -1)) # type: ignore return samples @@ -591,11 +594,7 @@ def _slice_np_mcmc( def multi_obs_potential(params): # Params are of shape (num_chains * num_obs, event). - # We now reshape them to (num_chains, num_obs, event). - # params = np.reshape(params, (num_chains, num_obs, -1)) - - # `all_potentials` is of shape (num_chains, num_obs). - all_potentials = potential_function(params) + all_potentials = potential_function(params) # Shape: (num_chains, num_obs) return all_potentials.flatten() posterior_sampler = SliceSamplerMultiChain( From 2a5f357c4cc39b17f6b320dbb4a962ff3c13f93e Mon Sep 17 00:00:00 2001 From: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Date: Tue, 18 Jun 2024 09:38:58 +0200 Subject: [PATCH 63/71] Update sbi/inference/posteriors/direct_posterior.py Co-authored-by: Jan --- sbi/inference/posteriors/direct_posterior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 1c7ef691d..9f90cfb45 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -264,7 +264,7 @@ def log_prob_batched( ) -> Tensor: """Given a batch of observations [x_1, ..., x_B] and a batch of parameters \ [$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \ - of the posterior $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \ + of the posteriors $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \ (i.e. vectorized) manner. Args: From 79273a2fd37b933f0dc624ab31e293d7392e6530 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 18 Jun 2024 09:40:32 +0200 Subject: [PATCH 64/71] fixing formating --- tests/density_estimator_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 4f562a4b1..4fd0cd794 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -299,7 +299,10 @@ def test_correctness_of_density_estimator_log_prob( build_zuko_nsf, build_zuko_sospf, build_zuko_unaf, - pytest.param(build_categoricalmassestimator, marks=pytest.mark.xfail(reason='see issue #1172')), + pytest.param( + build_categoricalmassestimator, + marks=pytest.mark.xfail(reason='see issue #1172'), + ), pytest.param(build_mnle, marks=pytest.mark.xfail(reason='see issue #1172')), ), ) From 7b23d606f812b4a6ed24203deaec3a8d65935441 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 18 Jun 2024 09:47:06 +0200 Subject: [PATCH 65/71] reverting MCM posterior changes --- sbi/inference/posteriors/mcmc_posterior.py | 59 +++------------------- 1 file changed, 6 insertions(+), 53 deletions(-) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index e439d7805..65e9a0c25 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -390,55 +390,13 @@ def sample_batched( Returns: Samples from the posteriors of shape (*sample_shape, B, *input_shape) """ - batch_size = x.shape[0] - self.potential_fn.set_x(x) - # Replace arguments that were not passed with their default. - method = self.method if method is None else method - thin = self.thin if thin is None else thin - warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps - num_chains = self.num_chains if num_chains is None else num_chains - init_strategy = self.init_strategy if init_strategy is None else init_strategy - num_workers = self.num_workers if num_workers is None else num_workers - mp_context = self.mp_context if mp_context is None else mp_context - init_strategy_parameters = ( - self.init_strategy_parameters - if init_strategy_parameters is None - else init_strategy_parameters - ) - self.potential_ = self._prepare_potential(method) # type: ignore - - # For each observation in the batch, we have num_chains independent chains. - num_chains_extended = batch_size * num_chains - initial_params = self._get_initial_params( - init_strategy, # type: ignore - num_chains_extended, # type: ignore - num_workers, - show_progress_bars, - **init_strategy_parameters, + # See #1176 for a discussion on the implementation of batched sampling. + raise NotImplementedError( + "Batched sampling is not implemented for MCMC posterior. \ + Alternatively you can use `sample` in a loop \ + [posterior.sample(theta, x_o) for x_o in x]." ) - # We need num_samples from each posterior in the batch - num_samples = torch.Size(sample_shape).numel() * batch_size - - assert ( - method == "slice_np_vectorized" - ), "Batched sampling only supported for vectorized samplers!" - - with torch.set_grad_enabled(False): - transformed_samples = self._slice_np_mcmc( - num_samples=num_samples, - potential_function=self.potential_, - initial_params=initial_params, - thin=thin, # type: ignore - warmup_steps=warmup_steps, # type: ignore - vectorized=(method == "slice_np_vectorized"), - num_workers=num_workers, - show_progress_bars=show_progress_bars, - ) - - samples = self.theta_transform.inv(transformed_samples) - # Samples are of shape (num_samples, num_chains_extended, *input_shape) - return samples.reshape((*sample_shape, batch_size, -1)) # type: ignore def _build_mcmc_init_fn( self, @@ -592,14 +550,9 @@ def _slice_np_mcmc( else: SliceSamplerMultiChain = SliceSamplerVectorized - def multi_obs_potential(params): - # Params are of shape (num_chains * num_obs, event). - all_potentials = potential_function(params) # Shape: (num_chains, num_obs) - return all_potentials.flatten() - posterior_sampler = SliceSamplerMultiChain( init_params=tensor2numpy(initial_params), - log_prob_fn=multi_obs_potential, + log_prob_fn=potential_function, num_chains=num_chains, thin=thin, verbose=show_progress_bars, From d4f9e46517e7c53732369ada3b2b0d22a8ee36ce Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 18 Jun 2024 10:03:43 +0200 Subject: [PATCH 66/71] xfail mcmc tests --- tests/ensemble_test.py | 7 ++----- tests/posterior_nn_test.py | 24 +++++++++--------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/tests/ensemble_test.py b/tests/ensemble_test.py index d4e8cc4e6..c9a96cd3d 100644 --- a/tests/ensemble_test.py +++ b/tests/ensemble_test.py @@ -143,11 +143,8 @@ def simulator(theta): # Test sample_batched x_o_batch_dim = 2 if isinstance(inferer, (SNLE_A, SNRE_A)): - samples = posterior.sample_batched( - (10,), - ones(x_o_batch_dim, num_dim), - method="slice_np_vectorized", - ) + # TODO: Implement batched sampling for MCMC methods + pass else: samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim)) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index a86fcebce..f634c1fd2 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -56,14 +56,7 @@ def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool): @pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) -@pytest.mark.parametrize( - "x_o_batch_dim", - ( - 0, - 1, - 2, - ), -) +@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) def test_batched_sample_log_prob_with_different_x( snpe_method: type, x_o_batch_dim: bool ): @@ -92,15 +85,16 @@ def test_batched_sample_log_prob_with_different_x( @pytest.mark.mcmc -@pytest.mark.parametrize("snlre_method", [SNLE_A, SNRE_A, SNRE_B, SNRE_C]) @pytest.mark.parametrize( - "x_o_batch_dim", - ( - 0, - 1, - 2, - ), + "snlre_method", + [ + pytest.param(SNLE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(SNRE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(SNRE_B, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(SNRE_C, marks=pytest.mark.xfail(raises=NotImplementedError)), + ], ) +@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) def test_batched_mcmc_sample_log_prob_with_different_x( snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict ): From 6798c979fc386090bb6e817d230dca1282ff3526 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 18 Jun 2024 10:06:06 +0200 Subject: [PATCH 67/71] Exclude MCMC from ensamble batched_sample test --- tests/ensemble_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ensemble_test.py b/tests/ensemble_test.py index c9a96cd3d..bdd9fd5e0 100644 --- a/tests/ensemble_test.py +++ b/tests/ensemble_test.py @@ -144,7 +144,7 @@ def simulator(theta): x_o_batch_dim = 2 if isinstance(inferer, (SNLE_A, SNRE_A)): # TODO: Implement batched sampling for MCMC methods - pass + return else: samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim)) From b1724a54b48a8faf8075c553f0bf92fd52877c66 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 18 Jun 2024 15:58:43 +0200 Subject: [PATCH 68/71] SNPE_A Bug fix --- sbi/inference/snpe/snpe_a.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index dd3774e6f..1295ff287 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -474,7 +474,8 @@ def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tenso condition = condition.to(self._device) if not self._apply_correction: - return self._neural_net.sample(sample_shape, condition=condition) + samples = self._neural_net.sample(sample_shape, condition=condition) + return samples else: # When we want to sample from the approx. posterior, a proposal prior # \tilde{p} has already been observed. To analytically calculate the @@ -483,7 +484,12 @@ def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tenso condition_ndim = len(self.condition_shape) batch_size = condition.shape[:-condition_ndim] batch_size = torch.Size(batch_size).numel() - return self._sample_approx_posterior_mog(num_samples, condition, batch_size) + samples = self._sample_approx_posterior_mog( + num_samples, condition, batch_size + ) + # NOTE: New batching convention: (batch_dim, sample_dim, *event_shape) + samples = samples.transpose(0, 1) + return samples def _sample_approx_posterior_mog( self, num_samples, x: Tensor, batch_size: int From a6f4845e079fd3ec917c87f8695d67be84758566 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 18 Jun 2024 16:06:51 +0200 Subject: [PATCH 69/71] typo fix --- sbi/inference/posteriors/ensemble_posterior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 0e55ca130..b7eef8f7d 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -186,12 +186,12 @@ def sample_batched( **kwargs, ) -> Tensor: num_samples = torch.Size(sample_shape).numel() - posterior_indizes = torch.multinomial( + posterior_indices = torch.multinomial( self._weights, num_samples, replacement=True ) samples = [] for posterior_index, sample_size in torch.vstack( - posterior_indizes.unique(return_counts=True) + posterior_indices.unique(return_counts=True) ).T: sample_shape_c = torch.Size((int(sample_size),)) samples.append( From 2aac7053d3af17cf81d456522e1668cbe458655f Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 18 Jun 2024 16:24:56 +0200 Subject: [PATCH 70/71] preamtive main fix --- tests/linearGaussian_snpe_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index cd3b2c2cf..ad23fe126 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -600,7 +600,7 @@ def simulator(theta): error = np.abs(sample_kde_grid - eval_grid.numpy()) max_err = np.max(error) - assert max_err < 0.0027 + assert max_err < 0.005 def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): From 26444f767349671fc5d1a8e83a420b3b16c22fda Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 18 Jun 2024 16:39:31 +0200 Subject: [PATCH 71/71] Revert "preamtive main fix" This reverts commit 2aac7053d3af17cf81d456522e1668cbe458655f. --- tests/linearGaussian_snpe_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index ad23fe126..cd3b2c2cf 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -600,7 +600,7 @@ def simulator(theta): error = np.abs(sample_kde_grid - eval_grid.numpy()) max_err = np.max(error) - assert max_err < 0.005 + assert max_err < 0.0027 def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1):