From 6f616622cd6122b3cb6db89d3d0363b3d0040f65 Mon Sep 17 00:00:00 2001 From: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:03:10 +0200 Subject: [PATCH] importance sampling posterior (#1183) * importance sampling posterior * Tests and improved importance_posterior * pyright fix as suggested by Guy --- .../posteriors/importance_posterior.py | 17 +++++++++--- sbi/inference/snle/snle_base.py | 13 ++++++++- sbi/inference/snpe/snpe_base.py | 17 +++++++++++- sbi/inference/snre/snre_base.py | 13 ++++++++- tests/posterior_nn_test.py | 27 +++++++++++++++++++ 5 files changed, 80 insertions(+), 7 deletions(-) diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index 1a14d6bd7..4e9e002fb 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -156,6 +156,7 @@ def sample( self, sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, + method: Optional[str] = None, oversampling_factor: int = 32, max_sampling_batch_size: int = 10_000, sample_with: Optional[str] = None, @@ -164,14 +165,22 @@ def sample( """Return samples from the approximate posterior distribution. Args: - sample_shape: _description_ - x: _description_ + sample_shape: Shape of samples that are drawn from posterior. + x: Observed data. + method: Either of [`sir`|`importance`]. This sets the behavior of the + `.sample()` method. With `sir`, approximate posterior samples are + generated with sampling importance resampling (SIR). With + `importance`, the `.sample()` method returns a tuple of samples and + corresponding importance weights. oversampling_factor: Number of proposed samples from which only one is selected based on its importance weight. max_sampling_batch_size: The batch size of samples being drawn from the proposal at every iteration. show_progress_bars: Whether to show a progressbar during sampling. """ + + method = self.method if method is None else method + if sample_with is not None: raise ValueError( f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting " @@ -181,14 +190,14 @@ def sample( self.potential_fn.set_x(self._x_else_default_x(x)) - if self.method == "sir": + if method == "sir": return self._sir_sample( sample_shape, oversampling_factor=oversampling_factor, max_sampling_batch_size=max_sampling_batch_size, show_progress_bars=show_progress_bars, ) - elif self.method == "importance": + elif method == "importance": return self._importance_sample(sample_shape) else: raise NameError diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index c51244c64..bb841911d 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -13,6 +13,7 @@ from sbi.inference.base import NeuralInference from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior +from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior from sbi.inference.potentials import likelihood_estimator_based_potential from sbi.neural_nets import ConditionalDensityEstimator, likelihood_nn from sbi.neural_nets.density_estimators.shape_handling import ( @@ -270,7 +271,10 @@ def build_posterior( mcmc_parameters: Optional[Dict[str, Any]] = None, vi_parameters: Optional[Dict[str, Any]] = None, rejection_sampling_parameters: Optional[Dict[str, Any]] = None, - ) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior]: + importance_sampling_parameters: Optional[Dict[str, Any]] = None, + ) -> Union[ + MCMCPosterior, RejectionPosterior, VIPosterior, ImportanceSamplingPosterior + ]: r"""Build posterior from the neural density estimator. SNLE trains a neural network to approximate the likelihood $p(x|\theta)$. The @@ -350,6 +354,13 @@ def build_posterior( device=device, **vi_parameters or {}, ) + elif sample_with == "importance": + self._posterior = ImportanceSamplingPosterior( + potential_fn=potential_fn, + proposal=prior, + device=device, + **importance_sampling_parameters or {}, + ) else: raise NotImplementedError diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 4132e8c75..1ca63b2fc 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -21,6 +21,7 @@ VIPosterior, ) from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior from sbi.inference.potentials import posterior_estimator_based_potential from sbi.neural_nets import ConditionalDensityEstimator, posterior_nn from sbi.neural_nets.density_estimators.shape_handling import ( @@ -441,7 +442,14 @@ def build_posterior( mcmc_parameters: Optional[Dict[str, Any]] = None, vi_parameters: Optional[Dict[str, Any]] = None, rejection_sampling_parameters: Optional[Dict[str, Any]] = None, - ) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior, DirectPosterior]: + importance_sampling_parameters: Optional[Dict[str, Any]] = None, + ) -> Union[ + MCMCPosterior, + RejectionPosterior, + VIPosterior, + DirectPosterior, + ImportanceSamplingPosterior, + ]: r"""Build posterior from the neural density estimator. For SNPE, the posterior distribution that is returned here implements the @@ -540,6 +548,13 @@ def build_posterior( device=device, **vi_parameters or {}, ) + elif sample_with == "importance": + self._posterior = ImportanceSamplingPosterior( + potential_fn=potential_fn, + proposal=prior, + device=device, + **importance_sampling_parameters or {}, + ) else: raise NotImplementedError diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index 8ceb9aa7d..38f89d443 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -13,6 +13,7 @@ from sbi.inference.base import NeuralInference from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior +from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior from sbi.inference.potentials import ratio_estimator_based_potential from sbi.neural_nets import classifier_nn from sbi.utils import ( @@ -322,7 +323,10 @@ def build_posterior( mcmc_parameters: Optional[Dict[str, Any]] = None, vi_parameters: Optional[Dict[str, Any]] = None, rejection_sampling_parameters: Optional[Dict[str, Any]] = None, - ) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior]: + importance_sampling_parameters: Optional[Dict[str, Any]] = None, + ) -> Union[ + MCMCPosterior, RejectionPosterior, VIPosterior, ImportanceSamplingPosterior + ]: r"""Build posterior from the neural density estimator. SNRE trains a neural network to approximate likelihood ratios. The @@ -405,6 +409,13 @@ def build_posterior( device=device, **vi_parameters or {}, ) + elif sample_with == "importance": + self._posterior = ImportanceSamplingPosterior( + potential_fn=potential_fn, + proposal=prior, + device=device, + **importance_sampling_parameters or {}, + ) else: raise NotImplementedError diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index f634c1fd2..4ba61d7d0 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -55,6 +55,33 @@ def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool): _ = posterior.log_prob(samples) +@pytest.mark.parametrize( + "snplre_method", [SNPE_A, SNPE_C, SNLE_A, SNRE_A, SNRE_B, SNRE_C] +) +def test_importance_posterior_sample_log_prob(snplre_method: type): + num_dim = 2 + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + simulator = diagonal_linear_gaussian + + inference = snplre_method(prior=prior) + theta, x = simulate_for_sbi(simulator, prior, 1000) + _ = inference.append_simulations(theta, x).train(max_num_epochs=3) + + posterior = inference.build_posterior(sample_with="importance") + + x_o = ones(num_dim) + samples = posterior.sample((10,), x=x_o) + samples2, weights = posterior.sample((10,), x=x_o, method="importance") + assert samples.shape == (10, num_dim), "Sample shape of sample is wrong" + assert samples2.shape == (10, num_dim), "Sample of sample_with_weights shape wrong" + assert weights.shape == (10,), "Weights shape wrong" + + log_prob = posterior.log_prob(samples, x=x_o) + + assert log_prob.shape == (10,), "logprob shape wrong" + + @pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) @pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) def test_batched_sample_log_prob_with_different_x(