From 5200ee04961858adcf8c6df2b9b4c92b827bfc2a Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Fri, 18 Mar 2022 14:52:13 -0400 Subject: [PATCH 01/15] Initial testing of get_dataloaders --- sbi/inference/base.py | 75 ++++++++++++++++++++++++++++++++- sbi/inference/snpe/snpe_base.py | 16 +++---- 2 files changed, 79 insertions(+), 12 deletions(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 5a2fb7c57..23450a29e 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -161,6 +161,7 @@ def get_simulations( starting_round: int = 0, exclude_invalid_x: bool = True, warn_on_invalid: bool = True, + warn_if_zscoring: Optional[bool] = True, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Returns all $\theta$, $x$, and prior_masks from rounds >= `starting_round`. @@ -190,7 +191,8 @@ def get_simulations( # Check for NaNs in simulations. is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) # Check for problematic z-scoring - warn_if_zscoring_changes_data(x[is_valid_x]) + if warn_if_zscoring: + warn_if_zscoring_changes_data(x[is_valid_x]) if warn_on_invalid: warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) warn_on_invalid_x_for_snpec_leakage( @@ -216,6 +218,74 @@ def train( ) -> NeuralPosterior: raise NotImplementedError + def get_dataloaders_all( + self, + starting_round: int = 0, + exclude_invalid_x: bool = True, + warn_on_invalid: bool = True, + training_batch_size: int = 50, + validation_fraction: float = 0.1, + resume_training: bool = False, + dataloader_kwargs: Optional[dict] = None, + warn_if_zscoring: Optional[bool] = True, + ) -> Tuple[data.DataLoader, data.DataLoader]: + """Return dataloaders for training and validation. + + Args: + dataset: holding all theta and x, optionally masks. + training_batch_size: training arg of inference methods. + resume_training: Whether the current call is resuming training so that no + new training and validation indices into the dataset have to be created. + dataloader_kwargs: Additional or updated kwargs to be passed to the training + and validation dataloaders (like, e.g., a collate_fn). + + Returns: + Tuple of dataloaders for training and validation. + + """ + + datset = data.TensorDataset( + *self.get_simulations( + starting_round,exclude_invalid_x, warn_on_invalid = warn_on_invalid, + warn_if_zscoring = warn_if_zscoring + ) + + ) + # Get total number of training examples. + num_examples = len(dataset) + + # Select random train and validation splits from (theta, x) pairs. + num_training_examples = int((1 - validation_fraction) * num_examples) + num_validation_examples = num_examples - num_training_examples + + if not resume_training: + permuted_indices = torch.randperm(num_examples) + self.train_indices, self.val_indices = ( + permuted_indices[:num_training_examples], + permuted_indices[num_training_examples:], + ) + + # Create training and validation loaders using a subset sampler. + # Intentionally use dicts to define the default dataloader args + # Then, use dataloader_kwargs to override (or add to) any of these defaults + # https://stackoverflow.com/questions/44784577/in-method-call-args-how-to-override-keyword-argument-of-unpacked-dict + train_loader_kwargs = { + "batch_size": min(training_batch_size, num_training_examples), + "drop_last": True, + "sampler": SubsetRandomSampler(torch.arange(len(self.train_indices).tolist()) ), + } + val_loader_kwargs = { + "batch_size": min(training_batch_size, num_validation_examples), + "shuffle": False, + "drop_last": True, + "sampler": SubsetRandomSampler(torch.arange(len(self.val_indices).tolist()) ), + } + if dataloader_kwargs is not None: + train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs) + val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs) + + return data.DataLoader(dataset[train_indices], **train_loader_kwargs), data.DataLoader(dataset[val_indices], **val_loader_kwargs) + def get_dataloaders( self, dataset: data.TensorDataset, @@ -358,7 +428,8 @@ def _report_convergence_at_end( ) def _summarize( - self, round_: int, x_o: Union[Tensor, None], theta_bank: Tensor, x_bank: Tensor + self, round_: int, x_o: Union[Tensor, None], theta_bank: Union[Tensor,None] + , x_bank: Union[Tensor,None] ) -> None: """Update the summary_writer with statistics for a given round. diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index da950d59a..459cd75fd 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -166,6 +166,7 @@ def train( discard_prior_samples: bool = False, retrain_from_scratch: bool = False, show_train_summary: bool = False, + warn_if_zscoring: Optional[bool] = True, dataloader_kwargs: Optional[dict] = None, ) -> nn.Module: r"""Return density estimator that approximates the distribution $p(\theta|x)$. @@ -217,25 +218,20 @@ def train( if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"): start_idx = self._round - theta, x, prior_masks = self.get_simulations( - start_idx, exclude_invalid_x, warn_on_invalid=True - ) - - # Dataset is shared for training and validation loaders. - dataset = data.TensorDataset(theta, x, prior_masks) - # Set the proposal to the last proposal that was passed by the user. For # atomic SNPE, it does not matter what the proposal is. For non-atomic # SNPE, we only use the latest data that was passed, i.e. the one from the # last proposal. proposal = self._proposal_roundwise[-1] - train_loader, val_loader = self.get_dataloaders( - dataset, + train_loader, val_loader = self.get_dataloaders_all( + start_idx, + exclude_invalid_x, training_batch_size, validation_fraction, resume_training, dataloader_kwargs=dataloader_kwargs, + warn_if_zscoring=warn_if_zscoring, ) # First round or if retraining from scratch: @@ -341,7 +337,7 @@ def train( self._summary["best_validation_log_probs"].append(self._best_val_log_prob) # Update tensorboard and summary dict. - self._summarize(round_=self._round, x_o=None, theta_bank=theta, x_bank=x) + self._summarize(round_=self._round, x_o=None, theta_bank=None, x_bank=None) # Update description for progress bar. if show_train_summary: From 750fb6518326722d4baad561bcf29f5f5d9100e6 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Fri, 18 Mar 2022 15:02:24 -0400 Subject: [PATCH 02/15] Minor typo --- sbi/inference/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 23450a29e..4dfc4d61d 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -244,7 +244,7 @@ def get_dataloaders_all( """ - datset = data.TensorDataset( + dataset = data.TensorDataset( *self.get_simulations( starting_round,exclude_invalid_x, warn_on_invalid = warn_on_invalid, warn_if_zscoring = warn_if_zscoring From 178e3d1ade694cb3722ce9268e8536c0916ae760 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Wed, 23 Mar 2022 14:14:20 -0400 Subject: [PATCH 03/15] Changes to SNPE append_simulations and get_dataloaders. --- sbi/inference/base.py | 27 +++++++------- sbi/inference/snpe/snpe_base.py | 64 ++++++++++++++++++++++++++------- sbi/inference/snpe/snpe_c.py | 1 + 3 files changed, 65 insertions(+), 27 deletions(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 4dfc4d61d..86897fecc 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -129,11 +129,14 @@ def __init__( # Initialize roundwise (theta, x, prior_masks) for storage of parameters, # simulations and masks indicating if simulations came from prior. self._theta_roundwise, self._x_roundwise, self._prior_masks = [], [], [] + self._dataset = None + self._num_sims_per_round = [] self._model_bank = [] # Initialize list that indicates the round from which simulations were drawn. self._data_round_index = [] + self._round = 0 self._val_log_prob = float("-Inf") @@ -221,13 +224,10 @@ def train( def get_dataloaders_all( self, starting_round: int = 0, - exclude_invalid_x: bool = True, - warn_on_invalid: bool = True, training_batch_size: int = 50, validation_fraction: float = 0.1, resume_training: bool = False, dataloader_kwargs: Optional[dict] = None, - warn_if_zscoring: Optional[bool] = True, ) -> Tuple[data.DataLoader, data.DataLoader]: """Return dataloaders for training and validation. @@ -244,22 +244,19 @@ def get_dataloaders_all( """ - dataset = data.TensorDataset( - *self.get_simulations( - starting_round,exclude_invalid_x, warn_on_invalid = warn_on_invalid, - warn_if_zscoring = warn_if_zscoring - ) + if starting_round == 0: + indices = torch.arange(sum(self._num_sims_per_round)) + else: + indices = torch.arange(sum(self._num_sims_per_round[:starting_round]), sum(self._num_sims_per_round)) - ) # Get total number of training examples. - num_examples = len(dataset) - + num_examples = len(indices) # Select random train and validation splits from (theta, x) pairs. num_training_examples = int((1 - validation_fraction) * num_examples) num_validation_examples = num_examples - num_training_examples if not resume_training: - permuted_indices = torch.randperm(num_examples) + permuted_indices = indices[torch.randperm(num_examples)] self.train_indices, self.val_indices = ( permuted_indices[:num_training_examples], permuted_indices[num_training_examples:], @@ -272,19 +269,19 @@ def get_dataloaders_all( train_loader_kwargs = { "batch_size": min(training_batch_size, num_training_examples), "drop_last": True, - "sampler": SubsetRandomSampler(torch.arange(len(self.train_indices).tolist()) ), + "sampler": SubsetRandomSampler(torch.arange(len(self.train_indices)).tolist() ), } val_loader_kwargs = { "batch_size": min(training_batch_size, num_validation_examples), "shuffle": False, "drop_last": True, - "sampler": SubsetRandomSampler(torch.arange(len(self.val_indices).tolist()) ), + "sampler": SubsetRandomSampler(torch.arange(len(self.val_indices)).tolist() ), } if dataloader_kwargs is not None: train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs) val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs) - return data.DataLoader(dataset[train_indices], **train_loader_kwargs), data.DataLoader(dataset[val_indices], **val_loader_kwargs) + return data.DataLoader(self._dataset, **train_loader_kwargs), data.DataLoader(self._dataset, **val_loader_kwargs) def get_dataloaders( self, diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 459cd75fd..55b919016 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -29,6 +29,10 @@ test_posterior_net_for_multi_d_x, validate_theta_and_x, x_shape_from_simulation, + handle_invalid_x, + warn_if_zscoring_changes_data, + warn_on_invalid_x, + warn_on_invalid_x_for_snpec_leakage, ) from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior @@ -88,6 +92,11 @@ def append_simulations( theta: Tensor, x: Tensor, proposal: Optional[DirectPosterior] = None, + exclude_invalid_x: bool = True, + warn_on_invalid: bool = True, + warn_if_zscoring: bool = True, + return_self: bool = True, + data_device: str = None ) -> "PosteriorEstimator": r"""Store parameters and simulation outputs to use them for later training. @@ -108,7 +117,26 @@ def append_simulations( NeuralInference object (returned so that this function is chainable). """ - theta, x = validate_theta_and_x(theta, x, training_device=self._device) + # Add ability to specify device data is saved on + if data_device is None: data_device = self._device + theta, x = validate_theta_and_x(theta, x, training_device=data_device) + + + is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) + + # Check for problematic z-scoring + if warn_if_zscoring: + warn_if_zscoring_changes_data(x[is_valid_x]) + if warn_on_invalid: + warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) + warn_on_invalid_x_for_snpec_leakage( + num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round + ) + + x = x[is_valid_x] + theta = theta[is_valid_x] + + self._check_proposal(proposal) if ( @@ -122,7 +150,7 @@ def append_simulations( # with MLE loss or with atomic loss (see, in `train()`: # self._round = max(self._data_round_index)) self._data_round_index.append(0) - self._prior_masks.append(mask_sims_from_prior(0, theta.size(0))) + prior_masks = mask_sims_from_prior(0, theta.size(0)) else: if not self._data_round_index: # This catches a pretty specific case: if, in the first round, one @@ -130,10 +158,16 @@ def append_simulations( self._data_round_index.append(1) else: self._data_round_index.append(max(self._data_round_index) + 1) - self._prior_masks.append(mask_sims_from_prior(1, theta.size(0))) + prior_masks = mask_sims_from_prior(1, theta.size(0)) - self._theta_roundwise.append(theta) - self._x_roundwise.append(x) + + if self._dataset is None: + #If first round, set up ConcatDataset + self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) + else: + self._dataset.datasets.append( data.TensorDataset(theta,x,prior_masks) ) + + self._num_sims_per_round.append(theta.size(0)) self._proposal_roundwise.append(proposal) if self._prior is None or isinstance(self._prior, ImproperEmpirical): @@ -150,7 +184,11 @@ def append_simulations( theta_prior = self.get_simulations()[0] self._prior = ImproperEmpirical(theta_prior, ones(theta_prior.shape[0])) - return self + #Add ability to not return self + if return_self: + return self + else: + return 1 def train( self, @@ -226,22 +264,24 @@ def train( train_loader, val_loader = self.get_dataloaders_all( start_idx, - exclude_invalid_x, training_batch_size, validation_fraction, resume_training, dataloader_kwargs=dataloader_kwargs, - warn_if_zscoring=warn_if_zscoring, ) - # First round or if retraining from scratch: # Call the `self._build_neural_net` with the rounds' thetas and xs as # arguments, which will build the neural network. # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: + + #Get test theta,x + test_theta = self._dataset.datasets[0].tensors[0][:100] + test_x = self._dataset.datasets[0].tensors[1][:100] + self._neural_net = self._build_neural_net( - theta[self.train_indices], x[self.train_indices] + test_theta, test_x ) # If data on training device already move net as well. if ( @@ -250,8 +290,8 @@ def train( ): self._neural_net.to(self._device) - test_posterior_net_for_multi_d_x(self._neural_net, theta, x) - self._x_shape = x_shape_from_simulation(x) + test_posterior_net_for_multi_d_x(self._neural_net, test_theta, test_x) + self._x_shape = x_shape_from_simulation(test_x) # Move entire net to device for training. self._neural_net.to(self._device) diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index 065a45f1f..6d3b0e8e5 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -100,6 +100,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, + warn_if_zscoring: Optional[bool] = True ) -> nn.Module: r"""Return density estimator that approximates the distribution $p(\theta|x)$. From e1f7fd35ec331ad348967977ad6616fe5eb2b1f7 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Thu, 24 Mar 2022 11:13:49 -0400 Subject: [PATCH 04/15] Minor bug fix --- sbi/inference/snpe/snpe_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 55b919016..787844693 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -276,7 +276,7 @@ def train( # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - #Get test theta,x + #Get theta,x from dataset to initialize NN test_theta = self._dataset.datasets[0].tensors[0][:100] test_x = self._dataset.datasets[0].tensors[1][:100] @@ -286,7 +286,7 @@ def train( # If data on training device already move net as well. if ( not self._device == "cpu" - and f"{x.device.type}:{x.device.index}" == self._device + and f"{test_x.device.type}:{test_x.device.index}" == self._device ): self._neural_net.to(self._device) From d77bed7b64734d6c34cc60a5dcdcc821e813d007 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Thu, 24 Mar 2022 15:24:43 -0400 Subject: [PATCH 05/15] Minor bugfix when using multiple rounds --- sbi/inference/snpe/snpe_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 787844693..186ae5a18 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -165,7 +165,8 @@ def append_simulations( #If first round, set up ConcatDataset self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) else: - self._dataset.datasets.append( data.TensorDataset(theta,x,prior_masks) ) + #Otherwise append to Dataset + self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) self._num_sims_per_round.append(theta.size(0)) self._proposal_roundwise.append(proposal) From 48ace32258a1ba19101bb2fca321b154abfe142b Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Thu, 21 Apr 2022 14:39:03 -0400 Subject: [PATCH 06/15] Added to SNRE and SNLE --- README.md | 238 +- sbi/inference/base.py | 115 +- sbi/inference/snle/snle_base.py | 73 +- sbi/inference/snpe/snpe_a.py | 1721 +- sbi/inference/snpe/snpe_base.py | 1191 +- sbi/inference/snpe/snpe_c.py | 1254 +- sbi/inference/snre/snre_a.py | 3 - sbi/inference/snre/snre_b.py | 3 - sbi/inference/snre/snre_base.py | 71 +- sbi/utils/sbiutils.py | 23 + tests/linearGaussian_snpe_test.py | 1166 +- tutorials/07_conditional_distributions.ipynb | 33176 ++++++++--------- 12 files changed, 19522 insertions(+), 19512 deletions(-) diff --git a/README.md b/README.md index 96012b76c..9d9fc26e5 100644 --- a/README.md +++ b/README.md @@ -1,119 +1,119 @@ -[![PyPI version](https://badge.fury.io/py/sbi.svg)](https://badge.fury.io/py/sbi) -[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mackelab/sbi/blob/master/CONTRIBUTING.md) -[![Tests](https://github.com/mackelab/sbi/workflows/Tests/badge.svg?branch=main)](https://github.com/mackelab/sbi/actions) -[![codecov](https://codecov.io/gh/mackelab/sbi/branch/main/graph/badge.svg)](https://codecov.io/gh/mackelab/sbi) -[![GitHub license](https://img.shields.io/github/license/mackelab/sbi)](https://github.com/mackelab/sbi/blob/master/LICENSE.txt) -[![DOI](https://joss.theoj.org/papers/10.21105/joss.02505/status.svg)](https://doi.org/10.21105/joss.02505) - -## sbi: simulation-based inference -[Getting Started](https://www.mackelab.org/sbi/tutorial/00_getting_started/) | [Documentation](https://www.mackelab.org/sbi/) - -`sbi` is a PyTorch package for simulation-based inference. Simulation-based inference is -the process of finding parameters of a simulator from observations. - -`sbi` takes a Bayesian approach and returns a full posterior distribution -over the parameters, conditional on the observations. This posterior can be amortized (i.e. -useful for any observation) or focused (i.e. tailored to a particular observation), with different -computational trade-offs. - -`sbi` offers a simple interface for one-line posterior inference. - -```python -from sbi.inference import infer -# import your simulator, define your prior over the parameters -parameter_posterior = infer(simulator, prior, method='SNPE', num_simulations=100) -``` -See below for the available methods of inference, `SNPE`, `SNRE` and `SNLE`. - - -## Installation - -`sbi` requires Python 3.6 or higher. We recommend to use a [`conda`](https://docs.conda.io/en/latest/miniconda.html) virtual -environment ([Miniconda installation instructions](https://docs.conda.io/en/latest/miniconda.html])). If `conda` is installed on the system, an environment for -installing `sbi` can be created as follows: -```commandline -# Create an environment for sbi (indicate Python 3.6 or higher); activate it -$ conda create -n sbi_env python=3.7 && conda activate sbi_env -``` - -Independent of whether you are using `conda` or not, `sbi` can be installed using `pip`: -```commandline -$ pip install sbi -``` - -To test the installation, drop into a python prompt and run -```python -from sbi.examples.minimal import simple -posterior = simple() -print(posterior) -``` - -## Inference Algorithms - -The following algorithms are currently available: - -#### Sequential Neural Posterior Estimation (SNPE) - -* [`SNPE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snpe.snpe_a.SNPE_A) from Papamakarios G and Murray I [_Fast ε-free Inference of Simulation Models with Bayesian Conditional Density Estimation_](https://proceedings.neurips.cc/paper/2016/hash/6aca97005c68f1206823815f66102863-Abstract.html) (NeurIPS 2016). - -* [`SNPE_C`](https://www.mackelab.org/sbi/reference/#sbi.inference.snpe.snpe_c.SNPE_C) or `APT` from Greenberg D, Nonnenmacher M, and Macke J [_Automatic - Posterior Transformation for likelihood-free - inference_](https://arxiv.org/abs/1905.07488) (ICML 2019). - - -#### Sequential Neural Likelihood Estimation (SNLE) -* [`SNLE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snle.snle_a.SNLE_A) or just `SNL` from Papamakarios G, Sterrat DC and Murray I [_Sequential - Neural Likelihood_](https://arxiv.org/abs/1805.07226) (AISTATS 2019). - - -#### Sequential Neural Ratio Estimation (SNRE) - -* [`SNRE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_a.SNRE_A) or `AALR` from Hermans J, Begy V, and Louppe G. [_Likelihood-free Inference with Amortized Approximate Likelihood Ratios_](https://arxiv.org/abs/1903.04057) (ICML 2020). - -* [`SNRE_B`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_b.SNRE_B) or `SRE` from Durkan C, Murray I, and Papamakarios G. [_On Contrastive Learning for Likelihood-free Inference_](https://arxiv.org/abs/2002.03712) (ICML 2020). - -#### Sequential Neural Variational Inference (SNVI) - -* [`SNVI`](https://www.mackelab.org/sbi/reference/#sbi.inference.posteriors.vi_posterior) from Glöckler M, Deistler M, Macke J, [_Variational methods for simulation-based inference_](https://openreview.net/forum?id=kZ0UYdhqkNY) (ICLR 2022). - -## Feedback and Contributions - -We would like to hear how `sbi` is working for your inference problems as well as receive bug reports, pull requests and other feedback (see -[contribute](http://www.mackelab.org/sbi/contribute/)). - - -## Acknowledgements - -`sbi` is the successor (using PyTorch) of the -[`delfi`](https://github.com/mackelab/delfi) package. It was started as a fork of Conor -M. Durkan's `lfi`. `sbi` runs as a community project; development is coordinated at the -[mackelab](https://uni-tuebingen.de/en/research/core-research/cluster-of-excellence-machine-learning/research/research/cluster-research-groups/professorships/machine-learning-in-science/). See also [credits](https://github.com/mackelab/sbi/blob/master/docs/docs/credits.md). - - -## Support - -`sbi` has been supported by the German Federal Ministry of Education and Research (BMBF) through the project ADIMEM, FKZ 01IS18052 A-D). [ADIMEM](https://fit.uni-tuebingen.de/Project/Details?id=9199) is a collaborative project between the groups of Jakob Macke (Uni Tübingen), Philipp Berens (Uni Tübingen), Philipp Hennig (Uni Tübingen) and Marcel Oberlaender (caesar Bonn) which aims to develop inference methods for mechanistic models. - - -## License - -[Affero General Public License v3 (AGPLv3)](https://www.gnu.org/licenses/) - - -## Citation -If you use `sbi` consider citing the [sbi software paper](https://doi.org/10.21105/joss.02505), in addition to the original research articles describing the specifc sbi-algorithm(s) you are using: - -``` -@article{tejero-cantero2020sbi, - doi = {10.21105/joss.02505}, - url = {https://doi.org/10.21105/joss.02505}, - year = {2020}, - publisher = {The Open Journal}, - volume = {5}, - number = {52}, - pages = {2505}, - author = {Alvaro Tejero-Cantero and Jan Boelts and Michael Deistler and Jan-Matthis Lueckmann and Conor Durkan and Pedro J. Gonçalves and David S. Greenberg and Jakob H. Macke}, - title = {sbi: A toolkit for simulation-based inference}, - journal = {Journal of Open Source Software} -} -``` +[![PyPI version](https://badge.fury.io/py/sbi.svg)](https://badge.fury.io/py/sbi) +[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mackelab/sbi/blob/master/CONTRIBUTING.md) +[![Tests](https://github.com/mackelab/sbi/workflows/Tests/badge.svg?branch=main)](https://github.com/mackelab/sbi/actions) +[![codecov](https://codecov.io/gh/mackelab/sbi/branch/main/graph/badge.svg)](https://codecov.io/gh/mackelab/sbi) +[![GitHub license](https://img.shields.io/github/license/mackelab/sbi)](https://github.com/mackelab/sbi/blob/master/LICENSE.txt) +[![DOI](https://joss.theoj.org/papers/10.21105/joss.02505/status.svg)](https://doi.org/10.21105/joss.02505) + +## sbi: simulation-based inference +[Getting Started](https://www.mackelab.org/sbi/tutorial/00_getting_started/) | [Documentation](https://www.mackelab.org/sbi/) + +`sbi` is a PyTorch package for simulation-based inference. Simulation-based inference is +the process of finding parameters of a simulator from observations. + +`sbi` takes a Bayesian approach and returns a full posterior distribution +over the parameters, conditional on the observations. This posterior can be amortized (i.e. +useful for any observation) or focused (i.e. tailored to a particular observation), with different +computational trade-offs. + +`sbi` offers a simple interface for one-line posterior inference. + +```python +from sbi.inference import infer +# import your simulator, define your prior over the parameters +parameter_posterior = infer(simulator, prior, method='SNPE', num_simulations=100) +``` +See below for the available methods of inference, `SNPE`, `SNRE` and `SNLE`. + + +## Installation + +`sbi` requires Python 3.6 or higher. We recommend to use a [`conda`](https://docs.conda.io/en/latest/miniconda.html) virtual +environment ([Miniconda installation instructions](https://docs.conda.io/en/latest/miniconda.html])). If `conda` is installed on the system, an environment for +installing `sbi` can be created as follows: +```commandline +# Create an environment for sbi (indicate Python 3.6 or higher); activate it +$ conda create -n sbi_env python=3.7 && conda activate sbi_env +``` + +Independent of whether you are using `conda` or not, `sbi` can be installed using `pip`: +```commandline +$ pip install sbi +``` + +To test the installation, drop into a python prompt and run +```python +from sbi.examples.minimal import simple +posterior = simple() +print(posterior) +``` + +## Inference Algorithms + +The following algorithms are currently available: + +#### Sequential Neural Posterior Estimation (SNPE) + +* [`SNPE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snpe.snpe_a.SNPE_A) from Papamakarios G and Murray I [_Fast ε-free Inference of Simulation Models with Bayesian Conditional Density Estimation_](https://proceedings.neurips.cc/paper/2016/hash/6aca97005c68f1206823815f66102863-Abstract.html) (NeurIPS 2016). + +* [`SNPE_C`](https://www.mackelab.org/sbi/reference/#sbi.inference.snpe.snpe_c.SNPE_C) or `APT` from Greenberg D, Nonnenmacher M, and Macke J [_Automatic + Posterior Transformation for likelihood-free + inference_](https://arxiv.org/abs/1905.07488) (ICML 2019). + + +#### Sequential Neural Likelihood Estimation (SNLE) +* [`SNLE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snle.snle_a.SNLE_A) or just `SNL` from Papamakarios G, Sterrat DC and Murray I [_Sequential + Neural Likelihood_](https://arxiv.org/abs/1805.07226) (AISTATS 2019). + + +#### Sequential Neural Ratio Estimation (SNRE) + +* [`SNRE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_a.SNRE_A) or `AALR` from Hermans J, Begy V, and Louppe G. [_Likelihood-free Inference with Amortized Approximate Likelihood Ratios_](https://arxiv.org/abs/1903.04057) (ICML 2020). + +* [`SNRE_B`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_b.SNRE_B) or `SRE` from Durkan C, Murray I, and Papamakarios G. [_On Contrastive Learning for Likelihood-free Inference_](https://arxiv.org/abs/2002.03712) (ICML 2020). + +#### Sequential Neural Variational Inference (SNVI) + +* [`SNVI`](https://www.mackelab.org/sbi/reference/#sbi.inference.posteriors.vi_posterior) from Glöckler M, Deistler M, Macke J, [_Variational methods for simulation-based inference_](https://openreview.net/forum?id=kZ0UYdhqkNY) (ICLR 2022). + +## Feedback and Contributions + +We would like to hear how `sbi` is working for your inference problems as well as receive bug reports, pull requests and other feedback (see +[contribute](http://www.mackelab.org/sbi/contribute/)). + + +## Acknowledgements + +`sbi` is the successor (using PyTorch) of the +[`delfi`](https://github.com/mackelab/delfi) package. It was started as a fork of Conor +M. Durkan's `lfi`. `sbi` runs as a community project; development is coordinated at the +[mackelab](https://uni-tuebingen.de/en/research/core-research/cluster-of-excellence-machine-learning/research/research/cluster-research-groups/professorships/machine-learning-in-science/). See also [credits](https://github.com/mackelab/sbi/blob/master/docs/docs/credits.md). + + +## Support + +`sbi` has been supported by the German Federal Ministry of Education and Research (BMBF) through the project ADIMEM, FKZ 01IS18052 A-D). [ADIMEM](https://fit.uni-tuebingen.de/Project/Details?id=9199) is a collaborative project between the groups of Jakob Macke (Uni Tübingen), Philipp Berens (Uni Tübingen), Philipp Hennig (Uni Tübingen) and Marcel Oberlaender (caesar Bonn) which aims to develop inference methods for mechanistic models. + + +## License + +[Affero General Public License v3 (AGPLv3)](https://www.gnu.org/licenses/) + + +## Citation +If you use `sbi` consider citing the [sbi software paper](https://doi.org/10.21105/joss.02505), in addition to the original research articles describing the specifc sbi-algorithm(s) you are using: + +``` +@article{tejero-cantero2020sbi, + doi = {10.21105/joss.02505}, + url = {https://doi.org/10.21105/joss.02505}, + year = {2020}, + publisher = {The Open Journal}, + volume = {5}, + number = {52}, + pages = {2505}, + author = {Alvaro Tejero-Cantero and Jan Boelts and Michael Deistler and Jan-Matthis Lueckmann and Conor Durkan and Pedro J. Gonçalves and David S. Greenberg and Jakob H. Macke}, + title = {sbi: A toolkit for simulation-based inference}, + journal = {Journal of Open Source Software} +} +``` diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 86897fecc..daae107f1 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -26,7 +26,7 @@ warn_on_invalid_x, warn_on_invalid_x_for_snpec_leakage, ) -from sbi.utils.sbiutils import get_simulations_since_round +from sbi.utils.sbiutils import get_simulations_indcies from sbi.utils.torchutils import check_if_prior_on_device, process_device from sbi.utils.user_input_checks import prepare_for_sbi @@ -128,7 +128,6 @@ def __init__( # Initialize roundwise (theta, x, prior_masks) for storage of parameters, # simulations and masks indicating if simulations came from prior. - self._theta_roundwise, self._x_roundwise, self._prior_masks = [], [], [] self._dataset = None self._num_sims_per_round = [] self._model_bank = [] @@ -162,9 +161,6 @@ def __init__( def get_simulations( self, starting_round: int = 0, - exclude_invalid_x: bool = True, - warn_on_invalid: bool = True, - warn_if_zscoring: Optional[bool] = True, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Returns all $\theta$, $x$, and prior_masks from rounds >= `starting_round`. @@ -181,28 +177,22 @@ def get_simulations( Returns: Parameters, simulation outputs, prior masks. """ - theta = get_simulations_since_round( - self._theta_roundwise, self._data_round_index, starting_round - ) - x = get_simulations_since_round( - self._x_roundwise, self._data_round_index, starting_round - ) - prior_masks = get_simulations_since_round( - self._prior_masks, self._data_round_index, starting_round - ) + #This is a pretty clunky implementation but not sure this will be used much with + #new implementation of `get_dataloaders` + indicies = get_simulations_indcies(self._num_sims_per_round, self._data_round_index, starting_round) + theta,x,prior_masks = [],[],[] - # Check for NaNs in simulations. - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) - # Check for problematic z-scoring - if warn_if_zscoring: - warn_if_zscoring_changes_data(x[is_valid_x]) - if warn_on_invalid: - warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) - warn_on_invalid_x_for_snpec_leakage( - num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round - ) + for ind in indicies: + theta_cur,x_cur,prior_mask_cur = self._dataset[ind] + theta.append(theta_cur) + x.append(x_cur) + prior_masks.append(prior_mask_cur) + + theta = torch.stack(theta).squeeze() + x = torch.stack(x).squeeze() + prior_masks = torch.stack(prior_masks).squeeze() - return theta[is_valid_x], x[is_valid_x], prior_masks[is_valid_x] + return theta, x, prior_masks @abstractmethod def train( @@ -221,7 +211,7 @@ def train( ) -> NeuralPosterior: raise NotImplementedError - def get_dataloaders_all( + def get_dataloaders( self, starting_round: int = 0, training_batch_size: int = 50, @@ -243,11 +233,9 @@ def get_dataloaders_all( Tuple of dataloaders for training and validation. """ - - if starting_round == 0: - indices = torch.arange(sum(self._num_sims_per_round)) - else: - indices = torch.arange(sum(self._num_sims_per_round[:starting_round]), sum(self._num_sims_per_round)) + + #Generate indicies to use based on starting_round + indices = get_simulations_indcies(self._num_sims_per_round, self._data_round_index, starting_round) # Get total number of training examples. num_examples = len(indices) @@ -256,6 +244,7 @@ def get_dataloaders_all( num_validation_examples = num_examples - num_training_examples if not resume_training: + # Seperate indicies for training and validation permuted_indices = indices[torch.randperm(num_examples)] self.train_indices, self.val_indices = ( permuted_indices[:num_training_examples], @@ -281,68 +270,10 @@ def get_dataloaders_all( train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs) val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs) - return data.DataLoader(self._dataset, **train_loader_kwargs), data.DataLoader(self._dataset, **val_loader_kwargs) - - def get_dataloaders( - self, - dataset: data.TensorDataset, - training_batch_size: int = 50, - validation_fraction: float = 0.1, - resume_training: bool = False, - dataloader_kwargs: Optional[dict] = None, - ) -> Tuple[data.DataLoader, data.DataLoader]: - """Return dataloaders for training and validation. - - Args: - dataset: holding all theta and x, optionally masks. - training_batch_size: training arg of inference methods. - resume_training: Whether the current call is resuming training so that no - new training and validation indices into the dataset have to be created. - dataloader_kwargs: Additional or updated kwargs to be passed to the training - and validation dataloaders (like, e.g., a collate_fn). - - Returns: - Tuple of dataloaders for training and validation. - - """ - - # Get total number of training examples. - num_examples = len(dataset) - - # Select random train and validation splits from (theta, x) pairs. - num_training_examples = int((1 - validation_fraction) * num_examples) - num_validation_examples = num_examples - num_training_examples - - if not resume_training: - permuted_indices = torch.randperm(num_examples) - self.train_indices, self.val_indices = ( - permuted_indices[:num_training_examples], - permuted_indices[num_training_examples:], - ) - - # Create training and validation loaders using a subset sampler. - # Intentionally use dicts to define the default dataloader args - # Then, use dataloader_kwargs to override (or add to) any of these defaults - # https://stackoverflow.com/questions/44784577/in-method-call-args-how-to-override-keyword-argument-of-unpacked-dict - train_loader_kwargs = { - "batch_size": min(training_batch_size, num_training_examples), - "drop_last": True, - "sampler": SubsetRandomSampler(self.train_indices.tolist()), - } - val_loader_kwargs = { - "batch_size": min(training_batch_size, num_validation_examples), - "shuffle": False, - "drop_last": True, - "sampler": SubsetRandomSampler(self.val_indices.tolist()), - } - if dataloader_kwargs is not None: - train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs) - val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs) - - train_loader = data.DataLoader(dataset, **train_loader_kwargs) - val_loader = data.DataLoader(dataset, **val_loader_kwargs) + train_loader = data.DataLoader(self._dataset, **train_loader_kwargs) + val_loader = data.DataLoader(self._dataset, **val_loader_kwargs) - return train_loader, val_loader + return train_loader,val_loader def _converged(self, epoch: int, stop_after_epochs: int) -> bool: """Return whether the training converged yet and save best model state so far. diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index e9e7baace..33aefa8c4 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -22,7 +22,10 @@ check_estimator_arg, check_prior, mask_sims_from_prior, + handle_invalid_x, validate_theta_and_x, + warn_if_zscoring_changes_data, + warn_on_invalid_x, x_shape_from_simulation, ) @@ -83,6 +86,11 @@ def append_simulations( theta: Tensor, x: Tensor, from_round: int = 0, + exclude_invalid_x: bool = True, + warn_on_invalid: bool = True, + warn_if_zscoring: bool = True, + return_self: bool = True, + data_device: str = None, ) -> "LikelihoodEstimator": r"""Store parameters and simulation outputs to use them for later training. @@ -99,19 +107,46 @@ def append_simulations( With default settings, this is not used at all for `SNLE`. Only when the user later on requests `.train(discard_prior_samples=True)`, we use these indices to find which training data stemmed from the prior. - + exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` + during training. Expect errors, silent or explicit, when `False`. + warn_on_invalid: Whether to warn if data is invalid + warn_if_zscoring: Whether to test if z-scoring causes duplicates + return_self: Whether to return a instance of the class, allows chaining + with `.train()`. Setting `False` decreases memory overhead. + data_device: Where to store the data, default is on the same device where + the training is happening. If training a large dataset on a GPU with not + much VRAM can set to 'cpu' to store data on system memory instead. Returns: NeuralInference object (returned so that this function is chainable). """ - theta, x = validate_theta_and_x(theta, x, training_device=self._device) + is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) + + # Check for problematic z-scoring + if warn_if_zscoring: + warn_if_zscoring_changes_data(x[is_valid_x]) + if warn_on_invalid: + warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) + + x = x[is_valid_x] + theta = theta[is_valid_x] + + if data_device is None: data_device = self._device + theta, x = validate_theta_and_x(theta, x, training_device=data_device) + prior_masks = mask_sims_from_prior(int(from_round), theta.size(0)) - self._theta_roundwise.append(theta) - self._x_roundwise.append(x) - self._prior_masks.append(mask_sims_from_prior(int(from_round), theta.size(0))) - self._data_round_index.append(int(from_round)) + if self._dataset is None: + #If first round, set up ConcatDataset + self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) + else: + #Otherwise append to Dataset + self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) - return self + self._num_sims_per_round.append(theta.size(0)) + self._data_round_index.append(int(from_round) ) + + if return_self: + return self def train( self, @@ -121,7 +156,6 @@ def train( stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, - exclude_invalid_x: bool = True, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, @@ -131,8 +165,6 @@ def train( r"""Train the density estimator to learn the distribution $p(x|\theta)$. Args: - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will @@ -155,15 +187,9 @@ def train( start_idx = int(discard_prior_samples and self._round > 0) # Load data from most recent round. self._round = max(self._data_round_index) - theta, x, _ = self.get_simulations( - start_idx, exclude_invalid_x, warn_on_invalid=True - ) - - # Dataset is shared for training and validation loaders. - dataset = data.TensorDataset(theta, x) train_loader, val_loader = self.get_dataloaders( - dataset, + start_idx, training_batch_size, validation_fraction, resume_training, @@ -176,10 +202,15 @@ def train( # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: + + #Get theta,x from dataset to initialize NN + test_theta = self._dataset.datasets[0].tensors[0][:100] + test_x = self._dataset.datasets[0].tensors[1][:100] + self._neural_net = self._build_neural_net( - theta[self.train_indices], x[self.train_indices] + test_theta, test_x ) - self._x_shape = x_shape_from_simulation(x) + self._x_shape = x_shape_from_simulation(test_x) assert ( len(self._x_shape) < 3 ), "SNLE cannot handle multi-dimensional simulator output." @@ -257,8 +288,8 @@ def train( self._summarize( round_=self._round, x_o=None, - theta_bank=theta, - x_bank=x, + theta_bank=None, + x_bank=None, ) # Update description for progress bar. diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index e0edefd12..664e62fff 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -1,862 +1,859 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . - -import warnings -from copy import deepcopy -from functools import partial -from typing import Any, Callable, Dict, Optional, Union - -import torch -import torch.nn as nn -from pyknos.mdn.mdn import MultivariateGaussianMDN -from pyknos.nflows import flows -from pyknos.nflows.transforms import CompositeTransform -from torch import Tensor -from torch.distributions import Distribution, MultivariateNormal - -import sbi.utils as utils -from sbi.inference.posteriors.direct_posterior import DirectPosterior -from sbi.inference.snpe.snpe_base import PosteriorEstimator -from sbi.types import TensorboardSummaryWriter, TorchModule -from sbi.utils import torchutils - - -class SNPE_A(PosteriorEstimator): - def __init__( - self, - prior: Optional[Distribution] = None, - density_estimator: Union[str, Callable] = "mdn_snpe_a", - num_components: int = 10, - device: str = "cpu", - logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorboardSummaryWriter] = None, - show_progress_bars: bool = True, - ): - r"""SNPE-A [1]. - - [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional - Density Estimation_, Papamakarios et al., NeurIPS 2016, - https://arxiv.org/abs/1605.06376. - - This class implements SNPE-A. SNPE-A trains across multiple rounds with a - maximum-likelihood-loss. This will make training converge to the proposal - posterior instead of the true posterior. To correct for this, SNPE-A applies a - post-hoc correction after training. This correction has to be performed - analytically. Thus, SNPE-A is limited to Gaussian distributions for all but the - last round. In the last round, SNPE-A can use a Mixture of Gaussians. - - Args: - prior: A probability distribution that expresses prior knowledge about the - parameters, e.g. which ranges are meaningful for them. Any - object with `.log_prob()`and `.sample()` (for example, a PyTorch - distribution) can be used. - density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a - pre-configured mixture of densities network. Alternatively, a function - that builds a custom neural network can be provided. The function will - be called with the first batch of simulations (theta, x), which can - thus be used for shape inference and potentially for z-scoring. It - needs to return a PyTorch `nn.Module` implementing the density - estimator. The density estimator needs to provide the methods - `.log_prob` and `.sample()`. Note that until the last round only a - single (multivariate) Gaussian component is used for training (see - Algorithm 1 in [1]). In the last round, this component is replicated - `num_components` times, its parameters are perturbed with a very small - noise, and then the last training round is done with the expanded - Gaussian mixture as estimator for the proposal posterior. - num_components: Number of components of the mixture of Gaussians in the - last round. This overrides the `num_components` value passed to - `posterior_nn()`. - device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". - logging_level: Minimum severity of messages to log. One of the strings - INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) - show_progress_bars: Whether to show a progressbar during training. - """ - - # Catch invalid inputs. - if not ((density_estimator == "mdn_snpe_a") or callable(density_estimator)): - raise TypeError( - "The `density_estimator` passed to SNPE_A needs to be a " - "callable or the string 'mdn_snpe_a'!" - ) - - # `num_components` will be used to replicate the Gaussian in the last round. - self._num_components = num_components - self._ran_final_round = False - - # WARNING: sneaky trick ahead. We proxy the parent's `train` here, - # requiring the signature to have `num_atoms`, save it for use below, and - # continue. It's sneaky because we are using the object (self) as a namespace - # to pass arguments between functions, and that's implicit state management. - kwargs = utils.del_entries( - locals(), - entries=("self", "__class__", "num_components"), - ) - super().__init__(**kwargs) - - def train( - self, - final_round: bool = False, - training_batch_size: int = 50, - learning_rate: float = 5e-4, - validation_fraction: float = 0.1, - stop_after_epochs: int = 20, - max_num_epochs: int = 2**31 - 1, - clip_max_norm: Optional[float] = 5.0, - calibration_kernel: Optional[Callable] = None, - exclude_invalid_x: bool = True, - resume_training: bool = False, - force_first_round_loss: bool = False, - retrain_from_scratch: bool = False, - show_train_summary: bool = False, - dataloader_kwargs: Optional[Dict] = None, - component_perturbation: float = 5e-3, - ) -> nn.Module: - r"""Return density estimator that approximates the proposal posterior. - - [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional - Density Estimation_, Papamakarios et al., NeurIPS 2016, - https://arxiv.org/abs/1605.06376. - - Training is performed with maximum likelihood on samples from the latest round, - which leads the algorithm to converge to the proposal posterior. - - Args: - final_round: Whether we are in the last round of training or not. For all - but the last round, Algorithm 1 from [1] is executed. In last the - round, Algorithm 2 from [1] is executed once. - training_batch_size: Training batch size. - learning_rate: Learning rate for Adam optimizer. - validation_fraction: The fraction of data to use for validation. - stop_after_epochs: The number of epochs to wait for improvement on the - validation set before terminating training. - max_num_epochs: Maximum number of epochs to run. If reached, we stop - training even when the validation loss is still decreasing. Otherwise, - we train until validation loss increases (see also `stop_after_epochs`). - clip_max_norm: Value at which to clip the total gradient norm in order to - prevent exploding gradients. Use None for no clipping. - calibration_kernel: A function to calibrate the loss with respect to the - simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. - resume_training: Can be used in case training time is limited, e.g. on a - cluster. If `True`, the split between train and validation set, the - optimizer, the number of epochs, and the best validation log-prob will - be restored from the last time `.train()` was called. - force_first_round_loss: If `True`, train with maximum likelihood, - i.e., potentially ignoring the correction for using a proposal - distribution different from the prior. - force_first_round_loss: If `True`, train with maximum likelihood, - regardless of the proposal distribution. - retrain_from_scratch: Whether to retrain the conditional density - estimator for the posterior from scratch each round. Not supported for - SNPE-A. - show_train_summary: Whether to print the number of epochs and validation - loss and leakage after the training. - dataloader_kwargs: Additional or updated kwargs to be passed to the training - and validation dataloaders (like, e.g., a collate_fn) - component_perturbation: The standard deviation applied to all weights and - biases when, in the last round, the Mixture of Gaussians is build from - a single Gaussian. This value can be problem-specific and also depends - on the number of mixture components. - - Returns: - Density estimator that approximates the distribution $p(\theta|x)$. - """ - - assert not retrain_from_scratch, """Retraining from scratch is not supported in SNPE-A yet. The reason for - this is that, if we reininitialized the density estimator, the z-scoring would - change, which would break the posthoc correction. This is a pure implementation - issue.""" - - kwargs = utils.del_entries( - locals(), - entries=("self", "__class__", "final_round", "component_perturbation"), - ) - - # SNPE-A always discards the prior samples. - kwargs["discard_prior_samples"] = True - - self._round = max(self._data_round_index) - - if final_round: - # If there is (will be) only one round, train with Algorithm 2 from [1]. - if self._round == 0: - self._build_neural_net = partial( - self._build_neural_net, num_components=self._num_components - ) - # Run Algorithm 2 from [1]. - elif not self._ran_final_round: - # Now switch to the specified number of components. This method will - # only be used if `retrain_from_scratch=True`. Otherwise, - # the MDN will be built from replicating the single-component net for - # `num_component` times (via `_expand_mog()`). - self._build_neural_net = partial( - self._build_neural_net, num_components=self._num_components - ) - - # Extend the MDN to the originally desired number of components. - self._expand_mog(eps=component_perturbation) - else: - warnings.warn( - "You have already run SNPE-A with `final_round=True`. Running it" - "again with this setting will not allow computing the posthoc" - "correction applied in SNPE-A. Thus, you will get an error when " - "calling `.build_posterior()` after training.", - UserWarning, - ) - else: - # Run Algorithm 1 from [1]. - # Wrap the function that builds the MDN such that we can make - # sure that there is only one component when running. - self._build_neural_net = partial(self._build_neural_net, num_components=1) - - if final_round: - self._ran_final_round = True - - return super().train(**kwargs) - - def correct_for_proposal( - self, - density_estimator: Optional[TorchModule] = None, - ) -> "SNPE_A_MDN": - r"""Build mixture of Gaussians that approximates the posterior. - - Returns a `SNPE_A_MDN` object, which applies the posthoc-correction required in - SNPE-A. - - Args: - density_estimator: The density estimator that the posterior is based on. - If `None`, use the latest neural density estimator that was trained. - - Returns: - Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. - """ - if density_estimator is None: - density_estimator = deepcopy( - self._neural_net - ) # PosteriorEstimator.train() also returns a deepcopy, mimic this here - # If internal net is used device is defined. - device = self._device - else: - # Otherwise, infer it from the device of the net parameters. - device = str(next(density_estimator.parameters()).device) - - # Set proposal of the density estimator. - # This also evokes the z-scoring correction if necessary. - if ( - self._proposal_roundwise[-1] is self._prior - or self._proposal_roundwise[-1] is None - ): - proposal = self._prior - assert isinstance( - proposal, (MultivariateNormal, utils.BoxUniform) - ), """Prior must be `torch.distributions.MultivariateNormal` or `sbi.utils. - BoxUniform`""" - else: - assert isinstance( - self._proposal_roundwise[-1], DirectPosterior - ), """The proposal you passed to `append_simulations` is neither the prior - nor a `DirectPosterior`. SNPE-A currently only supports these scenarios. - """ - proposal = self._proposal_roundwise[-1] - - # Create the SNPE_A_MDN - wrapped_density_estimator = SNPE_A_MDN( - flow=density_estimator, proposal=proposal, prior=self._prior, device=device - ) - return wrapped_density_estimator - - def build_posterior( - self, - density_estimator: Optional[TorchModule] = None, - prior: Optional[Distribution] = None, - ) -> "DirectPosterior": - r"""Build posterior from the neural density estimator. - - This method first corrects the estimated density with `correct_for_proposal` - and then returns a `DirectPosterior`. - - Args: - density_estimator: The density estimator that the posterior is based on. - If `None`, use the latest neural density estimator that was trained. - prior: Prior distribution. - - Returns: - Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. - """ - if prior is None: - assert ( - self._prior is not None - ), """You did not pass a prior. You have to pass the prior either at - initialization `inference = SNPE_A(prior)` or to `.build_posterior - (prior=prior)`.""" - prior = self._prior - - wrapped_density_estimator = self.correct_for_proposal( - density_estimator=density_estimator - ) - self._posterior = DirectPosterior( - posterior_estimator=wrapped_density_estimator, - prior=prior, - ) - return deepcopy(self._posterior) - - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - ) -> Tensor: - """Return the log-probability of the proposal posterior. - - For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in - `_loss()` to be found in `snpe_base.py`. - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - proposal: Proposal distribution. - - Returns: Log-probability of the proposal posterior. - """ - return self._neural_net.log_prob(theta, x) - - def _expand_mog(self, eps: float = 1e-5): - """ - Replicate a singe Gaussian trained with Algorithm 1 before continuing - with Algorithm 2. The weights and biases of the associated MDN layers - are repeated `num_components` times, slightly perturbed to break the - symmetry such that the gradients in the subsequent training are not - all identical. - - Args: - eps: Standard deviation for the random perturbation. - """ - assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN) - - # Increase the number of components - self._neural_net._distribution._num_components = self._num_components - - # Expand the 1-dim Gaussian. - for name, param in self._neural_net.named_parameters(): - if any( - key in name for key in ["logits", "means", "unconstrained", "upper"] - ): - if "bias" in name: - param.data = param.data.repeat(self._num_components) - param.data.add_(torch.randn_like(param.data) * eps) - param.grad = None # let autograd construct a new gradient - elif "weight" in name: - param.data = param.data.repeat(self._num_components, 1) - param.data.add_(torch.randn_like(param.data) * eps) - param.grad = None # let autograd construct a new gradient - - -class SNPE_A_MDN(nn.Module): - """Generates a posthoc-corrected MDN which approximates the posterior. - - This class takes as input the density estimator (abbreviated with `_d` suffix, aka - the proposal posterior) and the proposal prior (abbreviated with `_pp` suffix) from - which the simulations were drawn. It uses the algorithm presented in SNPE-A [1] to - compute the approximate posterior (abbreviated with `_p` suffix) from the two. The - approximate posterior is a MoG. This class also implements log-prob calculation - sampling from the approximate posterior. It inherits from `nn.Module` since the - constructor of `DirectPosterior` expects the argument `neural_net` to be a - `nn.Module`. - - [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional - Density Estimation_, Papamakarios et al., NeurIPS 2016, - https://arxiv.org/abs/1605.06376. - """ - - def __init__( - self, - flow: flows.Flow, - proposal: Union["utils.BoxUniform", "MultivariateNormal", "DirectPosterior"], - prior: Distribution, - device: str, - ): - """Constructor. - - Args: - flow: The trained normalizing flow, passed when building the posterior. - proposal: The proposal distribution. - prior: The prior distribution. - """ - # Call nn.Module's constructor. - super().__init__() - - self._neural_net = flow - self._prior = prior - self._device = device - - # Set the proposal using the `default_x`. - if isinstance(proposal, (utils.BoxUniform, MultivariateNormal)): - self._apply_correction = False - else: - self._apply_correction = True - logits_pp, m_pp, prec_pp = proposal.posterior_estimator._posthoc_correction( - proposal.default_x - ) - self._logits_pp, self._m_pp, self._prec_pp = ( - logits_pp.detach(), - m_pp.detach(), - prec_pp.detach(), - ) - - # Take care of z-scoring, pre-compute and store prior terms. - self._set_state_for_mog_proposal() - - def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: - inputs, context = inputs.to(self._device), context.to(self._device) - - if not self._apply_correction: - return self._neural_net.log_prob(inputs, context) - else: - # When we want to compute the approx. posterior, a proposal prior \tilde{p} - # has already been observed. To analytically calculate the log-prob of the - # Gaussian, we first need to compute the mixture components. - - # Compute the mixture components of the proposal posterior. - logits_pp, m_pp, prec_pp = self._posthoc_correction(context) - - # z-score theta if it z-scoring had been requested. - theta = self._maybe_z_score_theta(inputs) - - # Compute the log_prob of theta under the product. - log_prob_proposal_posterior = utils.mog_log_prob( - theta, logits_pp, m_pp, prec_pp - ) - utils.assert_all_finite( - log_prob_proposal_posterior, "proposal posterior eval" - ) - return log_prob_proposal_posterior # \hat{p} from eq (3) in [1] - - def sample(self, num_samples: int, context: Tensor, batch_size: int = 1) -> Tensor: - context = context.to(self._device) - - if not self._apply_correction: - return self._neural_net.sample(num_samples, context, batch_size) - else: - # When we want to sample from the approx. posterior, a proposal prior - # \tilde{p} has already been observed. To analytically calculate the - # log-prob of the Gaussian, we first need to compute the mixture components. - return self._sample_approx_posterior_mog(num_samples, context, batch_size) - - def _sample_approx_posterior_mog( - self, num_samples, x: Tensor, batch_size: int - ) -> Tensor: - r"""Sample from the approximate posterior. - - Args: - num_samples: Desired number of samples. - x: Conditioning context for posterior $p(\theta|x)$. - batch_size: Batch size for sampling. - - Returns: - Samples from the approximate mixture of Gaussians posterior. - """ - - # Compute the mixture components of the posterior. - logits_p, m_p, prec_p = self._posthoc_correction(x) - - # Compute the precision factors which represent the upper triangular matrix - # of the cholesky decomposition of the prec_p. - prec_factors_p = torch.linalg.cholesky(prec_p, upper=True) - - assert logits_p.ndim == 2 - assert m_p.ndim == 3 - assert prec_p.ndim == 4 - assert prec_factors_p.ndim == 4 - - # Replicate to use batched sampling from pyknos. - if batch_size is not None and batch_size > 1: - logits_p = logits_p.repeat(batch_size, 1) - m_p = m_p.repeat(batch_size, 1, 1) - prec_factors_p = prec_factors_p.repeat(batch_size, 1, 1, 1) - - # Get (optionally z-scored) MoG samples. - theta = MultivariateGaussianMDN.sample_mog( - num_samples, logits_p, m_p, prec_factors_p - ) - - embedded_context = self._neural_net._embedding_net(x) - if embedded_context is not None: - # Merge the context dimension with sample dimension in order to - # apply the transform. - theta = torchutils.merge_leading_dims(theta, num_dims=2) - embedded_context = torchutils.repeat_rows( - embedded_context, num_reps=num_samples - ) - - theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) - - if embedded_context is not None: - # Split the context dimension from sample dimension. - theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples]) - - return theta - - def _posthoc_correction(self, x: Tensor): - """ - Compute the mixture components of the posterior given the current density - estimator and the proposal. - - Args: - x: Conditioning context for posterior. - - Returns: - Mixture components of the posterior. - """ - - # Evaluate the density estimator. - encoded_x = self._neural_net._embedding_net(x) - dist = self._neural_net._distribution # defined to avoid black formatting. - logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) - norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) - - # The following if case is needed because, in the constructor, we call - # `_posthoc_correction` regardless of whether the `proposal` itself had a - # `proposal` or not. - if not self._apply_correction: - return norm_logits_d, m_d, prec_d - else: - logits_pp, m_pp, prec_pp = self._logits_pp, self._m_pp, self._prec_pp - - # Compute the MoG parameters of the posterior. - logits_p, m_p, prec_p, cov_p = self._proposal_posterior_transformation( - logits_pp, m_pp, prec_pp, norm_logits_d, m_d, prec_d - ) - return logits_p, m_p, prec_p - - def _proposal_posterior_transformation( - self, - logits_pp: Tensor, - means_pp: Tensor, - precisions_pp: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Transforms the proposal posterior (the MDN) into the posterior. - - The approximate posterior is: - $p(\theta|x) = 1/Z * q(\theta|x) * p(\theta) / prop(\theta)$ - In words: posterior = proposal posterior estimate * prior / proposal. - - Since the proposal posterior estimate and the proposal are MoG, and the - prior is either Gaussian or uniform, we can solve this in closed-form. - - This function implements Appendix C from [1], and is highly similar to - `SNPE_C._automatic_posterior_transformation()`. - - Args: - logits_pp: Component weight of each Gaussian of the proposal prior. - means_pp: Mean of each Gaussian of the proposal prior. - precisions_pp: Precision matrix of each Gaussian of the proposal prior. - logits_d: Component weight for each Gaussian of the density estimator. - means_d: Mean of each Gaussian of the density estimator. - precisions_d: Precision matrix of each Gaussian of the density estimator. - - Returns: (Component weight, mean, precision matrix, covariance matrix) of each - Gaussian of the approximate posterior. - """ - - precisions_post, covariances_post = self._precisions_posterior( - precisions_pp, precisions_d - ) - - means_post = self._means_posterior( - covariances_post, means_pp, precisions_pp, means_d, precisions_d - ) - - logits_post = SNPE_A_MDN._logits_posterior( - means_post, - precisions_post, - covariances_post, - logits_pp, - means_pp, - precisions_pp, - logits_d, - means_d, - precisions_d, - ) - - return logits_post, means_post, precisions_post, covariances_post - - def _set_state_for_mog_proposal(self) -> None: - """ - Set state variables of the SNPE_A_MDN instance every time `set_proposal()` - is called, i.e. every time a posterior is build using - `SNPE_A.build_posterior()`. - - This function is almost identical to `SNPE_C._set_state_for_mog_proposal()`. - - Three things are computed: - 1) Check if z-scoring was requested. To do so, we check if the `_transform` - argument of the net had been a `CompositeTransform`. See pyknos mdn.py. - 2) Define a (potentially standardized) prior. It's standardized if z-scoring - had been requested. - 3) Compute (Precision * mean) for the prior. This quantity is used at every - training step if the prior is Gaussian. - """ - - self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) - - self._set_maybe_z_scored_prior() - - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - self.prec_m_prod_prior = torch.mv( - self._maybe_z_scored_prior.precision_matrix, # type: ignore - self._maybe_z_scored_prior.loc, # type: ignore - ) - - def _set_maybe_z_scored_prior(self) -> None: - r""" - Compute and store potentially standardized prior (if z-scoring was requested). - - This function is highly similar to `SNPE_C._set_maybe_z_scored_prior()`. - - The proposal posterior is: - $p(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ - - Let's denote z-scored theta by `a`: a = (theta - mean) / std - Then $p'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ - - The ' indicates that the evaluation occurs in standardized space. The constant - scaling factor has been absorbed into $Z_2$. - From the above equation, we see that we need to evaluate the prior **in - standardized space**. We build the standardized prior in this function. - - The standardize transform that is applied to the samples theta does not use - the exact prior mean and std (due to implementation issues). Hence, the z-scored - prior will not be exactly have mean=0 and std=1. - """ - if self.z_score_theta: - scale = self._neural_net._transform._transforms[0]._scale - shift = self._neural_net._transform._transforms[0]._shift - - # Following the definition of the linear transform in - # `standardizing_transform` in `sbiutils.py`: - # shift=-mean / std - # scale=1 / std - # Solving these equations for mean and std: - estim_prior_std = 1 / scale - estim_prior_mean = -shift * estim_prior_std - - # Compute the discrepancy of the true prior mean and std and the mean and - # std that was empirically estimated from samples. - # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) - # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean - # and std (estimated from samples and used to build standardize transform). - almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std - almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std - - if isinstance(self._prior, MultivariateNormal): - self._maybe_z_scored_prior = MultivariateNormal( - almost_zero_mean, torch.diag(almost_one_std) - ) - else: - range_ = torch.sqrt(almost_one_std * 3.0) - self._maybe_z_scored_prior = utils.BoxUniform( - almost_zero_mean - range_, almost_zero_mean + range_ - ) - else: - self._maybe_z_scored_prior = self._prior - - def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: - """Return potentially standardized theta if z-scoring was requested.""" - - if self.z_score_theta: - theta, _ = self._neural_net._transform(theta) - - return theta - - def _precisions_posterior(self, precisions_pp: Tensor, precisions_d: Tensor): - r"""Return the precisions and covariances of the MoG posterior. - - As described at the end of Appendix C in [1], it can happen that the - proposal's precision matrix is not positive definite. - - $S_k^\prime = ( S_k^{-1} - S_0^{-1} )^{-1}$ - (see eq (23) in Appendix C of [1]) - - Args: - precisions_pp: Precision matrices of the proposal prior. - precisions_d: Precision matrices of the density estimator. - - Returns: (Precisions, Covariances) of the MoG posterior. - """ - - num_comps_p = precisions_pp.shape[1] - num_comps_d = precisions_d.shape[1] - - # Check if precision matrices are positive definite. - for batches in precisions_pp: - for pprior in batches: - eig_pprior = torch.linalg.eigvalsh(pprior, UPLO="U") - if not (eig_pprior > 0).all(): - raise AssertionError( - "The precision matrix of the proposal is not positive definite!" - ) - for batches in precisions_d: - for d in batches: - eig_d = torch.linalg.eigvalsh(d, UPLO="U") - if not (eig_d > 0).all(): - raise AssertionError( - "The precision matrix of the density estimator is not " - "positive definite!" - ) - - precisions_pp_rep = precisions_pp.repeat_interleave(num_comps_d, dim=1) - precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) - - precisions_p = precisions_d_rep - precisions_pp_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - precisions_p += self._maybe_z_scored_prior.precision_matrix - - # Check if precision matrix is positive definite. - for idx_batch, batches in enumerate(precisions_p): - for idx_comp, pp in enumerate(batches): - eig_pp = torch.symeig(pp, eigenvectors=False).eigenvalues - if not (eig_pp > 0).all(): - raise AssertionError( - "The precision matrix of a posterior is not positive " - "definite! This is a known issue for SNPE-A. Either try a " - "different parameter setting, e.g. a different number of " - "mixture components (when contracting SNPE-A), or a different " - "value for the parameter perturbation (when building the " - "posterior)." - ) - - covariances_p = torch.inverse(precisions_p) - return precisions_p, covariances_p - - def _means_posterior( - self, - covariances_p: Tensor, - means_pp: Tensor, - precisions_pp: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Return the means of the MoG posterior. - - $m_k^\prime = S_k^\prime ( S_k^{-1} m_k - S_0^{-1} m_0 )$ - (see eq (24) in Appendix C of [1]) - - Args: - covariances_post: Covariance matrices of the MoG posterior. - means_pp: Means of the proposal prior. - precisions_pp: Precision matrices of the proposal prior. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Means of the MoG posterior. - """ - - num_comps_pp = precisions_pp.shape[1] - num_comps_d = precisions_d.shape[1] - - # Compute the products P_k * m_k and P_0 * m_0. - prec_m_prod_pp = utils.batched_mixture_mv(precisions_pp, means_pp) - prec_m_prod_d = utils.batched_mixture_mv(precisions_d, means_d) - - # Repeat them to allow for matrix operations: same trick as for the precisions. - prec_m_prod_pp_rep = prec_m_prod_pp.repeat_interleave(num_comps_d, dim=1) - prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_pp, 1) - - # Compute the means P_k^prime * (P_k * m_k - P_0 * m_0). - summed_cov_m_prod_rep = prec_m_prod_d_rep - prec_m_prod_pp_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - summed_cov_m_prod_rep += self.prec_m_prod_prior - - means_p = utils.batched_mixture_mv(covariances_p, summed_cov_m_prod_rep) - return means_p - - @staticmethod - def _logits_posterior( - means_post: Tensor, - precisions_post: Tensor, - covariances_post: Tensor, - logits_pp: Tensor, - means_pp: Tensor, - precisions_pp: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Return the component weights (i.e. logits) of the MoG posterior. - - $\alpha_k^\prime = \frac{ \alpha_k exp(-0.5 c_k) }{ \sum{j} \alpha_j exp(-0.5 - c_j) } $ - with - $c_k = logdet(S_k) - logdet(S_0) - logdet(S_k^\prime) + - + m_k^T P_k m_k - m_0^T P_0 m_0 - m_k^\prime^T P_k^\prime m_k^\prime$ - (see eqs. (25, 26) in Appendix C of [1]) - - Args: - means_post: Means of the posterior. - precisions_post: Precision matrices of the posterior. - covariances_post: Covariance matrices of the posterior. - logits_pp: Component weights (i.e. logits) of the proposal prior. - means_pp: Means of the proposal prior. - precisions_pp: Precision matrices of the proposal prior. - logits_d: Component weights (i.e. logits) of the density estimator. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Component weights of the proposal posterior. - """ - - num_comps_pp = precisions_pp.shape[1] - num_comps_d = precisions_d.shape[1] - - # Compute the ratio of the logits similar to eq (10) in Appendix A.1 of [2] - logits_pp_rep = logits_pp.repeat_interleave(num_comps_d, dim=1) - logits_d_rep = logits_d.repeat(1, num_comps_pp) - logit_factors = logits_d_rep - logits_pp_rep - - # Compute the log-determinants - logdet_covariances_post = torch.logdet(covariances_post) - logdet_covariances_pp = -torch.logdet(precisions_pp) - logdet_covariances_d = -torch.logdet(precisions_d) - - # Repeat the proposal and density estimator terms such that there are LK terms. - # Same trick as has been used above. - logdet_covariances_pp_rep = logdet_covariances_pp.repeat_interleave( - num_comps_d, dim=1 - ) - logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_pp) - - log_sqrt_det_ratio = 0.5 * ( # similar to eq (14) in Appendix A.1 of [2] - logdet_covariances_post - + logdet_covariances_pp_rep - - logdet_covariances_d_rep - ) - - # Compute for proposal, density estimator, and proposal posterior: - exponent_pp = utils.batched_mixture_vmv( - precisions_pp, means_pp # m_0 in eq (26) in Appendix C of [1] - ) - exponent_d = utils.batched_mixture_vmv( - precisions_d, means_d # m_k in eq (26) in Appendix C of [1] - ) - exponent_post = utils.batched_mixture_vmv( - precisions_post, means_post # m_k^\prime in eq (26) in Appendix C of [1] - ) - - # Extend proposal and density estimator exponents to get LK terms. - exponent_pp_rep = exponent_pp.repeat_interleave(num_comps_d, dim=1) - exponent_d_rep = exponent_d.repeat(1, num_comps_pp) - exponent = -0.5 * ( - exponent_d_rep - exponent_pp_rep - exponent_post # eq (26) in [1] - ) - - logits_post = logit_factors + log_sqrt_det_ratio + exponent - return logits_post +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +import warnings +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, Optional, Union + +import torch +import torch.nn as nn +from pyknos.mdn.mdn import MultivariateGaussianMDN +from pyknos.nflows import flows +from pyknos.nflows.transforms import CompositeTransform +from torch import Tensor +from torch.distributions import Distribution, MultivariateNormal + +import sbi.utils as utils +from sbi.inference.posteriors.direct_posterior import DirectPosterior +from sbi.inference.snpe.snpe_base import PosteriorEstimator +from sbi.types import TensorboardSummaryWriter, TorchModule +from sbi.utils import torchutils + + +class SNPE_A(PosteriorEstimator): + def __init__( + self, + prior: Optional[Distribution] = None, + density_estimator: Union[str, Callable] = "mdn_snpe_a", + num_components: int = 10, + device: str = "cpu", + logging_level: Union[int, str] = "WARNING", + summary_writer: Optional[TensorboardSummaryWriter] = None, + show_progress_bars: bool = True, + ): + r"""SNPE-A [1]. + + [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional + Density Estimation_, Papamakarios et al., NeurIPS 2016, + https://arxiv.org/abs/1605.06376. + + This class implements SNPE-A. SNPE-A trains across multiple rounds with a + maximum-likelihood-loss. This will make training converge to the proposal + posterior instead of the true posterior. To correct for this, SNPE-A applies a + post-hoc correction after training. This correction has to be performed + analytically. Thus, SNPE-A is limited to Gaussian distributions for all but the + last round. In the last round, SNPE-A can use a Mixture of Gaussians. + + Args: + prior: A probability distribution that expresses prior knowledge about the + parameters, e.g. which ranges are meaningful for them. Any + object with `.log_prob()`and `.sample()` (for example, a PyTorch + distribution) can be used. + density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a + pre-configured mixture of densities network. Alternatively, a function + that builds a custom neural network can be provided. The function will + be called with the first batch of simulations (theta, x), which can + thus be used for shape inference and potentially for z-scoring. It + needs to return a PyTorch `nn.Module` implementing the density + estimator. The density estimator needs to provide the methods + `.log_prob` and `.sample()`. Note that until the last round only a + single (multivariate) Gaussian component is used for training (see + Algorithm 1 in [1]). In the last round, this component is replicated + `num_components` times, its parameters are perturbed with a very small + noise, and then the last training round is done with the expanded + Gaussian mixture as estimator for the proposal posterior. + num_components: Number of components of the mixture of Gaussians in the + last round. This overrides the `num_components` value passed to + `posterior_nn()`. + device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". + logging_level: Minimum severity of messages to log. One of the strings + INFO, WARNING, DEBUG, ERROR and CRITICAL. + summary_writer: A tensorboard `SummaryWriter` to control, among others, log + file location (default is `/logs`.) + show_progress_bars: Whether to show a progressbar during training. + """ + + # Catch invalid inputs. + if not ((density_estimator == "mdn_snpe_a") or callable(density_estimator)): + raise TypeError( + "The `density_estimator` passed to SNPE_A needs to be a " + "callable or the string 'mdn_snpe_a'!" + ) + + # `num_components` will be used to replicate the Gaussian in the last round. + self._num_components = num_components + self._ran_final_round = False + + # WARNING: sneaky trick ahead. We proxy the parent's `train` here, + # requiring the signature to have `num_atoms`, save it for use below, and + # continue. It's sneaky because we are using the object (self) as a namespace + # to pass arguments between functions, and that's implicit state management. + kwargs = utils.del_entries( + locals(), + entries=("self", "__class__", "num_components"), + ) + super().__init__(**kwargs) + + def train( + self, + final_round: bool = False, + training_batch_size: int = 50, + learning_rate: float = 5e-4, + validation_fraction: float = 0.1, + stop_after_epochs: int = 20, + max_num_epochs: int = 2**31 - 1, + clip_max_norm: Optional[float] = 5.0, + calibration_kernel: Optional[Callable] = None, + resume_training: bool = False, + force_first_round_loss: bool = False, + retrain_from_scratch: bool = False, + show_train_summary: bool = False, + dataloader_kwargs: Optional[Dict] = None, + component_perturbation: float = 5e-3, + ) -> nn.Module: + r"""Return density estimator that approximates the proposal posterior. + + [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional + Density Estimation_, Papamakarios et al., NeurIPS 2016, + https://arxiv.org/abs/1605.06376. + + Training is performed with maximum likelihood on samples from the latest round, + which leads the algorithm to converge to the proposal posterior. + + Args: + final_round: Whether we are in the last round of training or not. For all + but the last round, Algorithm 1 from [1] is executed. In last the + round, Algorithm 2 from [1] is executed once. + training_batch_size: Training batch size. + learning_rate: Learning rate for Adam optimizer. + validation_fraction: The fraction of data to use for validation. + stop_after_epochs: The number of epochs to wait for improvement on the + validation set before terminating training. + max_num_epochs: Maximum number of epochs to run. If reached, we stop + training even when the validation loss is still decreasing. Otherwise, + we train until validation loss increases (see also `stop_after_epochs`). + clip_max_norm: Value at which to clip the total gradient norm in order to + prevent exploding gradients. Use None for no clipping. + calibration_kernel: A function to calibrate the loss with respect to the + simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. + resume_training: Can be used in case training time is limited, e.g. on a + cluster. If `True`, the split between train and validation set, the + optimizer, the number of epochs, and the best validation log-prob will + be restored from the last time `.train()` was called. + force_first_round_loss: If `True`, train with maximum likelihood, + i.e., potentially ignoring the correction for using a proposal + distribution different from the prior. + force_first_round_loss: If `True`, train with maximum likelihood, + regardless of the proposal distribution. + retrain_from_scratch: Whether to retrain the conditional density + estimator for the posterior from scratch each round. Not supported for + SNPE-A. + show_train_summary: Whether to print the number of epochs and validation + loss and leakage after the training. + dataloader_kwargs: Additional or updated kwargs to be passed to the training + and validation dataloaders (like, e.g., a collate_fn) + component_perturbation: The standard deviation applied to all weights and + biases when, in the last round, the Mixture of Gaussians is build from + a single Gaussian. This value can be problem-specific and also depends + on the number of mixture components. + + Returns: + Density estimator that approximates the distribution $p(\theta|x)$. + """ + + assert not retrain_from_scratch, """Retraining from scratch is not supported in SNPE-A yet. The reason for + this is that, if we reininitialized the density estimator, the z-scoring would + change, which would break the posthoc correction. This is a pure implementation + issue.""" + + kwargs = utils.del_entries( + locals(), + entries=("self", "__class__", "final_round", "component_perturbation"), + ) + + # SNPE-A always discards the prior samples. + kwargs["discard_prior_samples"] = True + + self._round = max(self._data_round_index) + + if final_round: + # If there is (will be) only one round, train with Algorithm 2 from [1]. + if self._round == 0: + self._build_neural_net = partial( + self._build_neural_net, num_components=self._num_components + ) + # Run Algorithm 2 from [1]. + elif not self._ran_final_round: + # Now switch to the specified number of components. This method will + # only be used if `retrain_from_scratch=True`. Otherwise, + # the MDN will be built from replicating the single-component net for + # `num_component` times (via `_expand_mog()`). + self._build_neural_net = partial( + self._build_neural_net, num_components=self._num_components + ) + + # Extend the MDN to the originally desired number of components. + self._expand_mog(eps=component_perturbation) + else: + warnings.warn( + "You have already run SNPE-A with `final_round=True`. Running it" + "again with this setting will not allow computing the posthoc" + "correction applied in SNPE-A. Thus, you will get an error when " + "calling `.build_posterior()` after training.", + UserWarning, + ) + else: + # Run Algorithm 1 from [1]. + # Wrap the function that builds the MDN such that we can make + # sure that there is only one component when running. + self._build_neural_net = partial(self._build_neural_net, num_components=1) + + if final_round: + self._ran_final_round = True + + return super().train(**kwargs) + + def correct_for_proposal( + self, + density_estimator: Optional[TorchModule] = None, + ) -> "SNPE_A_MDN": + r"""Build mixture of Gaussians that approximates the posterior. + + Returns a `SNPE_A_MDN` object, which applies the posthoc-correction required in + SNPE-A. + + Args: + density_estimator: The density estimator that the posterior is based on. + If `None`, use the latest neural density estimator that was trained. + + Returns: + Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. + """ + if density_estimator is None: + density_estimator = deepcopy( + self._neural_net + ) # PosteriorEstimator.train() also returns a deepcopy, mimic this here + # If internal net is used device is defined. + device = self._device + else: + # Otherwise, infer it from the device of the net parameters. + device = str(next(density_estimator.parameters()).device) + + # Set proposal of the density estimator. + # This also evokes the z-scoring correction if necessary. + if ( + self._proposal_roundwise[-1] is self._prior + or self._proposal_roundwise[-1] is None + ): + proposal = self._prior + assert isinstance( + proposal, (MultivariateNormal, utils.BoxUniform) + ), """Prior must be `torch.distributions.MultivariateNormal` or `sbi.utils. + BoxUniform`""" + else: + assert isinstance( + self._proposal_roundwise[-1], DirectPosterior + ), """The proposal you passed to `append_simulations` is neither the prior + nor a `DirectPosterior`. SNPE-A currently only supports these scenarios. + """ + proposal = self._proposal_roundwise[-1] + + # Create the SNPE_A_MDN + wrapped_density_estimator = SNPE_A_MDN( + flow=density_estimator, proposal=proposal, prior=self._prior, device=device + ) + return wrapped_density_estimator + + def build_posterior( + self, + density_estimator: Optional[TorchModule] = None, + prior: Optional[Distribution] = None, + ) -> "DirectPosterior": + r"""Build posterior from the neural density estimator. + + This method first corrects the estimated density with `correct_for_proposal` + and then returns a `DirectPosterior`. + + Args: + density_estimator: The density estimator that the posterior is based on. + If `None`, use the latest neural density estimator that was trained. + prior: Prior distribution. + + Returns: + Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. + """ + if prior is None: + assert ( + self._prior is not None + ), """You did not pass a prior. You have to pass the prior either at + initialization `inference = SNPE_A(prior)` or to `.build_posterior + (prior=prior)`.""" + prior = self._prior + + wrapped_density_estimator = self.correct_for_proposal( + density_estimator=density_estimator + ) + self._posterior = DirectPosterior( + posterior_estimator=wrapped_density_estimator, + prior=prior, + ) + return deepcopy(self._posterior) + + def _log_prob_proposal_posterior( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + ) -> Tensor: + """Return the log-probability of the proposal posterior. + + For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in + `_loss()` to be found in `snpe_base.py`. + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + proposal: Proposal distribution. + + Returns: Log-probability of the proposal posterior. + """ + return self._neural_net.log_prob(theta, x) + + def _expand_mog(self, eps: float = 1e-5): + """ + Replicate a singe Gaussian trained with Algorithm 1 before continuing + with Algorithm 2. The weights and biases of the associated MDN layers + are repeated `num_components` times, slightly perturbed to break the + symmetry such that the gradients in the subsequent training are not + all identical. + + Args: + eps: Standard deviation for the random perturbation. + """ + assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN) + + # Increase the number of components + self._neural_net._distribution._num_components = self._num_components + + # Expand the 1-dim Gaussian. + for name, param in self._neural_net.named_parameters(): + if any( + key in name for key in ["logits", "means", "unconstrained", "upper"] + ): + if "bias" in name: + param.data = param.data.repeat(self._num_components) + param.data.add_(torch.randn_like(param.data) * eps) + param.grad = None # let autograd construct a new gradient + elif "weight" in name: + param.data = param.data.repeat(self._num_components, 1) + param.data.add_(torch.randn_like(param.data) * eps) + param.grad = None # let autograd construct a new gradient + + +class SNPE_A_MDN(nn.Module): + """Generates a posthoc-corrected MDN which approximates the posterior. + + This class takes as input the density estimator (abbreviated with `_d` suffix, aka + the proposal posterior) and the proposal prior (abbreviated with `_pp` suffix) from + which the simulations were drawn. It uses the algorithm presented in SNPE-A [1] to + compute the approximate posterior (abbreviated with `_p` suffix) from the two. The + approximate posterior is a MoG. This class also implements log-prob calculation + sampling from the approximate posterior. It inherits from `nn.Module` since the + constructor of `DirectPosterior` expects the argument `neural_net` to be a + `nn.Module`. + + [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional + Density Estimation_, Papamakarios et al., NeurIPS 2016, + https://arxiv.org/abs/1605.06376. + """ + + def __init__( + self, + flow: flows.Flow, + proposal: Union["utils.BoxUniform", "MultivariateNormal", "DirectPosterior"], + prior: Distribution, + device: str, + ): + """Constructor. + + Args: + flow: The trained normalizing flow, passed when building the posterior. + proposal: The proposal distribution. + prior: The prior distribution. + """ + # Call nn.Module's constructor. + super().__init__() + + self._neural_net = flow + self._prior = prior + self._device = device + + # Set the proposal using the `default_x`. + if isinstance(proposal, (utils.BoxUniform, MultivariateNormal)): + self._apply_correction = False + else: + self._apply_correction = True + logits_pp, m_pp, prec_pp = proposal.posterior_estimator._posthoc_correction( + proposal.default_x + ) + self._logits_pp, self._m_pp, self._prec_pp = ( + logits_pp.detach(), + m_pp.detach(), + prec_pp.detach(), + ) + + # Take care of z-scoring, pre-compute and store prior terms. + self._set_state_for_mog_proposal() + + def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: + inputs, context = inputs.to(self._device), context.to(self._device) + + if not self._apply_correction: + return self._neural_net.log_prob(inputs, context) + else: + # When we want to compute the approx. posterior, a proposal prior \tilde{p} + # has already been observed. To analytically calculate the log-prob of the + # Gaussian, we first need to compute the mixture components. + + # Compute the mixture components of the proposal posterior. + logits_pp, m_pp, prec_pp = self._posthoc_correction(context) + + # z-score theta if it z-scoring had been requested. + theta = self._maybe_z_score_theta(inputs) + + # Compute the log_prob of theta under the product. + log_prob_proposal_posterior = utils.mog_log_prob( + theta, logits_pp, m_pp, prec_pp + ) + utils.assert_all_finite( + log_prob_proposal_posterior, "proposal posterior eval" + ) + return log_prob_proposal_posterior # \hat{p} from eq (3) in [1] + + def sample(self, num_samples: int, context: Tensor, batch_size: int = 1) -> Tensor: + context = context.to(self._device) + + if not self._apply_correction: + return self._neural_net.sample(num_samples, context, batch_size) + else: + # When we want to sample from the approx. posterior, a proposal prior + # \tilde{p} has already been observed. To analytically calculate the + # log-prob of the Gaussian, we first need to compute the mixture components. + return self._sample_approx_posterior_mog(num_samples, context, batch_size) + + def _sample_approx_posterior_mog( + self, num_samples, x: Tensor, batch_size: int + ) -> Tensor: + r"""Sample from the approximate posterior. + + Args: + num_samples: Desired number of samples. + x: Conditioning context for posterior $p(\theta|x)$. + batch_size: Batch size for sampling. + + Returns: + Samples from the approximate mixture of Gaussians posterior. + """ + + # Compute the mixture components of the posterior. + logits_p, m_p, prec_p = self._posthoc_correction(x) + + # Compute the precision factors which represent the upper triangular matrix + # of the cholesky decomposition of the prec_p. + prec_factors_p = torch.linalg.cholesky(prec_p, upper=True) + + assert logits_p.ndim == 2 + assert m_p.ndim == 3 + assert prec_p.ndim == 4 + assert prec_factors_p.ndim == 4 + + # Replicate to use batched sampling from pyknos. + if batch_size is not None and batch_size > 1: + logits_p = logits_p.repeat(batch_size, 1) + m_p = m_p.repeat(batch_size, 1, 1) + prec_factors_p = prec_factors_p.repeat(batch_size, 1, 1, 1) + + # Get (optionally z-scored) MoG samples. + theta = MultivariateGaussianMDN.sample_mog( + num_samples, logits_p, m_p, prec_factors_p + ) + + embedded_context = self._neural_net._embedding_net(x) + if embedded_context is not None: + # Merge the context dimension with sample dimension in order to + # apply the transform. + theta = torchutils.merge_leading_dims(theta, num_dims=2) + embedded_context = torchutils.repeat_rows( + embedded_context, num_reps=num_samples + ) + + theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) + + if embedded_context is not None: + # Split the context dimension from sample dimension. + theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples]) + + return theta + + def _posthoc_correction(self, x: Tensor): + """ + Compute the mixture components of the posterior given the current density + estimator and the proposal. + + Args: + x: Conditioning context for posterior. + + Returns: + Mixture components of the posterior. + """ + + # Evaluate the density estimator. + encoded_x = self._neural_net._embedding_net(x) + dist = self._neural_net._distribution # defined to avoid black formatting. + logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) + norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) + + # The following if case is needed because, in the constructor, we call + # `_posthoc_correction` regardless of whether the `proposal` itself had a + # `proposal` or not. + if not self._apply_correction: + return norm_logits_d, m_d, prec_d + else: + logits_pp, m_pp, prec_pp = self._logits_pp, self._m_pp, self._prec_pp + + # Compute the MoG parameters of the posterior. + logits_p, m_p, prec_p, cov_p = self._proposal_posterior_transformation( + logits_pp, m_pp, prec_pp, norm_logits_d, m_d, prec_d + ) + return logits_p, m_p, prec_p + + def _proposal_posterior_transformation( + self, + logits_pp: Tensor, + means_pp: Tensor, + precisions_pp: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Transforms the proposal posterior (the MDN) into the posterior. + + The approximate posterior is: + $p(\theta|x) = 1/Z * q(\theta|x) * p(\theta) / prop(\theta)$ + In words: posterior = proposal posterior estimate * prior / proposal. + + Since the proposal posterior estimate and the proposal are MoG, and the + prior is either Gaussian or uniform, we can solve this in closed-form. + + This function implements Appendix C from [1], and is highly similar to + `SNPE_C._automatic_posterior_transformation()`. + + Args: + logits_pp: Component weight of each Gaussian of the proposal prior. + means_pp: Mean of each Gaussian of the proposal prior. + precisions_pp: Precision matrix of each Gaussian of the proposal prior. + logits_d: Component weight for each Gaussian of the density estimator. + means_d: Mean of each Gaussian of the density estimator. + precisions_d: Precision matrix of each Gaussian of the density estimator. + + Returns: (Component weight, mean, precision matrix, covariance matrix) of each + Gaussian of the approximate posterior. + """ + + precisions_post, covariances_post = self._precisions_posterior( + precisions_pp, precisions_d + ) + + means_post = self._means_posterior( + covariances_post, means_pp, precisions_pp, means_d, precisions_d + ) + + logits_post = SNPE_A_MDN._logits_posterior( + means_post, + precisions_post, + covariances_post, + logits_pp, + means_pp, + precisions_pp, + logits_d, + means_d, + precisions_d, + ) + + return logits_post, means_post, precisions_post, covariances_post + + def _set_state_for_mog_proposal(self) -> None: + """ + Set state variables of the SNPE_A_MDN instance every time `set_proposal()` + is called, i.e. every time a posterior is build using + `SNPE_A.build_posterior()`. + + This function is almost identical to `SNPE_C._set_state_for_mog_proposal()`. + + Three things are computed: + 1) Check if z-scoring was requested. To do so, we check if the `_transform` + argument of the net had been a `CompositeTransform`. See pyknos mdn.py. + 2) Define a (potentially standardized) prior. It's standardized if z-scoring + had been requested. + 3) Compute (Precision * mean) for the prior. This quantity is used at every + training step if the prior is Gaussian. + """ + + self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) + + self._set_maybe_z_scored_prior() + + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + self.prec_m_prod_prior = torch.mv( + self._maybe_z_scored_prior.precision_matrix, # type: ignore + self._maybe_z_scored_prior.loc, # type: ignore + ) + + def _set_maybe_z_scored_prior(self) -> None: + r""" + Compute and store potentially standardized prior (if z-scoring was requested). + + This function is highly similar to `SNPE_C._set_maybe_z_scored_prior()`. + + The proposal posterior is: + $p(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ + + Let's denote z-scored theta by `a`: a = (theta - mean) / std + Then $p'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ + + The ' indicates that the evaluation occurs in standardized space. The constant + scaling factor has been absorbed into $Z_2$. + From the above equation, we see that we need to evaluate the prior **in + standardized space**. We build the standardized prior in this function. + + The standardize transform that is applied to the samples theta does not use + the exact prior mean and std (due to implementation issues). Hence, the z-scored + prior will not be exactly have mean=0 and std=1. + """ + if self.z_score_theta: + scale = self._neural_net._transform._transforms[0]._scale + shift = self._neural_net._transform._transforms[0]._shift + + # Following the definition of the linear transform in + # `standardizing_transform` in `sbiutils.py`: + # shift=-mean / std + # scale=1 / std + # Solving these equations for mean and std: + estim_prior_std = 1 / scale + estim_prior_mean = -shift * estim_prior_std + + # Compute the discrepancy of the true prior mean and std and the mean and + # std that was empirically estimated from samples. + # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) + # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean + # and std (estimated from samples and used to build standardize transform). + almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std + almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std + + if isinstance(self._prior, MultivariateNormal): + self._maybe_z_scored_prior = MultivariateNormal( + almost_zero_mean, torch.diag(almost_one_std) + ) + else: + range_ = torch.sqrt(almost_one_std * 3.0) + self._maybe_z_scored_prior = utils.BoxUniform( + almost_zero_mean - range_, almost_zero_mean + range_ + ) + else: + self._maybe_z_scored_prior = self._prior + + def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: + """Return potentially standardized theta if z-scoring was requested.""" + + if self.z_score_theta: + theta, _ = self._neural_net._transform(theta) + + return theta + + def _precisions_posterior(self, precisions_pp: Tensor, precisions_d: Tensor): + r"""Return the precisions and covariances of the MoG posterior. + + As described at the end of Appendix C in [1], it can happen that the + proposal's precision matrix is not positive definite. + + $S_k^\prime = ( S_k^{-1} - S_0^{-1} )^{-1}$ + (see eq (23) in Appendix C of [1]) + + Args: + precisions_pp: Precision matrices of the proposal prior. + precisions_d: Precision matrices of the density estimator. + + Returns: (Precisions, Covariances) of the MoG posterior. + """ + + num_comps_p = precisions_pp.shape[1] + num_comps_d = precisions_d.shape[1] + + # Check if precision matrices are positive definite. + for batches in precisions_pp: + for pprior in batches: + eig_pprior = torch.linalg.eigvalsh(pprior, UPLO="U") + if not (eig_pprior > 0).all(): + raise AssertionError( + "The precision matrix of the proposal is not positive definite!" + ) + for batches in precisions_d: + for d in batches: + eig_d = torch.linalg.eigvalsh(d, UPLO="U") + if not (eig_d > 0).all(): + raise AssertionError( + "The precision matrix of the density estimator is not " + "positive definite!" + ) + + precisions_pp_rep = precisions_pp.repeat_interleave(num_comps_d, dim=1) + precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) + + precisions_p = precisions_d_rep - precisions_pp_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + precisions_p += self._maybe_z_scored_prior.precision_matrix + + # Check if precision matrix is positive definite. + for idx_batch, batches in enumerate(precisions_p): + for idx_comp, pp in enumerate(batches): + eig_pp = torch.symeig(pp, eigenvectors=False).eigenvalues + if not (eig_pp > 0).all(): + raise AssertionError( + "The precision matrix of a posterior is not positive " + "definite! This is a known issue for SNPE-A. Either try a " + "different parameter setting, e.g. a different number of " + "mixture components (when contracting SNPE-A), or a different " + "value for the parameter perturbation (when building the " + "posterior)." + ) + + covariances_p = torch.inverse(precisions_p) + return precisions_p, covariances_p + + def _means_posterior( + self, + covariances_p: Tensor, + means_pp: Tensor, + precisions_pp: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Return the means of the MoG posterior. + + $m_k^\prime = S_k^\prime ( S_k^{-1} m_k - S_0^{-1} m_0 )$ + (see eq (24) in Appendix C of [1]) + + Args: + covariances_post: Covariance matrices of the MoG posterior. + means_pp: Means of the proposal prior. + precisions_pp: Precision matrices of the proposal prior. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Means of the MoG posterior. + """ + + num_comps_pp = precisions_pp.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute the products P_k * m_k and P_0 * m_0. + prec_m_prod_pp = utils.batched_mixture_mv(precisions_pp, means_pp) + prec_m_prod_d = utils.batched_mixture_mv(precisions_d, means_d) + + # Repeat them to allow for matrix operations: same trick as for the precisions. + prec_m_prod_pp_rep = prec_m_prod_pp.repeat_interleave(num_comps_d, dim=1) + prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_pp, 1) + + # Compute the means P_k^prime * (P_k * m_k - P_0 * m_0). + summed_cov_m_prod_rep = prec_m_prod_d_rep - prec_m_prod_pp_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + summed_cov_m_prod_rep += self.prec_m_prod_prior + + means_p = utils.batched_mixture_mv(covariances_p, summed_cov_m_prod_rep) + return means_p + + @staticmethod + def _logits_posterior( + means_post: Tensor, + precisions_post: Tensor, + covariances_post: Tensor, + logits_pp: Tensor, + means_pp: Tensor, + precisions_pp: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Return the component weights (i.e. logits) of the MoG posterior. + + $\alpha_k^\prime = \frac{ \alpha_k exp(-0.5 c_k) }{ \sum{j} \alpha_j exp(-0.5 + c_j) } $ + with + $c_k = logdet(S_k) - logdet(S_0) - logdet(S_k^\prime) + + + m_k^T P_k m_k - m_0^T P_0 m_0 - m_k^\prime^T P_k^\prime m_k^\prime$ + (see eqs. (25, 26) in Appendix C of [1]) + + Args: + means_post: Means of the posterior. + precisions_post: Precision matrices of the posterior. + covariances_post: Covariance matrices of the posterior. + logits_pp: Component weights (i.e. logits) of the proposal prior. + means_pp: Means of the proposal prior. + precisions_pp: Precision matrices of the proposal prior. + logits_d: Component weights (i.e. logits) of the density estimator. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Component weights of the proposal posterior. + """ + + num_comps_pp = precisions_pp.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute the ratio of the logits similar to eq (10) in Appendix A.1 of [2] + logits_pp_rep = logits_pp.repeat_interleave(num_comps_d, dim=1) + logits_d_rep = logits_d.repeat(1, num_comps_pp) + logit_factors = logits_d_rep - logits_pp_rep + + # Compute the log-determinants + logdet_covariances_post = torch.logdet(covariances_post) + logdet_covariances_pp = -torch.logdet(precisions_pp) + logdet_covariances_d = -torch.logdet(precisions_d) + + # Repeat the proposal and density estimator terms such that there are LK terms. + # Same trick as has been used above. + logdet_covariances_pp_rep = logdet_covariances_pp.repeat_interleave( + num_comps_d, dim=1 + ) + logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_pp) + + log_sqrt_det_ratio = 0.5 * ( # similar to eq (14) in Appendix A.1 of [2] + logdet_covariances_post + + logdet_covariances_pp_rep + - logdet_covariances_d_rep + ) + + # Compute for proposal, density estimator, and proposal posterior: + exponent_pp = utils.batched_mixture_vmv( + precisions_pp, means_pp # m_0 in eq (26) in Appendix C of [1] + ) + exponent_d = utils.batched_mixture_vmv( + precisions_d, means_d # m_k in eq (26) in Appendix C of [1] + ) + exponent_post = utils.batched_mixture_vmv( + precisions_post, means_post # m_k^\prime in eq (26) in Appendix C of [1] + ) + + # Extend proposal and density estimator exponents to get LK terms. + exponent_pp_rep = exponent_pp.repeat_interleave(num_comps_d, dim=1) + exponent_d_rep = exponent_d.repeat(1, num_comps_pp) + exponent = -0.5 * ( + exponent_d_rep - exponent_pp_rep - exponent_post # eq (26) in [1] + ) + + logits_post = logit_factors + log_sqrt_det_ratio + exponent + return logits_post diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 26c862282..1a56f9c75 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -1,594 +1,597 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . -import time -from abc import ABC, abstractmethod -from copy import deepcopy -from typing import Any, Callable, Dict, Optional, Union -from warnings import warn - -import torch -from torch import Tensor, nn, ones, optim -from torch.distributions import Distribution -from torch.nn.utils.clip_grad import clip_grad_norm_ -from torch.utils import data -from torch.utils.tensorboard.writer import SummaryWriter - -from sbi import utils as utils -from sbi.inference import NeuralInference, check_if_proposal_has_default_x -from sbi.inference.posteriors import ( - DirectPosterior, - MCMCPosterior, - RejectionPosterior, - VIPosterior, -) -from sbi.inference.posteriors.base_posterior import NeuralPosterior -from sbi.inference.potentials import posterior_estimator_based_potential -from sbi.utils import ( - RestrictedPrior, - check_estimator_arg, - test_posterior_net_for_multi_d_x, - validate_theta_and_x, - x_shape_from_simulation, - handle_invalid_x, - warn_if_zscoring_changes_data, - warn_on_invalid_x, - warn_on_invalid_x_for_snpec_leakage, -) -from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior - - -class PosteriorEstimator(NeuralInference, ABC): - def __init__( - self, - prior: Optional[Distribution] = None, - density_estimator: Union[str, Callable] = "maf", - device: str = "cpu", - logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[SummaryWriter] = None, - show_progress_bars: bool = True, - ): - """Base class for Sequential Neural Posterior Estimation methods. - - Args: - density_estimator: If it is a string, use a pre-configured network of the - provided type (one of nsf, maf, mdn, made). Alternatively, a function - that builds a custom neural network can be provided. The function will - be called with the first batch of simulations (theta, x), which can - thus be used for shape inference and potentially for z-scoring. It - needs to return a PyTorch `nn.Module` implementing the density - estimator. The density estimator needs to provide the methods - `.log_prob` and `.sample()`. - - See docstring of `NeuralInference` class for all other arguments. - """ - - super().__init__( - prior=prior, - device=device, - logging_level=logging_level, - summary_writer=summary_writer, - show_progress_bars=show_progress_bars, - ) - - # As detailed in the docstring, `density_estimator` is either a string or - # a callable. The function creating the neural network is attached to - # `_build_neural_net`. It will be called in the first round and receive - # thetas and xs as inputs, so that they can be used for shape inference and - # potentially for z-scoring. - check_estimator_arg(density_estimator) - if isinstance(density_estimator, str): - self._build_neural_net = utils.posterior_nn(model=density_estimator) - else: - self._build_neural_net = density_estimator - - self._proposal_roundwise = [] - self.use_non_atomic_loss = False - - # Extra SNPE-specific fields summary_writer. - self._summary.update({"rejection_sampling_acceptance_rates": []}) # type:ignore - - def append_simulations( - self, - theta: Tensor, - x: Tensor, - proposal: Optional[DirectPosterior] = None, - exclude_invalid_x: bool = True, - warn_on_invalid: bool = True, - warn_if_zscoring: bool = True, - return_self: bool = True, - data_device: str = None - ) -> "PosteriorEstimator": - r"""Store parameters and simulation outputs to use them for later training. - - Data are stored as entries in lists for each type of variable (parameter/data). - - Stores $\theta$, $x$, prior_masks (indicating if simulations are coming from the - prior or not) and an index indicating which round the batch of simulations came - from. - - Args: - theta: Parameter sets. - x: Simulation outputs. - proposal: The distribution that the parameters $\theta$ were sampled from. - Pass `None` if the parameters were sampled from the prior. If not - `None`, it will trigger a different loss-function. - - Returns: - NeuralInference object (returned so that this function is chainable). - """ - - # Add ability to specify device data is saved on - if data_device is None: data_device = self._device - theta, x = validate_theta_and_x(theta, x, training_device=data_device) - - - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) - - # Check for problematic z-scoring - if warn_if_zscoring: - warn_if_zscoring_changes_data(x[is_valid_x]) - if warn_on_invalid: - warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) - warn_on_invalid_x_for_snpec_leakage( - num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round - ) - - x = x[is_valid_x] - theta = theta[is_valid_x] - - - self._check_proposal(proposal) - - if ( - proposal is None - or proposal is self._prior - or ( - isinstance(proposal, RestrictedPrior) and proposal._prior is self._prior - ) - ): - # The `_data_round_index` will later be used to infer if one should train - # with MLE loss or with atomic loss (see, in `train()`: - # self._round = max(self._data_round_index)) - self._data_round_index.append(0) - prior_masks = mask_sims_from_prior(0, theta.size(0)) - else: - if not self._data_round_index: - # This catches a pretty specific case: if, in the first round, one - # passes data that does not come from the prior. - self._data_round_index.append(1) - else: - self._data_round_index.append(max(self._data_round_index) + 1) - prior_masks = mask_sims_from_prior(1, theta.size(0)) - - - if self._dataset is None: - #If first round, set up ConcatDataset - self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) - else: - #Otherwise append to Dataset - self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) - - self._num_sims_per_round.append(theta.size(0)) - self._proposal_roundwise.append(proposal) - - if self._prior is None or isinstance(self._prior, ImproperEmpirical): - if proposal is not None: - raise ValueError( - "You had not passed a prior at initialization, but now you " - "passed a proposal. If you want to run multi-round SNPE, you have " - "to specify a prior (set the `.prior` argument or re-initialize " - "the object with a prior distribution). If the samples you passed " - "to `append_simulations()` were sampled from the prior, you can " - "run single-round inference with " - "`append_simulations(..., proposal=None)`." - ) - theta_prior = self.get_simulations()[0] - self._prior = ImproperEmpirical(theta_prior, ones(theta_prior.shape[0])) - - #Add ability to not return self - if return_self: - return self - else: - return 1 - - def train( - self, - training_batch_size: int = 50, - learning_rate: float = 5e-4, - validation_fraction: float = 0.1, - stop_after_epochs: int = 20, - max_num_epochs: int = 2**31 - 1, - clip_max_norm: Optional[float] = 5.0, - calibration_kernel: Optional[Callable] = None, - exclude_invalid_x: bool = True, - resume_training: bool = False, - force_first_round_loss: bool = False, - discard_prior_samples: bool = False, - retrain_from_scratch: bool = False, - show_train_summary: bool = False, - warn_if_zscoring: Optional[bool] = True, - dataloader_kwargs: Optional[dict] = None, - ) -> nn.Module: - r"""Return density estimator that approximates the distribution $p(\theta|x)$. - - Args: - training_batch_size: Training batch size. - learning_rate: Learning rate for Adam optimizer. - validation_fraction: The fraction of data to use for validation. - stop_after_epochs: The number of epochs to wait for improvement on the - validation set before terminating training. - max_num_epochs: Maximum number of epochs to run. If reached, we stop - training even when the validation loss is still decreasing. Otherwise, - we train until validation loss increases (see also `stop_after_epochs`). - clip_max_norm: Value at which to clip the total gradient norm in order to - prevent exploding gradients. Use None for no clipping. - calibration_kernel: A function to calibrate the loss with respect to the - simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. - resume_training: Can be used in case training time is limited, e.g. on a - cluster. If `True`, the split between train and validation set, the - optimizer, the number of epochs, and the best validation log-prob will - be restored from the last time `.train()` was called. - force_first_round_loss: If `True`, train with maximum likelihood, - i.e., potentially ignoring the correction for using a proposal - distribution different from the prior. - discard_prior_samples: Whether to discard samples simulated in round 1, i.e. - from the prior. Training may be sped up by ignoring such less targeted - samples. - retrain_from_scratch: Whether to retrain the conditional density - estimator for the posterior from scratch each round. - show_train_summary: Whether to print the number of epochs and validation - loss after the training. - dataloader_kwargs: Additional or updated kwargs to be passed to the training - and validation dataloaders (like, e.g., a collate_fn) - - Returns: - Density estimator that approximates the distribution $p(\theta|x)$. - """ - if self._round == 0 and self._neural_net is not None: - assert force_first_round_loss, ( - "You have already trained this neural network. After you had trained " - "the network, you again appended simulations with `append_simulations" - "(theta, x)`, but you did not provide a proposal. If the new " - "simulations are sampled from the prior, you can set " - "`.train(..., force_first_round_loss=True`). However, if the new " - "simulations were not sampled from the prior, you should pass the " - "proposal, i.e. `append_simulations(theta, x, proposal)`. If " - "your samples are not sampled from the prior and you do not pass a " - "proposal and you set `force_first_round_loss=True`, the result of " - "SNPE will not be the true posterior. Instead, it will be the proposal " - "posterior, which (usually) is more narrow than the true posterior." - ) - - # Calibration kernels proposed in Lueckmann, Gonçalves et al., 2017. - if calibration_kernel is None: - calibration_kernel = lambda x: ones([len(x)], device=self._device) - - # Starting index for the training set (1 = discard round-0 samples). - start_idx = int(discard_prior_samples and self._round > 0) - - # For non-atomic loss, we can not reuse samples from previous rounds as of now. - # SNPE-A can, by construction of the algorithm, only use samples from the last - # round. SNPE-A is the only algorithm that has an attribute `_ran_final_round`, - # so this is how we check for whether or not we are using SNPE-A. - if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"): - start_idx = self._round - - # Set the proposal to the last proposal that was passed by the user. For - # atomic SNPE, it does not matter what the proposal is. For non-atomic - # SNPE, we only use the latest data that was passed, i.e. the one from the - # last proposal. - proposal = self._proposal_roundwise[-1] - - train_loader, val_loader = self.get_dataloaders_all( - start_idx, - training_batch_size, - validation_fraction, - resume_training, - dataloader_kwargs=dataloader_kwargs, - ) - # First round or if retraining from scratch: - # Call the `self._build_neural_net` with the rounds' thetas and xs as - # arguments, which will build the neural network. - # This is passed into NeuralPosterior, to create a neural posterior which - # can `sample()` and `log_prob()`. The network is accessible via `.net`. - if self._neural_net is None or retrain_from_scratch: - - #Get theta,x from dataset to initialize NN - test_theta = self._dataset.datasets[0].tensors[0][:100] - test_x = self._dataset.datasets[0].tensors[1][:100] - - self._neural_net = self._build_neural_net( - test_theta, test_x - ) - # If data on training device already move net as well. - if ( - not self._device == "cpu" - and f"{test_x.device.type}:{test_x.device.index}" == self._device - ): - self._neural_net.to(self._device) - - test_posterior_net_for_multi_d_x(self._neural_net, test_theta, test_x) - self._x_shape = x_shape_from_simulation(test_x) - - # Move entire net to device for training. - self._neural_net.to(self._device) - - if not resume_training: - self.optimizer = optim.Adam( - list(self._neural_net.parameters()), lr=learning_rate - ) - self.epoch, self._val_log_prob = 0, float("-Inf") - - while self.epoch <= max_num_epochs and not self._converged( - self.epoch, stop_after_epochs - ): - - # Train for a single epoch. - self._neural_net.train() - train_log_probs_sum = 0 - epoch_start_time = time.time() - for batch in train_loader: - self.optimizer.zero_grad() - # Get batches on current device. - theta_batch, x_batch, masks_batch = ( - batch[0].to(self._device), - batch[1].to(self._device), - batch[2].to(self._device), - ) - - train_losses = self._loss( - theta_batch, x_batch, masks_batch, proposal, calibration_kernel - ) - train_loss = torch.mean(train_losses) - train_log_probs_sum -= train_losses.sum().item() - - train_loss.backward() - if clip_max_norm is not None: - clip_grad_norm_( - self._neural_net.parameters(), max_norm=clip_max_norm - ) - self.optimizer.step() - - self.epoch += 1 - - train_log_prob_average = train_log_probs_sum / ( - len(train_loader) * train_loader.batch_size # type: ignore - ) - self._summary["train_log_probs"].append(train_log_prob_average) - - # Calculate validation performance. - self._neural_net.eval() - val_log_prob_sum = 0 - - with torch.no_grad(): - for batch in val_loader: - theta_batch, x_batch, masks_batch = ( - batch[0].to(self._device), - batch[1].to(self._device), - batch[2].to(self._device), - ) - # Take negative loss here to get validation log_prob. - val_losses = self._loss( - theta_batch, - x_batch, - masks_batch, - proposal, - calibration_kernel, - ) - val_log_prob_sum -= val_losses.sum().item() - - # Take mean over all validation samples. - self._val_log_prob = val_log_prob_sum / ( - len(val_loader) * val_loader.batch_size # type: ignore - ) - # Log validation log prob for every epoch. - self._summary["validation_log_probs"].append(self._val_log_prob) - self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time) - - self._maybe_show_progress(self._show_progress_bars, self.epoch) - - self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs) - - # Update summary. - self._summary["epochs"].append(self.epoch) - self._summary["best_validation_log_probs"].append(self._best_val_log_prob) - - # Update tensorboard and summary dict. - self._summarize(round_=self._round, x_o=None, theta_bank=None, x_bank=None) - - # Update description for progress bar. - if show_train_summary: - print(self._describe_round(self._round, self._summary)) - - # Avoid keeping the gradients in the resulting network, which can - # cause memory leakage when benchmarking. - self._neural_net.zero_grad(set_to_none=True) - - return deepcopy(self._neural_net) - - def build_posterior( - self, - density_estimator: Optional[nn.Module] = None, - prior: Optional[Distribution] = None, - sample_with: str = "rejection", - mcmc_method: str = "slice_np", - vi_method: str = "rKL", - mcmc_parameters: Dict[str, Any] = {}, - vi_parameters: Dict[str, Any] = {}, - rejection_sampling_parameters: Dict[str, Any] = {}, - ) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior, DirectPosterior]: - r"""Build posterior from the neural density estimator. - - For SNPE, the posterior distribution that is returned here implements the - following functionality over the raw neural density estimator: - - correct the calculation of the log probability such that it compensates for - the leakage. - - reject samples that lie outside of the prior bounds. - - alternatively, if leakage is very high (which can happen for multi-round - SNPE), sample from the posterior with MCMC. - - Args: - density_estimator: The density estimator that the posterior is based on. - If `None`, use the latest neural density estimator that was trained. - prior: Prior distribution. - sample_with: Method to use for sampling from the posterior. Must be one of - [`mcmc` | `rejection` | `vi`]. - mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, - `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy - implementation of slice sampling; select `hmc`, `nuts` or `slice` for - Pyro-based sampling. - vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. Note - some of the methods admit a `mode seeking` property (e.g. rKL) whereas - some admit a `mass covering` one (e.g fKL). - mcmc_parameters: Additional kwargs passed to `MCMCPosterior`. - vi_parameters: Additional kwargs passed to `VIPosterior`. - rejection_sampling_parameters: Additional kwargs passed to - `RejectionPosterior` or `DirectPosterior`. By default, - `DirectPosterior` is used. Only if `rejection_sampling_parameters` - contains `proposal`, a `RejectionPosterior` is instantiated. - - Returns: - Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods - (the returned log-probability is unnormalized). - """ - if prior is None: - assert self._prior is not None, ( - "You did not pass a prior. You have to pass the prior either at " - "initialization `inference = SNPE(prior)` or to " - "`.build_posterior(prior=prior)`." - ) - prior = self._prior - else: - utils.check_prior(prior) - - if density_estimator is None: - posterior_estimator = self._neural_net - # If internal net is used device is defined. - device = self._device - else: - posterior_estimator = density_estimator - # Otherwise, infer it from the device of the net parameters. - device = next(density_estimator.parameters()).device.type - - potential_fn, theta_transform = posterior_estimator_based_potential( - posterior_estimator=posterior_estimator, prior=prior, x_o=None - ) - - if sample_with == "rejection": - if "proposal" in rejection_sampling_parameters.keys(): - self._posterior = RejectionPosterior( - potential_fn=potential_fn, - device=device, - x_shape=self._x_shape, - **rejection_sampling_parameters, - ) - else: - self._posterior = DirectPosterior( - posterior_estimator=posterior_estimator, - prior=prior, - x_shape=self._x_shape, - device=device, - ) - elif sample_with == "mcmc": - self._posterior = MCMCPosterior( - potential_fn=potential_fn, - theta_transform=theta_transform, - proposal=prior, - method=mcmc_method, - device=device, - x_shape=self._x_shape, - **mcmc_parameters, - ) - elif sample_with == "vi": - self._posterior = VIPosterior( - potential_fn=potential_fn, - theta_transform=theta_transform, - prior=prior, # type: ignore - vi_method=vi_method, - device=device, - x_shape=self._x_shape, - **vi_parameters, - ) - else: - raise NotImplementedError - - # Store models at end of each round. - self._model_bank.append(deepcopy(self._posterior)) - - return deepcopy(self._posterior) - - @abstractmethod - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - ) -> Tensor: - raise NotImplementedError - - def _loss( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - calibration_kernel: Callable, - ) -> Tensor: - """Return loss with proposal correction (`round_>0`) or without it (`round_=0`). - - The loss is the negative log prob. Irrespective of the round or SNPE method - (A, B, or C), it can be weighted with a calibration kernel. - - Returns: - Calibration kernel-weighted negative log prob. - """ - if self._round == 0: - # Use posterior log prob (without proposal correction) for first round. - log_prob = self._neural_net.log_prob(theta, x) - else: - log_prob = self._log_prob_proposal_posterior(theta, x, masks, proposal) - - return -(calibration_kernel(x) * log_prob) - - def _check_proposal(self, proposal): - """ - Check for validity of the provided proposal distribution. - - If the proposal is a `NeuralPosterior`, we check if the default_x is set. - If the proposal is **not** a `NeuralPosterior`, we warn since it is likely that - the user simply passed the prior, but this would still trigger atomic loss. - """ - if proposal is not None: - check_if_proposal_has_default_x(proposal) - - if isinstance(proposal, RestrictedPrior): - if proposal._prior is not self._prior: - warn( - "The proposal you passed is a `RestrictedPrior`, but the " - "proposal distribution it uses is not the prior (it can be " - "accessed via `RestrictedPrior._prior`). We do not " - "recommend to mix the `RestrictedPrior` with multi-round " - "SNPE." - ) - elif ( - not isinstance(proposal, NeuralPosterior) - and proposal is not self._prior - ): - warn( - "The proposal you passed is neither the prior nor a " - "`NeuralPosterior` object. If you are an expert user and did so " - "for research purposes, this is fine. If not, you might be doing " - "something wrong: feel free to create an issue on Github." - ) - elif self._round > 0: - raise ValueError( - "A proposal was passed but no prior was passed at initialisation. When " - "running multi-round inference, a prior needs to be specified upon " - "initialisation. Potential fix: setting the `._prior` attribute or " - "re-initialisation. If the samples passed to `append_simulations()` " - "were sampled from the prior, single-round inference can be performed " - "with `append_simulations(..., proprosal=None)`." - ) +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . +import time +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, Callable, Dict, Optional, Union +from warnings import warn + +import torch +from torch import Tensor, nn, ones, optim +from torch.distributions import Distribution +from torch.nn.utils.clip_grad import clip_grad_norm_ +from torch.utils import data +from torch.utils.tensorboard.writer import SummaryWriter + +from sbi import utils as utils +from sbi.inference import NeuralInference, check_if_proposal_has_default_x +from sbi.inference.posteriors import ( + DirectPosterior, + MCMCPosterior, + RejectionPosterior, + VIPosterior, +) +from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials import posterior_estimator_based_potential +from sbi.utils import ( + RestrictedPrior, + check_estimator_arg, + test_posterior_net_for_multi_d_x, + validate_theta_and_x, + x_shape_from_simulation, + handle_invalid_x, + warn_if_zscoring_changes_data, + warn_on_invalid_x, + warn_on_invalid_x_for_snpec_leakage, +) +from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior + + +class PosteriorEstimator(NeuralInference, ABC): + def __init__( + self, + prior: Optional[Distribution] = None, + density_estimator: Union[str, Callable] = "maf", + device: str = "cpu", + logging_level: Union[int, str] = "WARNING", + summary_writer: Optional[SummaryWriter] = None, + show_progress_bars: bool = True, + ): + """Base class for Sequential Neural Posterior Estimation methods. + + Args: + density_estimator: If it is a string, use a pre-configured network of the + provided type (one of nsf, maf, mdn, made). Alternatively, a function + that builds a custom neural network can be provided. The function will + be called with the first batch of simulations (theta, x), which can + thus be used for shape inference and potentially for z-scoring. It + needs to return a PyTorch `nn.Module` implementing the density + estimator. The density estimator needs to provide the methods + `.log_prob` and `.sample()`. + + See docstring of `NeuralInference` class for all other arguments. + """ + + super().__init__( + prior=prior, + device=device, + logging_level=logging_level, + summary_writer=summary_writer, + show_progress_bars=show_progress_bars, + ) + + # As detailed in the docstring, `density_estimator` is either a string or + # a callable. The function creating the neural network is attached to + # `_build_neural_net`. It will be called in the first round and receive + # thetas and xs as inputs, so that they can be used for shape inference and + # potentially for z-scoring. + check_estimator_arg(density_estimator) + if isinstance(density_estimator, str): + self._build_neural_net = utils.posterior_nn(model=density_estimator) + else: + self._build_neural_net = density_estimator + + self._proposal_roundwise = [] + self.use_non_atomic_loss = False + + # Extra SNPE-specific fields summary_writer. + self._summary.update({"rejection_sampling_acceptance_rates": []}) # type:ignore + + def append_simulations( + self, + theta: Tensor, + x: Tensor, + proposal: Optional[DirectPosterior] = None, + exclude_invalid_x: bool = True, + warn_on_invalid: bool = True, + warn_if_zscoring: bool = True, + return_self: bool = True, + data_device: str = None, + ) -> "PosteriorEstimator": + r"""Store parameters and simulation outputs to use them for later training. + + Data are stored as entries in lists for each type of variable (parameter/data). + + Stores $\theta$, $x$, prior_masks (indicating if simulations are coming from the + prior or not) and an index indicating which round the batch of simulations came + from. + + Args: + theta: Parameter sets. + x: Simulation outputs. + proposal: The distribution that the parameters $\theta$ were sampled from. + Pass `None` if the parameters were sampled from the prior. If not + `None`, it will trigger a different loss-function. + exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` + during training. Expect errors, silent or explicit, when `False`. + warn_on_invalid: Whether to warn if data is invalid + warn_if_zscoring: Whether to test if z-scoring causes duplicates + return_self: Whether to return a instance of the class, allows chaining + with `.train()`. Setting `False` decreases memory overhead. + data_device: Where to store the data, default is on the same device where + the training is happening. If training a large dataset on a GPU with not + much VRAM can set to 'cpu' to store data on system memory instead. + + Returns: + NeuralInference object (returned so that this function is chainable). + """ + + # Add ability to specify device data is saved on + if data_device is None: data_device = self._device + theta, x = validate_theta_and_x(theta, x, training_device=data_device) + + + is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) + + # Check for problematic z-scoring + if warn_if_zscoring: + warn_if_zscoring_changes_data(x[is_valid_x]) + if warn_on_invalid: + warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) + warn_on_invalid_x_for_snpec_leakage( + num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round + ) + + x = x[is_valid_x] + theta = theta[is_valid_x] + + + self._check_proposal(proposal) + + if ( + proposal is None + or proposal is self._prior + or ( + isinstance(proposal, RestrictedPrior) and proposal._prior is self._prior + ) + ): + # The `_data_round_index` will later be used to infer if one should train + # with MLE loss or with atomic loss (see, in `train()`: + # self._round = max(self._data_round_index)) + self._data_round_index.append(0) + prior_masks = mask_sims_from_prior(0, theta.size(0)) + else: + if not self._data_round_index: + # This catches a pretty specific case: if, in the first round, one + # passes data that does not come from the prior. + self._data_round_index.append(1) + else: + self._data_round_index.append(max(self._data_round_index) + 1) + prior_masks = mask_sims_from_prior(1, theta.size(0)) + + + if self._dataset is None: + #If first round, set up ConcatDataset + self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) + else: + #Otherwise append to Dataset + self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) + + self._num_sims_per_round.append(theta.size(0)) + self._proposal_roundwise.append(proposal) + + if self._prior is None or isinstance(self._prior, ImproperEmpirical): + if proposal is not None: + raise ValueError( + "You had not passed a prior at initialization, but now you " + "passed a proposal. If you want to run multi-round SNPE, you have " + "to specify a prior (set the `.prior` argument or re-initialize " + "the object with a prior distribution). If the samples you passed " + "to `append_simulations()` were sampled from the prior, you can " + "run single-round inference with " + "`append_simulations(..., proposal=None)`." + ) + theta_prior = self.get_simulations()[0] + self._prior = ImproperEmpirical(theta_prior, ones(theta_prior.shape[0])) + + #Add ability to not return self + if return_self: + return self + + def train( + self, + training_batch_size: int = 50, + learning_rate: float = 5e-4, + validation_fraction: float = 0.1, + stop_after_epochs: int = 20, + max_num_epochs: int = 2**31 - 1, + clip_max_norm: Optional[float] = 5.0, + calibration_kernel: Optional[Callable] = None, + resume_training: bool = False, + force_first_round_loss: bool = False, + discard_prior_samples: bool = False, + retrain_from_scratch: bool = False, + show_train_summary: bool = False, + dataloader_kwargs: Optional[dict] = None, + ) -> nn.Module: + r"""Return density estimator that approximates the distribution $p(\theta|x)$. + + Args: + training_batch_size: Training batch size. + learning_rate: Learning rate for Adam optimizer. + validation_fraction: The fraction of data to use for validation. + stop_after_epochs: The number of epochs to wait for improvement on the + validation set before terminating training. + max_num_epochs: Maximum number of epochs to run. If reached, we stop + training even when the validation loss is still decreasing. Otherwise, + we train until validation loss increases (see also `stop_after_epochs`). + clip_max_norm: Value at which to clip the total gradient norm in order to + prevent exploding gradients. Use None for no clipping. + calibration_kernel: A function to calibrate the loss with respect to the + simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. + resume_training: Can be used in case training time is limited, e.g. on a + cluster. If `True`, the split between train and validation set, the + optimizer, the number of epochs, and the best validation log-prob will + be restored from the last time `.train()` was called. + force_first_round_loss: If `True`, train with maximum likelihood, + i.e., potentially ignoring the correction for using a proposal + distribution different from the prior. + discard_prior_samples: Whether to discard samples simulated in round 1, i.e. + from the prior. Training may be sped up by ignoring such less targeted + samples. + retrain_from_scratch: Whether to retrain the conditional density + estimator for the posterior from scratch each round. + show_train_summary: Whether to print the number of epochs and validation + loss after the training. + dataloader_kwargs: Additional or updated kwargs to be passed to the training + and validation dataloaders (like, e.g., a collate_fn) + + Returns: + Density estimator that approximates the distribution $p(\theta|x)$. + """ + if self._round == 0 and self._neural_net is not None: + assert force_first_round_loss, ( + "You have already trained this neural network. After you had trained " + "the network, you again appended simulations with `append_simulations" + "(theta, x)`, but you did not provide a proposal. If the new " + "simulations are sampled from the prior, you can set " + "`.train(..., force_first_round_loss=True`). However, if the new " + "simulations were not sampled from the prior, you should pass the " + "proposal, i.e. `append_simulations(theta, x, proposal)`. If " + "your samples are not sampled from the prior and you do not pass a " + "proposal and you set `force_first_round_loss=True`, the result of " + "SNPE will not be the true posterior. Instead, it will be the proposal " + "posterior, which (usually) is more narrow than the true posterior." + ) + + # Calibration kernels proposed in Lueckmann, Gonçalves et al., 2017. + if calibration_kernel is None: + calibration_kernel = lambda x: ones([len(x)], device=self._device) + + # Starting index for the training set (1 = discard round-0 samples). + start_idx = int(discard_prior_samples and self._round > 0) + + # For non-atomic loss, we can not reuse samples from previous rounds as of now. + # SNPE-A can, by construction of the algorithm, only use samples from the last + # round. SNPE-A is the only algorithm that has an attribute `_ran_final_round`, + # so this is how we check for whether or not we are using SNPE-A. + if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"): + start_idx = self._round + + # Set the proposal to the last proposal that was passed by the user. For + # atomic SNPE, it does not matter what the proposal is. For non-atomic + # SNPE, we only use the latest data that was passed, i.e. the one from the + # last proposal. + proposal = self._proposal_roundwise[-1] + + train_loader, val_loader = self.get_dataloaders( + start_idx, + training_batch_size, + validation_fraction, + resume_training, + dataloader_kwargs=dataloader_kwargs, + ) + # First round or if retraining from scratch: + # Call the `self._build_neural_net` with the rounds' thetas and xs as + # arguments, which will build the neural network. + # This is passed into NeuralPosterior, to create a neural posterior which + # can `sample()` and `log_prob()`. The network is accessible via `.net`. + if self._neural_net is None or retrain_from_scratch: + + #Get theta,x from dataset to initialize NN + test_theta = self._dataset.datasets[0].tensors[0][:100] + test_x = self._dataset.datasets[0].tensors[1][:100] + + self._neural_net = self._build_neural_net( + test_theta, test_x + ) + # If data on training device already move net as well. + if ( + not self._device == "cpu" + and f"{test_x.device.type}:{test_x.device.index}" == self._device + ): + self._neural_net.to(self._device) + + test_posterior_net_for_multi_d_x(self._neural_net, test_theta, test_x) + self._x_shape = x_shape_from_simulation(test_x) + + # Move entire net to device for training. + self._neural_net.to(self._device) + + if not resume_training: + self.optimizer = optim.Adam( + list(self._neural_net.parameters()), lr=learning_rate + ) + self.epoch, self._val_log_prob = 0, float("-Inf") + + while self.epoch <= max_num_epochs and not self._converged( + self.epoch, stop_after_epochs + ): + + # Train for a single epoch. + self._neural_net.train() + train_log_probs_sum = 0 + epoch_start_time = time.time() + for batch in train_loader: + self.optimizer.zero_grad() + # Get batches on current device. + theta_batch, x_batch, masks_batch = ( + batch[0].to(self._device), + batch[1].to(self._device), + batch[2].to(self._device), + ) + + train_losses = self._loss( + theta_batch, x_batch, masks_batch, proposal, calibration_kernel + ) + train_loss = torch.mean(train_losses) + train_log_probs_sum -= train_losses.sum().item() + + train_loss.backward() + if clip_max_norm is not None: + clip_grad_norm_( + self._neural_net.parameters(), max_norm=clip_max_norm + ) + self.optimizer.step() + + self.epoch += 1 + + train_log_prob_average = train_log_probs_sum / ( + len(train_loader) * train_loader.batch_size # type: ignore + ) + self._summary["train_log_probs"].append(train_log_prob_average) + + # Calculate validation performance. + self._neural_net.eval() + val_log_prob_sum = 0 + + with torch.no_grad(): + for batch in val_loader: + theta_batch, x_batch, masks_batch = ( + batch[0].to(self._device), + batch[1].to(self._device), + batch[2].to(self._device), + ) + # Take negative loss here to get validation log_prob. + val_losses = self._loss( + theta_batch, + x_batch, + masks_batch, + proposal, + calibration_kernel, + ) + val_log_prob_sum -= val_losses.sum().item() + + # Take mean over all validation samples. + self._val_log_prob = val_log_prob_sum / ( + len(val_loader) * val_loader.batch_size # type: ignore + ) + # Log validation log prob for every epoch. + self._summary["validation_log_probs"].append(self._val_log_prob) + self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time) + + self._maybe_show_progress(self._show_progress_bars, self.epoch) + + self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs) + + # Update summary. + self._summary["epochs"].append(self.epoch) + self._summary["best_validation_log_probs"].append(self._best_val_log_prob) + + # Update tensorboard and summary dict. + self._summarize(round_=self._round, x_o=None, theta_bank=None, x_bank=None) + + # Update description for progress bar. + if show_train_summary: + print(self._describe_round(self._round, self._summary)) + + # Avoid keeping the gradients in the resulting network, which can + # cause memory leakage when benchmarking. + self._neural_net.zero_grad(set_to_none=True) + + return deepcopy(self._neural_net) + + def build_posterior( + self, + density_estimator: Optional[nn.Module] = None, + prior: Optional[Distribution] = None, + sample_with: str = "rejection", + mcmc_method: str = "slice_np", + vi_method: str = "rKL", + mcmc_parameters: Dict[str, Any] = {}, + vi_parameters: Dict[str, Any] = {}, + rejection_sampling_parameters: Dict[str, Any] = {}, + ) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior, DirectPosterior]: + r"""Build posterior from the neural density estimator. + + For SNPE, the posterior distribution that is returned here implements the + following functionality over the raw neural density estimator: + - correct the calculation of the log probability such that it compensates for + the leakage. + - reject samples that lie outside of the prior bounds. + - alternatively, if leakage is very high (which can happen for multi-round + SNPE), sample from the posterior with MCMC. + + Args: + density_estimator: The density estimator that the posterior is based on. + If `None`, use the latest neural density estimator that was trained. + prior: Prior distribution. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection` | `vi`]. + mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, + `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy + implementation of slice sampling; select `hmc`, `nuts` or `slice` for + Pyro-based sampling. + vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. Note + some of the methods admit a `mode seeking` property (e.g. rKL) whereas + some admit a `mass covering` one (e.g fKL). + mcmc_parameters: Additional kwargs passed to `MCMCPosterior`. + vi_parameters: Additional kwargs passed to `VIPosterior`. + rejection_sampling_parameters: Additional kwargs passed to + `RejectionPosterior` or `DirectPosterior`. By default, + `DirectPosterior` is used. Only if `rejection_sampling_parameters` + contains `proposal`, a `RejectionPosterior` is instantiated. + + Returns: + Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods + (the returned log-probability is unnormalized). + """ + if prior is None: + assert self._prior is not None, ( + "You did not pass a prior. You have to pass the prior either at " + "initialization `inference = SNPE(prior)` or to " + "`.build_posterior(prior=prior)`." + ) + prior = self._prior + else: + utils.check_prior(prior) + + if density_estimator is None: + posterior_estimator = self._neural_net + # If internal net is used device is defined. + device = self._device + else: + posterior_estimator = density_estimator + # Otherwise, infer it from the device of the net parameters. + device = next(density_estimator.parameters()).device.type + + potential_fn, theta_transform = posterior_estimator_based_potential( + posterior_estimator=posterior_estimator, prior=prior, x_o=None + ) + + if sample_with == "rejection": + if "proposal" in rejection_sampling_parameters.keys(): + self._posterior = RejectionPosterior( + potential_fn=potential_fn, + device=device, + x_shape=self._x_shape, + **rejection_sampling_parameters, + ) + else: + self._posterior = DirectPosterior( + posterior_estimator=posterior_estimator, + prior=prior, + x_shape=self._x_shape, + device=device, + ) + elif sample_with == "mcmc": + self._posterior = MCMCPosterior( + potential_fn=potential_fn, + theta_transform=theta_transform, + proposal=prior, + method=mcmc_method, + device=device, + x_shape=self._x_shape, + **mcmc_parameters, + ) + elif sample_with == "vi": + self._posterior = VIPosterior( + potential_fn=potential_fn, + theta_transform=theta_transform, + prior=prior, # type: ignore + vi_method=vi_method, + device=device, + x_shape=self._x_shape, + **vi_parameters, + ) + else: + raise NotImplementedError + + # Store models at end of each round. + self._model_bank.append(deepcopy(self._posterior)) + + return deepcopy(self._posterior) + + @abstractmethod + def _log_prob_proposal_posterior( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + ) -> Tensor: + raise NotImplementedError + + def _loss( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + calibration_kernel: Callable, + ) -> Tensor: + """Return loss with proposal correction (`round_>0`) or without it (`round_=0`). + + The loss is the negative log prob. Irrespective of the round or SNPE method + (A, B, or C), it can be weighted with a calibration kernel. + + Returns: + Calibration kernel-weighted negative log prob. + """ + if self._round == 0: + # Use posterior log prob (without proposal correction) for first round. + log_prob = self._neural_net.log_prob(theta, x) + else: + log_prob = self._log_prob_proposal_posterior(theta, x, masks, proposal) + + return -(calibration_kernel(x) * log_prob) + + def _check_proposal(self, proposal): + """ + Check for validity of the provided proposal distribution. + + If the proposal is a `NeuralPosterior`, we check if the default_x is set. + If the proposal is **not** a `NeuralPosterior`, we warn since it is likely that + the user simply passed the prior, but this would still trigger atomic loss. + """ + if proposal is not None: + check_if_proposal_has_default_x(proposal) + + if isinstance(proposal, RestrictedPrior): + if proposal._prior is not self._prior: + warn( + "The proposal you passed is a `RestrictedPrior`, but the " + "proposal distribution it uses is not the prior (it can be " + "accessed via `RestrictedPrior._prior`). We do not " + "recommend to mix the `RestrictedPrior` with multi-round " + "SNPE." + ) + elif ( + not isinstance(proposal, NeuralPosterior) + and proposal is not self._prior + ): + warn( + "The proposal you passed is neither the prior nor a " + "`NeuralPosterior` object. If you are an expert user and did so " + "for research purposes, this is fine. If not, you might be doing " + "something wrong: feel free to create an issue on Github." + ) + elif self._round > 0: + raise ValueError( + "A proposal was passed but no prior was passed at initialisation. When " + "running multi-round inference, a prior needs to be specified upon " + "initialisation. Potential fix: setting the `._prior` attribute or " + "re-initialisation. If the samples passed to `append_simulations()` " + "were sampled from the prior, single-round inference can be performed " + "with `append_simulations(..., proprosal=None)`." + ) diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index f09428ccb..b9cd11984 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -1,629 +1,625 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . - - -from typing import Callable, Dict, Optional, Union - -import torch -from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn -from pyknos.nflows.transforms import CompositeTransform -from torch import Tensor, eye, nn, ones -from torch.distributions import Distribution, MultivariateNormal, Uniform - -from sbi import utils as utils -from sbi.inference.posteriors.direct_posterior import DirectPosterior -from sbi.inference.snpe.snpe_base import PosteriorEstimator -from sbi.types import TensorboardSummaryWriter -from sbi.utils import ( - batched_mixture_mv, - batched_mixture_vmv, - check_dist_class, - clamp_and_warn, - del_entries, - repeat_rows, -) - - -class SNPE_C(PosteriorEstimator): - def __init__( - self, - prior: Optional[Distribution] = None, - density_estimator: Union[str, Callable] = "maf", - device: str = "cpu", - logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorboardSummaryWriter] = None, - show_progress_bars: bool = True, - ): - r"""SNPE-C / APT [1]. - - [1] _Automatic Posterior Transformation for Likelihood-free Inference_, - Greenberg et al., ICML 2019, https://arxiv.org/abs/1905.07488. - - This class implements two loss variants of SNPE-C: the non-atomic and the atomic - version. The atomic loss of SNPE-C can be used for any density estimator, - i.e. also for normalizing flows. However, it suffers from leakage issues. On - the other hand, the non-atomic loss can only be used only if the proposal - distribution is a mixture of Gaussians, the density estimator is a mixture of - Gaussians, and the prior is either Gaussian or Uniform. It does not suffer from - leakage issues. At the beginning of each round, we print whether the non-atomic - or the atomic version is used. - - In this codebase, we will automatically switch to the non-atomic loss if the - following criteria are fulfilled:
- - proposal is a `DirectPosterior` with density_estimator `mdn`, as built - with `utils.sbi.posterior_nn()`.
- - the density estimator is a `mdn`, as built with - `utils.sbi.posterior_nn()`.
- - `isinstance(prior, MultivariateNormal)` (from `torch.distributions`) or - `isinstance(prior, sbi.utils.BoxUniform)` - - Note that custom implementations of any of these densities (or estimators) will - not trigger the non-atomic loss, and the algorithm will fall back onto using - the atomic loss. - - Args: - prior: A probability distribution that expresses prior knowledge about the - parameters, e.g. which ranges are meaningful for them. - density_estimator: If it is a string, use a pre-configured network of the - provided type (one of nsf, maf, mdn, made). Alternatively, a function - that builds a custom neural network can be provided. The function will - be called with the first batch of simulations (theta, x), which can - thus be used for shape inference and potentially for z-scoring. It - needs to return a PyTorch `nn.Module` implementing the density - estimator. The density estimator needs to provide the methods - `.log_prob` and `.sample()`. - device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". - logging_level: Minimum severity of messages to log. One of the strings - INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) - show_progress_bars: Whether to show a progressbar during training. - """ - - kwargs = del_entries(locals(), entries=("self", "__class__")) - super().__init__(**kwargs) - - def train( - self, - num_atoms: int = 10, - training_batch_size: int = 50, - learning_rate: float = 5e-4, - validation_fraction: float = 0.1, - stop_after_epochs: int = 20, - max_num_epochs: int = 2**31 - 1, - clip_max_norm: Optional[float] = 5.0, - calibration_kernel: Optional[Callable] = None, - exclude_invalid_x: bool = True, - resume_training: bool = False, - force_first_round_loss: bool = False, - discard_prior_samples: bool = False, - use_combined_loss: bool = False, - retrain_from_scratch: bool = False, - show_train_summary: bool = False, - dataloader_kwargs: Optional[Dict] = None, - warn_if_zscoring: Optional[bool] = True - ) -> nn.Module: - r"""Return density estimator that approximates the distribution $p(\theta|x)$. - - Args: - num_atoms: Number of atoms to use for classification. - training_batch_size: Training batch size. - learning_rate: Learning rate for Adam optimizer. - validation_fraction: The fraction of data to use for validation. - stop_after_epochs: The number of epochs to wait for improvement on the - validation set before terminating training. - max_num_epochs: Maximum number of epochs to run. If reached, we stop - training even when the validation loss is still decreasing. Otherwise, - we train until validation loss increases (see also `stop_after_epochs`). - clip_max_norm: Value at which to clip the total gradient norm in order to - prevent exploding gradients. Use None for no clipping. - calibration_kernel: A function to calibrate the loss with respect to the - simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. - resume_training: Can be used in case training time is limited, e.g. on a - cluster. If `True`, the split between train and validation set, the - optimizer, the number of epochs, and the best validation log-prob will - be restored from the last time `.train()` was called. - force_first_round_loss: If `True`, train with maximum likelihood, - i.e., potentially ignoring the correction for using a proposal - distribution different from the prior. - discard_prior_samples: Whether to discard samples simulated in round 1, i.e. - from the prior. Training may be sped up by ignoring such less targeted - samples. - use_combined_loss: Whether to train the neural net also on prior samples - using maximum likelihood in addition to training it on all samples using - atomic loss. The extra MLE loss helps prevent density leaking with - bounded priors. - retrain_from_scratch: Whether to retrain the conditional density - estimator for the posterior from scratch each round. - show_train_summary: Whether to print the number of epochs and validation - loss and leakage after the training. - dataloader_kwargs: Additional or updated kwargs to be passed to the training - and validation dataloaders (like, e.g., a collate_fn) - - Returns: - Density estimator that approximates the distribution $p(\theta|x)$. - """ - - # WARNING: sneaky trick ahead. We proxy the parent's `train` here, - # requiring the signature to have `num_atoms`, save it for use below, and - # continue. It's sneaky because we are using the object (self) as a namespace - # to pass arguments between functions, and that's implicit state management. - self._num_atoms = num_atoms - self._use_combined_loss = use_combined_loss - kwargs = del_entries( - locals(), entries=("self", "__class__", "num_atoms", "use_combined_loss") - ) - - self._round = max(self._data_round_index) - - if self._round > 0: - # Set the proposal to the last proposal that was passed by the user. For - # atomic SNPE, it does not matter what the proposal is. For non-atomic - # SNPE, we only use the latest data that was passed, i.e. the one from the - # last proposal. - proposal = self._proposal_roundwise[-1] - self.use_non_atomic_loss = ( - isinstance(proposal.posterior_estimator._distribution, mdn) - and isinstance(self._neural_net._distribution, mdn) - and check_dist_class( - self._prior, class_to_check=(Uniform, MultivariateNormal) - )[0] - ) - - algorithm = "non-atomic" if self.use_non_atomic_loss else "atomic" - print(f"Using SNPE-C with {algorithm} loss") - - if self.use_non_atomic_loss: - # Take care of z-scoring, pre-compute and store prior terms. - self._set_state_for_mog_proposal() - - return super().train(**kwargs) - - def _set_state_for_mog_proposal(self) -> None: - """Set state variables that are used at each training step of non-atomic SNPE-C. - - Three things are computed: - 1) Check if z-scoring was requested. To do so, we check if the `_transform` - argument of the net had been a `CompositeTransform`. See pyknos mdn.py. - 2) Define a (potentially standardized) prior. It's standardized if z-scoring - had been requested. - 3) Compute (Precision * mean) for the prior. This quantity is used at every - training step if the prior is Gaussian. - """ - - self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) - - self._set_maybe_z_scored_prior() - - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - self.prec_m_prod_prior = torch.mv( - self._maybe_z_scored_prior.precision_matrix, # type: ignore - self._maybe_z_scored_prior.loc, # type: ignore - ) - - def _set_maybe_z_scored_prior(self) -> None: - r"""Compute and store potentially standardized prior (if z-scoring was done). - - The proposal posterior is: - $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ - - Let's denote z-scored theta by `a`: a = (theta - mean) / std - Then pp'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ - - The ' indicates that the evaluation occurs in standardized space. The constant - scaling factor has been absorbed into Z_2. - From the above equation, we see that we need to evaluate the prior **in - standardized space**. We build the standardized prior in this function. - - The standardize transform that is applied to the samples theta does not use - the exact prior mean and std (due to implementation issues). Hence, the z-scored - prior will not be exactly have mean=0 and std=1. - """ - - if self.z_score_theta: - scale = self._neural_net._transform._transforms[0]._scale - shift = self._neural_net._transform._transforms[0]._shift - - # Following the definintion of the linear transform in - # `standardizing_transform` in `sbiutils.py`: - # shift=-mean / std - # scale=1 / std - # Solving these equations for mean and std: - estim_prior_std = 1 / scale - estim_prior_mean = -shift * estim_prior_std - - # Compute the discrepancy of the true prior mean and std and the mean and - # std that was empirically estimated from samples. - # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) - # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean - # and std (estimated from samples and used to build standardize transform). - almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std - almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std - - if isinstance(self._prior, MultivariateNormal): - self._maybe_z_scored_prior = MultivariateNormal( - almost_zero_mean, torch.diag(almost_one_std) - ) - else: - range_ = torch.sqrt(almost_one_std * 3.0) - self._maybe_z_scored_prior = utils.BoxUniform( - almost_zero_mean - range_, almost_zero_mean + range_ - ) - else: - self._maybe_z_scored_prior = self._prior - - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: DirectPosterior, - ) -> Tensor: - """Return the log-probability of the proposal posterior. - - If the proposal is a MoG, the density estimator is a MoG, and the prior is - either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which - suffers from leakage). - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - proposal: Proposal distribution. - - Returns: Log-probability of the proposal posterior. - """ - - if self.use_non_atomic_loss: - return self._log_prob_proposal_posterior_mog(theta, x, proposal) - else: - return self._log_prob_proposal_posterior_atomic(theta, x, masks) - - def _log_prob_proposal_posterior_atomic( - self, theta: Tensor, x: Tensor, masks: Tensor - ): - """Return log probability of the proposal posterior for atomic proposals. - - We have two main options when evaluating the proposal posterior. - (1) Generate atoms from the proposal prior. - (2) Generate atoms from a more targeted distribution, such as the most - recent posterior. - If we choose the latter, it is likely beneficial not to do this in the first - round, since we would be sampling from a randomly-initialized neural density - estimator. - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - - Returns: - Log-probability of the proposal posterior. - """ - - batch_size = theta.shape[0] - - num_atoms = int( - clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size) - ) - - # Each set of parameter atoms is evaluated using the same x, - # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2] - repeated_x = repeat_rows(x, num_atoms) - - # To generate the full set of atoms for a given item in the batch, - # we sample without replacement num_atoms - 1 times from the rest - # of the theta in the batch. - probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) - - choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) - contrasting_theta = theta[choices] - - # We can now create our sets of atoms from the contrasting parameter sets - # we have generated. - atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( - batch_size * num_atoms, -1 - ) - - # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. - log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x) - utils.assert_all_finite(log_prob_posterior, "posterior eval") - log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) - - # Get (batch_size * num_atoms) log prob prior evals. - log_prob_prior = self._prior.log_prob(atomic_theta) - log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) - utils.assert_all_finite(log_prob_prior, "prior eval") - - # Compute unnormalized proposal posterior. - unnormalized_log_prob = log_prob_posterior - log_prob_prior - - # Normalize proposal posterior across discrete set of atoms. - log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp( - unnormalized_log_prob, dim=-1 - ) - utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") - - # XXX This evaluates the posterior on _all_ prior samples - if self._use_combined_loss: - log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x) - masks = masks.reshape(-1) - log_prob_proposal_posterior = ( - masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior - ) - - return log_prob_proposal_posterior - - def _log_prob_proposal_posterior_mog( - self, theta: Tensor, x: Tensor, proposal: DirectPosterior - ) -> Tensor: - """Return log-probability of the proposal posterior for MoG proposal. - - For MoG proposals and MoG density estimators, this can be done in closed form - and does not require atomic loss (i.e. there will be no leakage issues). - - Notation: - - m are mean vectors. - prec are precision matrices. - cov are covariance matrices. - - _p at the end indicates that it is the proposal. - _d indicates that it is the density estimator. - _pp indicates the proposal posterior. - - All tensors will have shapes (batch_dim, num_components, ...) - - Args: - theta: Batch of parameters θ. - x: Batch of data. - proposal: Proposal distribution. - - Returns: - Log-probability of the proposal posterior. - """ - - # Evaluate the proposal. MDNs do not have functionality to run the embedding_net - # and then get the mixture_components (**without** calling log_prob()). Hence, - # we call them separately here. - encoded_x = proposal.posterior_estimator._embedding_net(proposal.default_x) - dist = ( - proposal.posterior_estimator._distribution - ) # defined to avoid ugly black formatting. - logits_p, m_p, prec_p, _, _ = dist.get_mixture_components(encoded_x) - norm_logits_p = logits_p - torch.logsumexp(logits_p, dim=-1, keepdim=True) - - # Evaluate the density estimator. - encoded_x = self._neural_net._embedding_net(x) - dist = self._neural_net._distribution # defined to avoid black formatting. - logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) - norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) - - # z-score theta if it z-scoring had been requested. - theta = self._maybe_z_score_theta(theta) - - # Compute the MoG parameters of the proposal posterior. - logits_pp, m_pp, prec_pp, cov_pp = self._automatic_posterior_transformation( - norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d - ) - - # Compute the log_prob of theta under the product. - log_prob_proposal_posterior = utils.mog_log_prob( - theta, logits_pp, m_pp, prec_pp - ) - utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") - - return log_prob_proposal_posterior - - def _automatic_posterior_transformation( - self, - logits_p: Tensor, - means_p: Tensor, - precisions_p: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Returns the MoG parameters of the proposal posterior. - - The proposal posterior is: - $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ - In words: proposal posterior = posterior estimate * proposal / prior. - - If the posterior estimate and the proposal are MoG and the prior is either - Gaussian or uniform, we can solve this in closed-form. The is implemented in - this function. - - This function implements Appendix A1 from Greenberg et al. 2019. - - We have to build L*K components. How do we do this? - Example: proposal has two components, density estimator has three components. - Let's call the two components of the proposal i,j and the three components - of the density estimator x,y,z. We have to multiply every component of the - proposal with every component of the density estimator. So, what we do is: - 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave() - 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat() - 3) Multiply them with simple matrix operations. - - Args: - logits_p: Component weight of each Gaussian of the proposal. - means_p: Mean of each Gaussian of the proposal. - precisions_p: Precision matrix of each Gaussian of the proposal. - logits_d: Component weight for each Gaussian of the density estimator. - means_d: Mean of each Gaussian of the density estimator. - precisions_d: Precision matrix of each Gaussian of the density estimator. - - Returns: (Component weight, mean, precision matrix, covariance matrix) of each - Gaussian of the proposal posterior. Has L*K terms (proposal has L terms, - density estimator has K terms). - """ - - precisions_pp, covariances_pp = self._precisions_proposal_posterior( - precisions_p, precisions_d - ) - - means_pp = self._means_proposal_posterior( - covariances_pp, means_p, precisions_p, means_d, precisions_d - ) - - logits_pp = self._logits_proposal_posterior( - means_pp, - precisions_pp, - covariances_pp, - logits_p, - means_p, - precisions_p, - logits_d, - means_d, - precisions_d, - ) - - return logits_pp, means_pp, precisions_pp, covariances_pp - - def _precisions_proposal_posterior( - self, precisions_p: Tensor, precisions_d: Tensor - ): - """Return the precisions and covariances of the proposal posterior. - - Args: - precisions_p: Precision matrices of the proposal distribution. - precisions_d: Precision matrices of the density estimator. - - Returns: (Precisions, Covariances) of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1) - precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) - - precisions_pp = precisions_p_rep + precisions_d_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - precisions_pp -= self._maybe_z_scored_prior.precision_matrix - - covariances_pp = torch.inverse(precisions_pp) - - return precisions_pp, covariances_pp - - def _means_proposal_posterior( - self, - covariances_pp: Tensor, - means_p: Tensor, - precisions_p: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - """Return the means of the proposal posterior. - - means_pp = C_ix * (P_i * m_i + P_x * m_x - P_o * m_o). - - Args: - covariances_pp: Covariance matrices of the proposal posterior. - means_p: Means of the proposal distribution. - precisions_p: Precision matrices of the proposal distribution. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Means of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - # First, compute the product P_i * m_i and P_j * m_j - prec_m_prod_p = batched_mixture_mv(precisions_p, means_p) - prec_m_prod_d = batched_mixture_mv(precisions_d, means_d) - - # Repeat them to allow for matrix operations: same trick as for the precisions. - prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1) - prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1) - - # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o). - summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - summed_cov_m_prod_rep -= self.prec_m_prod_prior - - means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep) - - return means_pp - - @staticmethod - def _logits_proposal_posterior( - means_pp: Tensor, - precisions_pp: Tensor, - covariances_pp: Tensor, - logits_p: Tensor, - means_p: Tensor, - precisions_p: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - """Return the component weights (i.e. logits) of the proposal posterior. - - Args: - means_pp: Means of the proposal posterior. - precisions_pp: Precision matrices of the proposal posterior. - covariances_pp: Covariance matrices of the proposal posterior. - logits_p: Component weights (i.e. logits) of the proposal distribution. - means_p: Means of the proposal distribution. - precisions_p: Precision matrices of the proposal distribution. - logits_d: Component weights (i.e. logits) of the density estimator. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Component weights of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - # Compute log(alpha_i * beta_j) - logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) - logits_d_rep = logits_d.repeat(1, num_comps_p) - logit_factors = logits_p_rep + logits_d_rep - - # Compute sqrt(det()/(det()*det())) - logdet_covariances_pp = torch.logdet(covariances_pp) - logdet_covariances_p = -torch.logdet(precisions_p) - logdet_covariances_d = -torch.logdet(precisions_d) - - # Repeat the proposal and density estimator terms such that there are LK terms. - # Same trick as has been used above. - logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( - num_comps_d, dim=1 - ) - logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) - - log_sqrt_det_ratio = 0.5 * ( - logdet_covariances_pp - - (logdet_covariances_p_rep + logdet_covariances_d_rep) - ) - - # Compute for proposal, density estimator, and proposal posterior: - # mu_i.T * P_i * mu_i - exponent_p = batched_mixture_vmv(precisions_p, means_p) - exponent_d = batched_mixture_vmv(precisions_d, means_d) - exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) - - # Extend proposal and density estimator exponents to get LK terms. - exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) - exponent_d_rep = exponent_d.repeat(1, num_comps_p) - exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) - - logits_pp = logit_factors + log_sqrt_det_ratio + exponent - - return logits_pp - - def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: - """Return potentially standardized theta if z-scoring was requested.""" - - if self.z_score_theta: - theta, _ = self._neural_net._transform(theta) - - return theta +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + + +from typing import Callable, Dict, Optional, Union + +import torch +from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn +from pyknos.nflows.transforms import CompositeTransform +from torch import Tensor, eye, nn, ones +from torch.distributions import Distribution, MultivariateNormal, Uniform + +from sbi import utils as utils +from sbi.inference.posteriors.direct_posterior import DirectPosterior +from sbi.inference.snpe.snpe_base import PosteriorEstimator +from sbi.types import TensorboardSummaryWriter +from sbi.utils import ( + batched_mixture_mv, + batched_mixture_vmv, + check_dist_class, + clamp_and_warn, + del_entries, + repeat_rows, +) + + +class SNPE_C(PosteriorEstimator): + def __init__( + self, + prior: Optional[Distribution] = None, + density_estimator: Union[str, Callable] = "maf", + device: str = "cpu", + logging_level: Union[int, str] = "WARNING", + summary_writer: Optional[TensorboardSummaryWriter] = None, + show_progress_bars: bool = True, + ): + r"""SNPE-C / APT [1]. + + [1] _Automatic Posterior Transformation for Likelihood-free Inference_, + Greenberg et al., ICML 2019, https://arxiv.org/abs/1905.07488. + + This class implements two loss variants of SNPE-C: the non-atomic and the atomic + version. The atomic loss of SNPE-C can be used for any density estimator, + i.e. also for normalizing flows. However, it suffers from leakage issues. On + the other hand, the non-atomic loss can only be used only if the proposal + distribution is a mixture of Gaussians, the density estimator is a mixture of + Gaussians, and the prior is either Gaussian or Uniform. It does not suffer from + leakage issues. At the beginning of each round, we print whether the non-atomic + or the atomic version is used. + + In this codebase, we will automatically switch to the non-atomic loss if the + following criteria are fulfilled:
+ - proposal is a `DirectPosterior` with density_estimator `mdn`, as built + with `utils.sbi.posterior_nn()`.
+ - the density estimator is a `mdn`, as built with + `utils.sbi.posterior_nn()`.
+ - `isinstance(prior, MultivariateNormal)` (from `torch.distributions`) or + `isinstance(prior, sbi.utils.BoxUniform)` + + Note that custom implementations of any of these densities (or estimators) will + not trigger the non-atomic loss, and the algorithm will fall back onto using + the atomic loss. + + Args: + prior: A probability distribution that expresses prior knowledge about the + parameters, e.g. which ranges are meaningful for them. + density_estimator: If it is a string, use a pre-configured network of the + provided type (one of nsf, maf, mdn, made). Alternatively, a function + that builds a custom neural network can be provided. The function will + be called with the first batch of simulations (theta, x), which can + thus be used for shape inference and potentially for z-scoring. It + needs to return a PyTorch `nn.Module` implementing the density + estimator. The density estimator needs to provide the methods + `.log_prob` and `.sample()`. + device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". + logging_level: Minimum severity of messages to log. One of the strings + INFO, WARNING, DEBUG, ERROR and CRITICAL. + summary_writer: A tensorboard `SummaryWriter` to control, among others, log + file location (default is `/logs`.) + show_progress_bars: Whether to show a progressbar during training. + """ + + kwargs = del_entries(locals(), entries=("self", "__class__")) + super().__init__(**kwargs) + + def train( + self, + num_atoms: int = 10, + training_batch_size: int = 50, + learning_rate: float = 5e-4, + validation_fraction: float = 0.1, + stop_after_epochs: int = 20, + max_num_epochs: int = 2**31 - 1, + clip_max_norm: Optional[float] = 5.0, + calibration_kernel: Optional[Callable] = None, + resume_training: bool = False, + force_first_round_loss: bool = False, + discard_prior_samples: bool = False, + use_combined_loss: bool = False, + retrain_from_scratch: bool = False, + show_train_summary: bool = False, + dataloader_kwargs: Optional[Dict] = None, + ) -> nn.Module: + r"""Return density estimator that approximates the distribution $p(\theta|x)$. + + Args: + num_atoms: Number of atoms to use for classification. + training_batch_size: Training batch size. + learning_rate: Learning rate for Adam optimizer. + validation_fraction: The fraction of data to use for validation. + stop_after_epochs: The number of epochs to wait for improvement on the + validation set before terminating training. + max_num_epochs: Maximum number of epochs to run. If reached, we stop + training even when the validation loss is still decreasing. Otherwise, + we train until validation loss increases (see also `stop_after_epochs`). + clip_max_norm: Value at which to clip the total gradient norm in order to + prevent exploding gradients. Use None for no clipping. + calibration_kernel: A function to calibrate the loss with respect to the + simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. + resume_training: Can be used in case training time is limited, e.g. on a + cluster. If `True`, the split between train and validation set, the + optimizer, the number of epochs, and the best validation log-prob will + be restored from the last time `.train()` was called. + force_first_round_loss: If `True`, train with maximum likelihood, + i.e., potentially ignoring the correction for using a proposal + distribution different from the prior. + discard_prior_samples: Whether to discard samples simulated in round 1, i.e. + from the prior. Training may be sped up by ignoring such less targeted + samples. + use_combined_loss: Whether to train the neural net also on prior samples + using maximum likelihood in addition to training it on all samples using + atomic loss. The extra MLE loss helps prevent density leaking with + bounded priors. + retrain_from_scratch: Whether to retrain the conditional density + estimator for the posterior from scratch each round. + show_train_summary: Whether to print the number of epochs and validation + loss and leakage after the training. + dataloader_kwargs: Additional or updated kwargs to be passed to the training + and validation dataloaders (like, e.g., a collate_fn) + + Returns: + Density estimator that approximates the distribution $p(\theta|x)$. + """ + + # WARNING: sneaky trick ahead. We proxy the parent's `train` here, + # requiring the signature to have `num_atoms`, save it for use below, and + # continue. It's sneaky because we are using the object (self) as a namespace + # to pass arguments between functions, and that's implicit state management. + self._num_atoms = num_atoms + self._use_combined_loss = use_combined_loss + kwargs = del_entries( + locals(), entries=("self", "__class__", "num_atoms", "use_combined_loss") + ) + + self._round = max(self._data_round_index) + + if self._round > 0: + # Set the proposal to the last proposal that was passed by the user. For + # atomic SNPE, it does not matter what the proposal is. For non-atomic + # SNPE, we only use the latest data that was passed, i.e. the one from the + # last proposal. + proposal = self._proposal_roundwise[-1] + self.use_non_atomic_loss = ( + isinstance(proposal.posterior_estimator._distribution, mdn) + and isinstance(self._neural_net._distribution, mdn) + and check_dist_class( + self._prior, class_to_check=(Uniform, MultivariateNormal) + )[0] + ) + + algorithm = "non-atomic" if self.use_non_atomic_loss else "atomic" + print(f"Using SNPE-C with {algorithm} loss") + + if self.use_non_atomic_loss: + # Take care of z-scoring, pre-compute and store prior terms. + self._set_state_for_mog_proposal() + + return super().train(**kwargs) + + def _set_state_for_mog_proposal(self) -> None: + """Set state variables that are used at each training step of non-atomic SNPE-C. + + Three things are computed: + 1) Check if z-scoring was requested. To do so, we check if the `_transform` + argument of the net had been a `CompositeTransform`. See pyknos mdn.py. + 2) Define a (potentially standardized) prior. It's standardized if z-scoring + had been requested. + 3) Compute (Precision * mean) for the prior. This quantity is used at every + training step if the prior is Gaussian. + """ + + self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) + + self._set_maybe_z_scored_prior() + + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + self.prec_m_prod_prior = torch.mv( + self._maybe_z_scored_prior.precision_matrix, # type: ignore + self._maybe_z_scored_prior.loc, # type: ignore + ) + + def _set_maybe_z_scored_prior(self) -> None: + r"""Compute and store potentially standardized prior (if z-scoring was done). + + The proposal posterior is: + $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ + + Let's denote z-scored theta by `a`: a = (theta - mean) / std + Then pp'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ + + The ' indicates that the evaluation occurs in standardized space. The constant + scaling factor has been absorbed into Z_2. + From the above equation, we see that we need to evaluate the prior **in + standardized space**. We build the standardized prior in this function. + + The standardize transform that is applied to the samples theta does not use + the exact prior mean and std (due to implementation issues). Hence, the z-scored + prior will not be exactly have mean=0 and std=1. + """ + + if self.z_score_theta: + scale = self._neural_net._transform._transforms[0]._scale + shift = self._neural_net._transform._transforms[0]._shift + + # Following the definintion of the linear transform in + # `standardizing_transform` in `sbiutils.py`: + # shift=-mean / std + # scale=1 / std + # Solving these equations for mean and std: + estim_prior_std = 1 / scale + estim_prior_mean = -shift * estim_prior_std + + # Compute the discrepancy of the true prior mean and std and the mean and + # std that was empirically estimated from samples. + # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) + # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean + # and std (estimated from samples and used to build standardize transform). + almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std + almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std + + if isinstance(self._prior, MultivariateNormal): + self._maybe_z_scored_prior = MultivariateNormal( + almost_zero_mean, torch.diag(almost_one_std) + ) + else: + range_ = torch.sqrt(almost_one_std * 3.0) + self._maybe_z_scored_prior = utils.BoxUniform( + almost_zero_mean - range_, almost_zero_mean + range_ + ) + else: + self._maybe_z_scored_prior = self._prior + + def _log_prob_proposal_posterior( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: DirectPosterior, + ) -> Tensor: + """Return the log-probability of the proposal posterior. + + If the proposal is a MoG, the density estimator is a MoG, and the prior is + either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which + suffers from leakage). + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + proposal: Proposal distribution. + + Returns: Log-probability of the proposal posterior. + """ + + if self.use_non_atomic_loss: + return self._log_prob_proposal_posterior_mog(theta, x, proposal) + else: + return self._log_prob_proposal_posterior_atomic(theta, x, masks) + + def _log_prob_proposal_posterior_atomic( + self, theta: Tensor, x: Tensor, masks: Tensor + ): + """Return log probability of the proposal posterior for atomic proposals. + + We have two main options when evaluating the proposal posterior. + (1) Generate atoms from the proposal prior. + (2) Generate atoms from a more targeted distribution, such as the most + recent posterior. + If we choose the latter, it is likely beneficial not to do this in the first + round, since we would be sampling from a randomly-initialized neural density + estimator. + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + + Returns: + Log-probability of the proposal posterior. + """ + + batch_size = theta.shape[0] + + num_atoms = int( + clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size) + ) + + # Each set of parameter atoms is evaluated using the same x, + # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2] + repeated_x = repeat_rows(x, num_atoms) + + # To generate the full set of atoms for a given item in the batch, + # we sample without replacement num_atoms - 1 times from the rest + # of the theta in the batch. + probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) + + choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) + contrasting_theta = theta[choices] + + # We can now create our sets of atoms from the contrasting parameter sets + # we have generated. + atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( + batch_size * num_atoms, -1 + ) + + # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. + log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x) + utils.assert_all_finite(log_prob_posterior, "posterior eval") + log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) + + # Get (batch_size * num_atoms) log prob prior evals. + log_prob_prior = self._prior.log_prob(atomic_theta) + log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) + utils.assert_all_finite(log_prob_prior, "prior eval") + + # Compute unnormalized proposal posterior. + unnormalized_log_prob = log_prob_posterior - log_prob_prior + + # Normalize proposal posterior across discrete set of atoms. + log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp( + unnormalized_log_prob, dim=-1 + ) + utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") + + # XXX This evaluates the posterior on _all_ prior samples + if self._use_combined_loss: + log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x) + masks = masks.reshape(-1) + log_prob_proposal_posterior = ( + masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior + ) + + return log_prob_proposal_posterior + + def _log_prob_proposal_posterior_mog( + self, theta: Tensor, x: Tensor, proposal: DirectPosterior + ) -> Tensor: + """Return log-probability of the proposal posterior for MoG proposal. + + For MoG proposals and MoG density estimators, this can be done in closed form + and does not require atomic loss (i.e. there will be no leakage issues). + + Notation: + + m are mean vectors. + prec are precision matrices. + cov are covariance matrices. + + _p at the end indicates that it is the proposal. + _d indicates that it is the density estimator. + _pp indicates the proposal posterior. + + All tensors will have shapes (batch_dim, num_components, ...) + + Args: + theta: Batch of parameters θ. + x: Batch of data. + proposal: Proposal distribution. + + Returns: + Log-probability of the proposal posterior. + """ + + # Evaluate the proposal. MDNs do not have functionality to run the embedding_net + # and then get the mixture_components (**without** calling log_prob()). Hence, + # we call them separately here. + encoded_x = proposal.posterior_estimator._embedding_net(proposal.default_x) + dist = ( + proposal.posterior_estimator._distribution + ) # defined to avoid ugly black formatting. + logits_p, m_p, prec_p, _, _ = dist.get_mixture_components(encoded_x) + norm_logits_p = logits_p - torch.logsumexp(logits_p, dim=-1, keepdim=True) + + # Evaluate the density estimator. + encoded_x = self._neural_net._embedding_net(x) + dist = self._neural_net._distribution # defined to avoid black formatting. + logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) + norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) + + # z-score theta if it z-scoring had been requested. + theta = self._maybe_z_score_theta(theta) + + # Compute the MoG parameters of the proposal posterior. + logits_pp, m_pp, prec_pp, cov_pp = self._automatic_posterior_transformation( + norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d + ) + + # Compute the log_prob of theta under the product. + log_prob_proposal_posterior = utils.mog_log_prob( + theta, logits_pp, m_pp, prec_pp + ) + utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") + + return log_prob_proposal_posterior + + def _automatic_posterior_transformation( + self, + logits_p: Tensor, + means_p: Tensor, + precisions_p: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Returns the MoG parameters of the proposal posterior. + + The proposal posterior is: + $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ + In words: proposal posterior = posterior estimate * proposal / prior. + + If the posterior estimate and the proposal are MoG and the prior is either + Gaussian or uniform, we can solve this in closed-form. The is implemented in + this function. + + This function implements Appendix A1 from Greenberg et al. 2019. + + We have to build L*K components. How do we do this? + Example: proposal has two components, density estimator has three components. + Let's call the two components of the proposal i,j and the three components + of the density estimator x,y,z. We have to multiply every component of the + proposal with every component of the density estimator. So, what we do is: + 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave() + 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat() + 3) Multiply them with simple matrix operations. + + Args: + logits_p: Component weight of each Gaussian of the proposal. + means_p: Mean of each Gaussian of the proposal. + precisions_p: Precision matrix of each Gaussian of the proposal. + logits_d: Component weight for each Gaussian of the density estimator. + means_d: Mean of each Gaussian of the density estimator. + precisions_d: Precision matrix of each Gaussian of the density estimator. + + Returns: (Component weight, mean, precision matrix, covariance matrix) of each + Gaussian of the proposal posterior. Has L*K terms (proposal has L terms, + density estimator has K terms). + """ + + precisions_pp, covariances_pp = self._precisions_proposal_posterior( + precisions_p, precisions_d + ) + + means_pp = self._means_proposal_posterior( + covariances_pp, means_p, precisions_p, means_d, precisions_d + ) + + logits_pp = self._logits_proposal_posterior( + means_pp, + precisions_pp, + covariances_pp, + logits_p, + means_p, + precisions_p, + logits_d, + means_d, + precisions_d, + ) + + return logits_pp, means_pp, precisions_pp, covariances_pp + + def _precisions_proposal_posterior( + self, precisions_p: Tensor, precisions_d: Tensor + ): + """Return the precisions and covariances of the proposal posterior. + + Args: + precisions_p: Precision matrices of the proposal distribution. + precisions_d: Precision matrices of the density estimator. + + Returns: (Precisions, Covariances) of the proposal posterior. L*K terms. + """ + + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1) + precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) + + precisions_pp = precisions_p_rep + precisions_d_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + precisions_pp -= self._maybe_z_scored_prior.precision_matrix + + covariances_pp = torch.inverse(precisions_pp) + + return precisions_pp, covariances_pp + + def _means_proposal_posterior( + self, + covariances_pp: Tensor, + means_p: Tensor, + precisions_p: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + """Return the means of the proposal posterior. + + means_pp = C_ix * (P_i * m_i + P_x * m_x - P_o * m_o). + + Args: + covariances_pp: Covariance matrices of the proposal posterior. + means_p: Means of the proposal distribution. + precisions_p: Precision matrices of the proposal distribution. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Means of the proposal posterior. L*K terms. + """ + + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + # First, compute the product P_i * m_i and P_j * m_j + prec_m_prod_p = batched_mixture_mv(precisions_p, means_p) + prec_m_prod_d = batched_mixture_mv(precisions_d, means_d) + + # Repeat them to allow for matrix operations: same trick as for the precisions. + prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1) + prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1) + + # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o). + summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + summed_cov_m_prod_rep -= self.prec_m_prod_prior + + means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep) + + return means_pp + + @staticmethod + def _logits_proposal_posterior( + means_pp: Tensor, + precisions_pp: Tensor, + covariances_pp: Tensor, + logits_p: Tensor, + means_p: Tensor, + precisions_p: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + """Return the component weights (i.e. logits) of the proposal posterior. + + Args: + means_pp: Means of the proposal posterior. + precisions_pp: Precision matrices of the proposal posterior. + covariances_pp: Covariance matrices of the proposal posterior. + logits_p: Component weights (i.e. logits) of the proposal distribution. + means_p: Means of the proposal distribution. + precisions_p: Precision matrices of the proposal distribution. + logits_d: Component weights (i.e. logits) of the density estimator. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Component weights of the proposal posterior. L*K terms. + """ + + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute log(alpha_i * beta_j) + logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) + logits_d_rep = logits_d.repeat(1, num_comps_p) + logit_factors = logits_p_rep + logits_d_rep + + # Compute sqrt(det()/(det()*det())) + logdet_covariances_pp = torch.logdet(covariances_pp) + logdet_covariances_p = -torch.logdet(precisions_p) + logdet_covariances_d = -torch.logdet(precisions_d) + + # Repeat the proposal and density estimator terms such that there are LK terms. + # Same trick as has been used above. + logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( + num_comps_d, dim=1 + ) + logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) + + log_sqrt_det_ratio = 0.5 * ( + logdet_covariances_pp + - (logdet_covariances_p_rep + logdet_covariances_d_rep) + ) + + # Compute for proposal, density estimator, and proposal posterior: + # mu_i.T * P_i * mu_i + exponent_p = batched_mixture_vmv(precisions_p, means_p) + exponent_d = batched_mixture_vmv(precisions_d, means_d) + exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) + + # Extend proposal and density estimator exponents to get LK terms. + exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) + exponent_d_rep = exponent_d.repeat(1, num_comps_p) + exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) + + logits_pp = logit_factors + log_sqrt_det_ratio + exponent + + return logits_pp + + def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: + """Return potentially standardized theta if z-scoring was requested.""" + + if self.z_score_theta: + theta, _ = self._neural_net._transform(theta) + + return theta diff --git a/sbi/inference/snre/snre_a.py b/sbi/inference/snre/snre_a.py index 1416abd1c..ce06a96e9 100644 --- a/sbi/inference/snre/snre_a.py +++ b/sbi/inference/snre/snre_a.py @@ -55,7 +55,6 @@ def train( stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, - exclude_invalid_x: bool = True, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, @@ -75,8 +74,6 @@ def train( we train until validation loss increases (see also `stop_after_epochs`). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will diff --git a/sbi/inference/snre/snre_b.py b/sbi/inference/snre/snre_b.py index d2db9212a..bb53cd09d 100644 --- a/sbi/inference/snre/snre_b.py +++ b/sbi/inference/snre/snre_b.py @@ -56,7 +56,6 @@ def train( stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, - exclude_invalid_x: bool = True, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, @@ -77,8 +76,6 @@ def train( we train until validation loss increases (see also `stop_after_epochs`). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index 09898527e..698a57612 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -17,8 +17,11 @@ check_estimator_arg, check_prior, clamp_and_warn, + handle_invalid_x, validate_theta_and_x, x_shape_from_simulation, + warn_if_zscoring_changes_data, + warn_on_invalid_x, ) from sbi.utils.sbiutils import mask_sims_from_prior @@ -82,6 +85,11 @@ def append_simulations( theta: Tensor, x: Tensor, from_round: int = 0, + exclude_invalid_x: bool = True, + warn_on_invalid: bool = True, + warn_if_zscoring: bool = True, + return_self: bool = True, + data_device: str = None, ) -> "RatioEstimator": r"""Store parameters and simulation outputs to use them for later training. @@ -98,19 +106,48 @@ def append_simulations( With default settings, this is not used at all for `SNRE`. Only when the user later on requests `.train(discard_prior_samples=True)`, we use these indices to find which training data stemmed from the prior. - + exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` + during training. Expect errors, silent or explicit, when `False`. + warn_on_invalid: Whether to warn if data is invalid + warn_if_zscoring: Whether to test if z-scoring causes duplicates + return_self: Whether to return a instance of the class, allows chaining + with `.train()`. Setting `False` decreases memory overhead. + data_device: Where to store the data, default is on the same device where + the training is happening. If training a large dataset on a GPU with not + much VRAM can set to 'cpu' to store data on system memory instead. Returns: NeuralInference object (returned so that this function is chainable). """ theta, x = validate_theta_and_x(theta, x, training_device=self._device) - self._theta_roundwise.append(theta) - self._x_roundwise.append(x) - self._prior_masks.append(mask_sims_from_prior(int(from_round), theta.size(0))) - self._data_round_index.append(int(from_round)) + is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) + + # Check for problematic z-scoring + if warn_if_zscoring: + warn_if_zscoring_changes_data(x[is_valid_x]) + if warn_on_invalid: + warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) + + x = x[is_valid_x] + theta = theta[is_valid_x] + + if data_device is None: data_device = self._device + theta, x = validate_theta_and_x(theta, x, training_device=data_device) + prior_masks = mask_sims_from_prior(int(from_round), theta.size(0)) - return self + if self._dataset is None: + #If first round, set up ConcatDataset + self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) + else: + #Otherwise append to Dataset + self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) + + self._num_sims_per_round.append(theta.size(0)) + self._data_round_index.append(int(from_round) ) + + if return_self: + return self def train( self, @@ -154,15 +191,9 @@ def train( start_idx = int(discard_prior_samples and self._round > 0) # Load data from most recent round. self._round = max(self._data_round_index) - theta, x, _ = self.get_simulations( - start_idx, exclude_invalid_x, warn_on_invalid=True - ) - - # Dataset is shared for training and validation loaders. - dataset = data.TensorDataset(theta, x) train_loader, val_loader = self.get_dataloaders( - dataset, + start_idx, training_batch_size, validation_fraction, resume_training, @@ -183,11 +214,15 @@ def train( # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: + + #Get theta,x from dataset to initialize NN + test_theta = self._dataset.datasets[0].tensors[0][:100] + test_x = self._dataset.datasets[0].tensors[1][:100] + self._neural_net = self._build_neural_net( - theta[self.train_indices], x[self.train_indices] + test_theta, test_x ) - self._x_shape = x_shape_from_simulation(x) - + self._x_shape = x_shape_from_simulation(test_x) self._neural_net.to(self._device) if not resume_training: @@ -260,8 +295,8 @@ def train( self._summarize( round_=self._round, x_o=None, - theta_bank=theta, - x_bank=x, + theta_bank=None, + x_bank=None, ) # Update description for progress bar. diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index bde9d7375..4f26cf848 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -348,6 +348,29 @@ def get_simulations_since_round( [t for t, r in zip(data, data_round_indices) if r >= starting_round_index] ) +def get_simulations_indcies( + num_sims_per_round: List, data_round_indices: List, starting_round_index: int +) -> Tensor: + """ + Returns indicies for for all simulations round >= `starting_round`. Used in + `get_dataloaders` and `get_simulations` + + Args: + num_sims_per_round: Number of simulations per round + data_round_indices: List with same length as data, each entry is an integer that + indicates which round the data is from. + starting_round_index: From which round onwards to return the data. We start + counting from 0. + """ + inds = [] + for j, (n,r) in enumerate(zip(num_sims_per_round,data_round_indices)): + + #Where to start counting + s_ind = sum(num_sims_per_round[:j]) + + if r >= starting_round_index: + inds.append(torch.arange(s_ind,s_ind + n) ) + return torch.cat(inds) def mask_sims_from_prior(round_: int, num_simulations: int) -> Tensor: """Returns Tensor True where simulated from prior parameters. diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 2d84aa759..b43bd0144 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -1,583 +1,583 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . - -from __future__ import annotations - -import numpy as np -import pytest -import torch -from scipy.stats import gaussian_kde -from torch import eye, ones, zeros -from torch.distributions import MultivariateNormal - -from sbi import analysis as analysis -from sbi import utils as utils -from sbi.analysis import ConditionedMDN, conditonal_potential -from sbi.inference import ( - SNPE_A, - SNPE_B, - SNPE_C, - DirectPosterior, - MCMCPosterior, - RejectionPosterior, - posterior_estimator_based_potential, - prepare_for_sbi, - simulate_for_sbi, -) -from sbi.simulators.linear_gaussian import ( - linear_gaussian, - samples_true_posterior_linear_gaussian_mvn_prior_different_dims, - samples_true_posterior_linear_gaussian_uniform_prior, - true_posterior_linear_gaussian_mvn_prior, -) -from tests.sbiutils_test import conditional_of_mvn -from tests.test_utils import ( - check_c2st, - get_dkl_gaussian_prior, - get_normalization_uniform_prior, - get_prob_outside_uniform_prior, -) - - -@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) -@pytest.mark.parametrize( - "num_dim, prior_str, num_trials", - ( - (2, "gaussian", 1), - (2, "uniform", 1), - (1, "gaussian", 1), - # no iid x in snpe. - pytest.param(1, "gaussian", 2, marks=pytest.mark.xfail), - pytest.param(2, "gaussian", 2, marks=pytest.mark.xfail), - ), -) -def test_c2st_snpe_on_linearGaussian( - snpe_method, num_dim: int, prior_str: str, num_trials: int -): - """Test whether SNPE infers well a simple example with available ground truth.""" - - x_o = zeros(num_trials, num_dim) - num_samples = 1000 - num_simulations = 2600 - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - if prior_str == "gaussian": - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - gt_posterior = true_posterior_linear_gaussian_mvn_prior( - x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - target_samples = gt_posterior.sample((num_samples,)) - else: - prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) - target_samples = samples_true_posterior_linear_gaussian_uniform_prior( - x_o, likelihood_shift, likelihood_cov, prior=prior, num_samples=num_samples - ) - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior - ) - - inference = snpe_method(prior, show_progress_bars=False) - - theta, x = simulate_for_sbi( - simulator, prior, num_simulations, simulation_batch_size=1000 - ) - posterior_estimator = inference.append_simulations(theta, x).train( - training_batch_size=100 - ) - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - samples = posterior.sample((num_samples,)) - - # Compute the c2st and assert it is near chance level of 0.5. - check_c2st(samples, target_samples, alg="snpe_c") - - map_ = posterior.map(num_init_samples=1_000, show_progress_bars=False) - - # Checks for log_prob() - if prior_str == "gaussian": - # For the Gaussian prior, we compute the KLd between ground truth and posterior. - dkl = get_dkl_gaussian_prior( - posterior, x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - - max_dkl = 0.15 - - assert ( - dkl < max_dkl - ), f"D-KL={dkl} is more than 2 stds above the average performance." - - assert ((map_ - gt_posterior.mean) ** 2).sum() < 0.5 - - elif prior_str == "uniform": - # Check whether the returned probability outside of the support is zero. - posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim) - assert ( - posterior_prob == 0.0 - ), "The posterior probability outside of the prior support is not zero" - - # Check whether normalization (i.e. scaling up the density due - # to leakage into regions without prior support) scales up the density by the - # correct factor. - ( - posterior_likelihood_unnorm, - posterior_likelihood_norm, - acceptance_prob, - ) = get_normalization_uniform_prior(posterior, prior, x=x_o) - # The acceptance probability should be *exactly* the ratio of the unnormalized - # and the normalized likelihood. However, we allow for an error margin of 1%, - # since the estimation of the acceptance probability is random (based on - # rejection sampling). - assert ( - acceptance_prob * 0.99 - < posterior_likelihood_unnorm / posterior_likelihood_norm - < acceptance_prob * 1.01 - ), "Normalizing the posterior density using the acceptance probability failed." - - assert ((map_ - ones(num_dim)) ** 2).sum() < 0.5 - - -def test_c2st_snpe_on_linearGaussian_different_dims(): - """Test whether SNPE B/C infer well a simple example with available ground truth. - - This example has different number of parameters theta than number of x. Also - this implicitly tests simulation_batch_size=1. It also impleictly tests whether the - prior can be `None` and whether we can stop and resume training. - - """ - - theta_dim = 3 - x_dim = 2 - discard_dims = theta_dim - x_dim - - x_o = zeros(1, x_dim) - num_samples = 1000 - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(x_dim) - likelihood_cov = 0.3 * eye(x_dim) - - prior_mean = zeros(theta_dim) - prior_cov = eye(theta_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - target_samples = samples_true_posterior_linear_gaussian_mvn_prior_different_dims( - x_o, - likelihood_shift, - likelihood_cov, - prior_mean, - prior_cov, - num_discarded_dims=discard_dims, - num_samples=num_samples, - ) - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian( - theta, likelihood_shift, likelihood_cov, num_discarded_dims=discard_dims - ), - prior, - ) - # Test whether prior can be `None`. - inference = SNPE_C(prior=None, density_estimator="maf", show_progress_bars=False) - - # type: ignore - theta, x = simulate_for_sbi(simulator, prior, 2000, simulation_batch_size=1) - - inference = inference.append_simulations(theta, x) - posterior_estimator = inference.train( - max_num_epochs=10 - ) # Test whether we can stop and resume. - posterior_estimator = inference.train( - resume_training=True, force_first_round_loss=True - ) - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - samples = posterior.sample((num_samples,)) - - # Compute the c2st and assert it is near chance level of 0.5. - check_c2st(samples, target_samples, alg="snpe_c") - - -# Test multi-round SNPE. -@pytest.mark.slow -@pytest.mark.parametrize( - "method_str", - ( - "snpe_a", - pytest.param( - "snpe_b", - marks=pytest.mark.xfail( - raises=NotImplementedError, reason="""SNPE-B not implemented""" - ), - ), - "snpe_c", - "snpe_c_non_atomic", - ), -) -def test_c2st_multi_round_snpe_on_linearGaussian(method_str: str): - """Test whether SNPE B/C infer well a simple example with available ground truth. - . - """ - - num_dim = 2 - x_o = zeros((1, num_dim)) - num_samples = 1000 - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - - gt_posterior = true_posterior_linear_gaussian_mvn_prior( - x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - target_samples = gt_posterior.sample((num_samples,)) - - if method_str == "snpe_c_non_atomic": - # Test whether SNPE works properly with structured z-scoring. - density_estimator = utils.posterior_nn( - "mdn", z_score_x="structured", num_components=5 - ) - method_str = "snpe_c" - elif method_str == "snpe_a": - density_estimator = "mdn_snpe_a" - else: - density_estimator = "maf" - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior - ) - creation_args = dict( - prior=prior, - density_estimator=density_estimator, - show_progress_bars=False, - ) - - if method_str == "snpe_b": - inference = SNPE_B(**creation_args) - theta, x = simulate_for_sbi(simulator, prior, 500, simulation_batch_size=10) - posterior_estimator = inference.append_simulations(theta, x).train() - posterior1 = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - theta, x = simulate_for_sbi( - simulator, posterior1, 1000, simulation_batch_size=10 - ) - posterior_estimator = inference.append_simulations( - theta, x, proposal=posterior1 - ).train() - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - elif method_str == "snpe_c": - inference = SNPE_C(**creation_args) - theta, x = simulate_for_sbi(simulator, prior, 900, simulation_batch_size=50) - posterior_estimator = inference.append_simulations(theta, x).train() - posterior1 = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - theta = posterior1.sample((1000,)) - x = simulator(theta) - _ = inference.append_simulations(theta, x, proposal=posterior1).train() - posterior = inference.build_posterior().set_default_x(x_o) - elif method_str == "snpe_a": - inference = SNPE_A(**creation_args) - proposal = prior - final_round = False - num_rounds = 3 - for r in range(num_rounds): - if r == 2: - final_round = True - theta, x = simulate_for_sbi( - simulator, proposal, 500, simulation_batch_size=50 - ) - inference = inference.append_simulations(theta, x, proposal=proposal) - _ = inference.train(max_num_epochs=200, final_round=final_round) - posterior = inference.build_posterior().set_default_x(x_o) - proposal = posterior - - samples = posterior.sample((num_samples,)) - - # Compute the c2st and assert it is near chance level of 0.5. - check_c2st(samples, target_samples, alg=method_str) - - -# Testing rejection and mcmc sampling methods. -@pytest.mark.slow -@pytest.mark.parametrize( - "sample_with, mcmc_method, prior_str", - ( - ("mcmc", "slice_np", "gaussian"), - ("mcmc", "slice", "gaussian"), - # XXX (True, "slice", "uniform"), - # XXX takes very long. fix when refactoring pyro sampling - ("rejection", "rejection", "uniform"), - ), -) -def test_api_snpe_c_posterior_correction(sample_with, mcmc_method, prior_str): - """Test that leakage correction applied to sampling works, with both MCMC and - rejection. - - """ - - num_dim = 2 - x_o = zeros(1, num_dim) - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - if prior_str == "gaussian": - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - else: - prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior - ) - inference = SNPE_C(prior, show_progress_bars=False) - - theta, x = simulate_for_sbi(simulator, prior, 1000) - posterior_estimator = inference.append_simulations(theta, x).train() - potential_fn, theta_transform = posterior_estimator_based_potential( - posterior_estimator, prior, x_o - ) - if sample_with == "mcmc": - posterior = MCMCPosterior( - potential_fn=potential_fn, - theta_transform=theta_transform, - proposal=prior, - method=mcmc_method, - ) - elif sample_with == "rejection": - posterior = RejectionPosterior( - potential_fn=potential_fn, - proposal=prior, - theta_transform=theta_transform, - ) - - # Posterior should be corrected for leakage even if num_rounds just 1. - samples = posterior.sample((10,)) - - # Evaluate the samples to check correction factor. - _ = posterior.log_prob(samples) - - -@pytest.mark.slow -def test_sample_conditional(): - """ - Test whether sampling from the conditional gives the same results as evaluating. - - This compares samples that get smoothed with a Gaussian kde to evaluating the - conditional log-probability with `eval_conditional_density`. - - `eval_conditional_density` is itself tested in `sbiutils_test.py`. Here, we use - a bimodal posterior to test the conditional. - """ - - num_dim = 3 - dim_to_sample_1 = 0 - dim_to_sample_2 = 2 - - x_o = zeros(1, num_dim) - - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.1 * eye(num_dim) - - prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) - - def simulator(theta): - if torch.rand(1) > 0.5: - return linear_gaussian(theta, likelihood_shift, likelihood_cov) - else: - return linear_gaussian(theta, -likelihood_shift, likelihood_cov) - - # Test whether SNPE works properly with structured z-scoring. - net = utils.posterior_nn("maf", z_score_x="structured", hidden_features=20) - - simulator, prior = prepare_for_sbi(simulator, prior) - - inference = SNPE_C(prior, density_estimator=net, show_progress_bars=False) - - # We need a pretty big dataset to properly model the bimodality. - theta, x = simulate_for_sbi(simulator, prior, 10000) - posterior_estimator = inference.append_simulations(theta, x).train( - max_num_epochs=60 - ) - - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - samples = posterior.sample((50,)) - - # Evaluate the conditional density be drawing samples and smoothing with a Gaussian - # kde. - potential_fn, theta_transform = posterior_estimator_based_potential( - posterior_estimator, prior=prior, x_o=x_o - ) - (conditioned_potential_fn, restricted_tf, restricted_prior,) = conditonal_potential( - potential_fn=potential_fn, - theta_transform=theta_transform, - prior=prior, - condition=samples[0], - dims_to_sample=[dim_to_sample_1, dim_to_sample_2], - ) - mcmc_posterior = MCMCPosterior( - potential_fn=conditioned_potential_fn, - theta_transform=restricted_tf, - proposal=restricted_prior, - ) - cond_samples = mcmc_posterior.sample((500,)) - - _ = analysis.pairplot( - cond_samples, - limits=[[-2, 2], [-2, 2], [-2, 2]], - figsize=(2, 2), - diag="kde", - upper="kde", - ) - - limits = [[-2, 2], [-2, 2], [-2, 2]] - - density = gaussian_kde(cond_samples.numpy().T, bw_method="scott") - - X, Y = np.meshgrid( - np.linspace(limits[0][0], limits[0][1], 50), - np.linspace(limits[1][0], limits[1][1], 50), - ) - positions = np.vstack([X.ravel(), Y.ravel()]) - sample_kde_grid = np.reshape(density(positions).T, X.shape) - - # Evaluate the conditional with eval_conditional_density. - eval_grid = analysis.eval_conditional_density( - posterior, - condition=samples[0], - dim1=dim_to_sample_1, - dim2=dim_to_sample_2, - limits=torch.tensor([[-2, 2], [-2, 2], [-2, 2]]), - ) - - # Compare the two densities. - sample_kde_grid = sample_kde_grid / np.sum(sample_kde_grid) - eval_grid = eval_grid / torch.sum(eval_grid) - - error = np.abs(sample_kde_grid - eval_grid.numpy()) - - max_err = np.max(error) - assert max_err < 0.0027 - - -def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): - """Test whether the conditional density infered from MDN parameters of a - `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and - conditions on the last m values to generate a conditional. - - Gaussian prior used for easier ground truthing of conditional posterior. - - Args: - num_dim: Dimensionality of the MVM. - cond_dim: Dimensionality of the condition. - """ - - assert ( - num_dim > cond_dim - ), "The number of dimensions needs to be greater than that of the condition!" - - x_o = zeros(1, num_dim) - num_samples = 1000 - num_simulations = 2700 - condition = 0.1 * ones(1, num_dim) - - dims = list(range(num_dim)) - dims2sample = dims[-cond_dim:] - dims2condition = dims[:-cond_dim] - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - - joint_posterior = true_posterior_linear_gaussian_mvn_prior( - x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - joint_cov = joint_posterior.covariance_matrix - joint_mean = joint_posterior.loc - - conditional_mean, conditional_cov = conditional_of_mvn( - joint_mean, joint_cov, condition[0, dims2condition] - ) - conditional_dist_gt = MultivariateNormal(conditional_mean, conditional_cov) - - conditional_samples_gt = conditional_dist_gt.sample((num_samples,)) - - def simulator(theta): - return linear_gaussian(theta, likelihood_shift, likelihood_cov) - - simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE_C(density_estimator="mdn", show_progress_bars=False) - - theta, x = simulate_for_sbi( - simulator, prior, num_simulations, simulation_batch_size=1000 - ) - posterior_mdn = inference.append_simulations(theta, x).train( - training_batch_size=100 - ) - conditioned_mdn = ConditionedMDN( - posterior_mdn, x_o, condition=condition, dims_to_sample=[0] - ) - conditional_samples_sbi = conditioned_mdn.sample((num_samples,)) - check_c2st( - conditional_samples_sbi, - conditional_samples_gt, - alg="analytic_mdn_conditioning_of_direct_posterior", - ) - - -@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) -def test_example_posterior(snpe_method: type): - """Return an inferred `NeuralPosterior` for interactive examination.""" - num_dim = 2 - x_o = zeros(1, num_dim) - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - - if snpe_method == SNPE_A: - extra_kwargs = dict(final_round=True) - else: - extra_kwargs = dict() - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior - ) - inference = snpe_method(prior, show_progress_bars=False) - - theta, x = simulate_for_sbi( - simulator, prior, 1000, simulation_batch_size=10, num_workers=6 - ) - posterior_estimator = inference.append_simulations(theta, x).train(**extra_kwargs) - if snpe_method == SNPE_A: - posterior_estimator = inference.correct_for_proposal() - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - assert posterior is not None +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from __future__ import annotations + +import numpy as np +import pytest +import torch +from scipy.stats import gaussian_kde +from torch import eye, ones, zeros +from torch.distributions import MultivariateNormal + +from sbi import analysis as analysis +from sbi import utils as utils +from sbi.analysis import ConditionedMDN, conditonal_potential +from sbi.inference import ( + SNPE_A, + SNPE_B, + SNPE_C, + DirectPosterior, + MCMCPosterior, + RejectionPosterior, + posterior_estimator_based_potential, + prepare_for_sbi, + simulate_for_sbi, +) +from sbi.simulators.linear_gaussian import ( + linear_gaussian, + samples_true_posterior_linear_gaussian_mvn_prior_different_dims, + samples_true_posterior_linear_gaussian_uniform_prior, + true_posterior_linear_gaussian_mvn_prior, +) +from tests.sbiutils_test import conditional_of_mvn +from tests.test_utils import ( + check_c2st, + get_dkl_gaussian_prior, + get_normalization_uniform_prior, + get_prob_outside_uniform_prior, +) + + +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +@pytest.mark.parametrize( + "num_dim, prior_str, num_trials", + ( + (2, "gaussian", 1), + (2, "uniform", 1), + (1, "gaussian", 1), + # no iid x in snpe. + pytest.param(1, "gaussian", 2, marks=pytest.mark.xfail), + pytest.param(2, "gaussian", 2, marks=pytest.mark.xfail), + ), +) +def test_c2st_snpe_on_linearGaussian( + snpe_method, num_dim: int, prior_str: str, num_trials: int +): + """Test whether SNPE infers well a simple example with available ground truth.""" + + x_o = zeros(num_trials, num_dim) + num_samples = 1000 + num_simulations = 2600 + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + if prior_str == "gaussian": + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + gt_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + target_samples = gt_posterior.sample((num_samples,)) + else: + prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) + target_samples = samples_true_posterior_linear_gaussian_uniform_prior( + x_o, likelihood_shift, likelihood_cov, prior=prior, num_samples=num_samples + ) + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior + ) + + inference = snpe_method(prior, show_progress_bars=False) + + theta, x = simulate_for_sbi( + simulator, prior, num_simulations, simulation_batch_size=1000 + ) + posterior_estimator = inference.append_simulations(theta, x).train( + training_batch_size=100 + ) + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + samples = posterior.sample((num_samples,)) + + # Compute the c2st and assert it is near chance level of 0.5. + check_c2st(samples, target_samples, alg="snpe_c") + + map_ = posterior.map(num_init_samples=1_000, show_progress_bars=False) + + # Checks for log_prob() + if prior_str == "gaussian": + # For the Gaussian prior, we compute the KLd between ground truth and posterior. + dkl = get_dkl_gaussian_prior( + posterior, x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + + max_dkl = 0.15 + + assert ( + dkl < max_dkl + ), f"D-KL={dkl} is more than 2 stds above the average performance." + + assert ((map_ - gt_posterior.mean) ** 2).sum() < 0.5 + + elif prior_str == "uniform": + # Check whether the returned probability outside of the support is zero. + posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim) + assert ( + posterior_prob == 0.0 + ), "The posterior probability outside of the prior support is not zero" + + # Check whether normalization (i.e. scaling up the density due + # to leakage into regions without prior support) scales up the density by the + # correct factor. + ( + posterior_likelihood_unnorm, + posterior_likelihood_norm, + acceptance_prob, + ) = get_normalization_uniform_prior(posterior, prior, x=x_o) + # The acceptance probability should be *exactly* the ratio of the unnormalized + # and the normalized likelihood. However, we allow for an error margin of 1%, + # since the estimation of the acceptance probability is random (based on + # rejection sampling). + assert ( + acceptance_prob * 0.99 + < posterior_likelihood_unnorm / posterior_likelihood_norm + < acceptance_prob * 1.01 + ), "Normalizing the posterior density using the acceptance probability failed." + + assert ((map_ - ones(num_dim)) ** 2).sum() < 0.5 + + +def test_c2st_snpe_on_linearGaussian_different_dims(): + """Test whether SNPE B/C infer well a simple example with available ground truth. + + This example has different number of parameters theta than number of x. Also + this implicitly tests simulation_batch_size=1. It also impleictly tests whether the + prior can be `None` and whether we can stop and resume training. + + """ + + theta_dim = 3 + x_dim = 2 + discard_dims = theta_dim - x_dim + + x_o = zeros(1, x_dim) + num_samples = 1000 + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(x_dim) + likelihood_cov = 0.3 * eye(x_dim) + + prior_mean = zeros(theta_dim) + prior_cov = eye(theta_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + target_samples = samples_true_posterior_linear_gaussian_mvn_prior_different_dims( + x_o, + likelihood_shift, + likelihood_cov, + prior_mean, + prior_cov, + num_discarded_dims=discard_dims, + num_samples=num_samples, + ) + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian( + theta, likelihood_shift, likelihood_cov, num_discarded_dims=discard_dims + ), + prior, + ) + # Test whether prior can be `None`. + inference = SNPE_C(prior=None, density_estimator="maf", show_progress_bars=False) + + # type: ignore + theta, x = simulate_for_sbi(simulator, prior, 2000, simulation_batch_size=1) + + inference = inference.append_simulations(theta, x) + posterior_estimator = inference.train( + max_num_epochs=10 + ) # Test whether we can stop and resume. + posterior_estimator = inference.train( + resume_training=True, force_first_round_loss=True + ) + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + samples = posterior.sample((num_samples,)) + + # Compute the c2st and assert it is near chance level of 0.5. + check_c2st(samples, target_samples, alg="snpe_c") + + +# Test multi-round SNPE. +@pytest.mark.slow +@pytest.mark.parametrize( + "method_str", + ( + "snpe_a", + pytest.param( + "snpe_b", + marks=pytest.mark.xfail( + raises=NotImplementedError, reason="""SNPE-B not implemented""" + ), + ), + "snpe_c", + "snpe_c_non_atomic", + ), +) +def test_c2st_multi_round_snpe_on_linearGaussian(method_str: str): + """Test whether SNPE B/C infer well a simple example with available ground truth. + . + """ + + num_dim = 2 + x_o = zeros((1, num_dim)) + num_samples = 1000 + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + gt_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + target_samples = gt_posterior.sample((num_samples,)) + + if method_str == "snpe_c_non_atomic": + # Test whether SNPE works properly with structured z-scoring. + density_estimator = utils.posterior_nn( + "mdn", z_score_x="structured", num_components=5 + ) + method_str = "snpe_c" + elif method_str == "snpe_a": + density_estimator = "mdn_snpe_a" + else: + density_estimator = "maf" + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior + ) + creation_args = dict( + prior=prior, + density_estimator=density_estimator, + show_progress_bars=False, + ) + + if method_str == "snpe_b": + inference = SNPE_B(**creation_args) + theta, x = simulate_for_sbi(simulator, prior, 500, simulation_batch_size=10) + posterior_estimator = inference.append_simulations(theta, x).train() + posterior1 = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + theta, x = simulate_for_sbi( + simulator, posterior1, 1000, simulation_batch_size=10 + ) + posterior_estimator = inference.append_simulations( + theta, x, proposal=posterior1 + ).train() + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + elif method_str == "snpe_c": + inference = SNPE_C(**creation_args) + theta, x = simulate_for_sbi(simulator, prior, 900, simulation_batch_size=50) + posterior_estimator = inference.append_simulations(theta, x).train() + posterior1 = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + theta = posterior1.sample((1000,)) + x = simulator(theta) + _ = inference.append_simulations(theta, x, proposal=posterior1).train() + posterior = inference.build_posterior().set_default_x(x_o) + elif method_str == "snpe_a": + inference = SNPE_A(**creation_args) + proposal = prior + final_round = False + num_rounds = 3 + for r in range(num_rounds): + if r == 2: + final_round = True + theta, x = simulate_for_sbi( + simulator, proposal, 500, simulation_batch_size=50 + ) + inference = inference.append_simulations(theta, x, proposal=proposal) + _ = inference.train(max_num_epochs=200, final_round=final_round) + posterior = inference.build_posterior().set_default_x(x_o) + proposal = posterior + + samples = posterior.sample((num_samples,)) + + # Compute the c2st and assert it is near chance level of 0.5. + check_c2st(samples, target_samples, alg=method_str) + + +# Testing rejection and mcmc sampling methods. +@pytest.mark.slow +@pytest.mark.parametrize( + "sample_with, mcmc_method, prior_str", + ( + ("mcmc", "slice_np", "gaussian"), + ("mcmc", "slice", "gaussian"), + # XXX (True, "slice", "uniform"), + # XXX takes very long. fix when refactoring pyro sampling + ("rejection", "rejection", "uniform"), + ), +) +def test_api_snpe_c_posterior_correction(sample_with, mcmc_method, prior_str): + """Test that leakage correction applied to sampling works, with both MCMC and + rejection. + + """ + + num_dim = 2 + x_o = zeros(1, num_dim) + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + if prior_str == "gaussian": + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + else: + prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior + ) + inference = SNPE_C(prior, show_progress_bars=False) + + theta, x = simulate_for_sbi(simulator, prior, 1000) + posterior_estimator = inference.append_simulations(theta, x).train() + potential_fn, theta_transform = posterior_estimator_based_potential( + posterior_estimator, prior, x_o + ) + if sample_with == "mcmc": + posterior = MCMCPosterior( + potential_fn=potential_fn, + theta_transform=theta_transform, + proposal=prior, + method=mcmc_method, + ) + elif sample_with == "rejection": + posterior = RejectionPosterior( + potential_fn=potential_fn, + proposal=prior, + theta_transform=theta_transform, + ) + + # Posterior should be corrected for leakage even if num_rounds just 1. + samples = posterior.sample((10,)) + + # Evaluate the samples to check correction factor. + _ = posterior.log_prob(samples) + + +@pytest.mark.slow +def test_sample_conditional(): + """ + Test whether sampling from the conditional gives the same results as evaluating. + + This compares samples that get smoothed with a Gaussian kde to evaluating the + conditional log-probability with `eval_conditional_density`. + + `eval_conditional_density` is itself tested in `sbiutils_test.py`. Here, we use + a bimodal posterior to test the conditional. + """ + + num_dim = 3 + dim_to_sample_1 = 0 + dim_to_sample_2 = 2 + + x_o = zeros(1, num_dim) + + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.1 * eye(num_dim) + + prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) + + def simulator(theta): + if torch.rand(1) > 0.5: + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + else: + return linear_gaussian(theta, -likelihood_shift, likelihood_cov) + + # Test whether SNPE works properly with structured z-scoring. + net = utils.posterior_nn("maf", z_score_x="structured", hidden_features=20) + + simulator, prior = prepare_for_sbi(simulator, prior) + + inference = SNPE_C(prior, density_estimator=net, show_progress_bars=False) + + # We need a pretty big dataset to properly model the bimodality. + theta, x = simulate_for_sbi(simulator, prior, 10000) + posterior_estimator = inference.append_simulations(theta, x).train( + max_num_epochs=60 + ) + + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + samples = posterior.sample((50,)) + + # Evaluate the conditional density be drawing samples and smoothing with a Gaussian + # kde. + potential_fn, theta_transform = posterior_estimator_based_potential( + posterior_estimator, prior=prior, x_o=x_o + ) + (conditioned_potential_fn, restricted_tf, restricted_prior,) = conditonal_potential( + potential_fn=potential_fn, + theta_transform=theta_transform, + prior=prior, + condition=samples[0], + dims_to_sample=[dim_to_sample_1, dim_to_sample_2], + ) + mcmc_posterior = MCMCPosterior( + potential_fn=conditioned_potential_fn, + theta_transform=restricted_tf, + proposal=restricted_prior, + ) + cond_samples = mcmc_posterior.sample((500,)) + + _ = analysis.pairplot( + cond_samples, + limits=[[-2, 2], [-2, 2], [-2, 2]], + figsize=(2, 2), + diag="kde", + upper="kde", + ) + + limits = [[-2, 2], [-2, 2], [-2, 2]] + + density = gaussian_kde(cond_samples.numpy().T, bw_method="scott") + + X, Y = np.meshgrid( + np.linspace(limits[0][0], limits[0][1], 50), + np.linspace(limits[1][0], limits[1][1], 50), + ) + positions = np.vstack([X.ravel(), Y.ravel()]) + sample_kde_grid = np.reshape(density(positions).T, X.shape) + + # Evaluate the conditional with eval_conditional_density. + eval_grid = analysis.eval_conditional_density( + posterior, + condition=samples[0], + dim1=dim_to_sample_1, + dim2=dim_to_sample_2, + limits=torch.tensor([[-2, 2], [-2, 2], [-2, 2]]), + ) + + # Compare the two densities. + sample_kde_grid = sample_kde_grid / np.sum(sample_kde_grid) + eval_grid = eval_grid / torch.sum(eval_grid) + + error = np.abs(sample_kde_grid - eval_grid.numpy()) + + max_err = np.max(error) + assert max_err < 0.0027 + + +def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): + """Test whether the conditional density infered from MDN parameters of a + `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and + conditions on the last m values to generate a conditional. + + Gaussian prior used for easier ground truthing of conditional posterior. + + Args: + num_dim: Dimensionality of the MVM. + cond_dim: Dimensionality of the condition. + """ + + assert ( + num_dim > cond_dim + ), "The number of dimensions needs to be greater than that of the condition!" + + x_o = zeros(1, num_dim) + num_samples = 1000 + num_simulations = 2700 + condition = 0.1 * ones(1, num_dim) + + dims = list(range(num_dim)) + dims2sample = dims[-cond_dim:] + dims2condition = dims[:-cond_dim] + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + joint_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + joint_cov = joint_posterior.covariance_matrix + joint_mean = joint_posterior.loc + + conditional_mean, conditional_cov = conditional_of_mvn( + joint_mean, joint_cov, condition[0, dims2condition] + ) + conditional_dist_gt = MultivariateNormal(conditional_mean, conditional_cov) + + conditional_samples_gt = conditional_dist_gt.sample((num_samples,)) + + def simulator(theta): + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + + simulator, prior = prepare_for_sbi(simulator, prior) + inference = SNPE_C(density_estimator="mdn", show_progress_bars=False) + + theta, x = simulate_for_sbi( + simulator, prior, num_simulations, simulation_batch_size=1000 + ) + posterior_mdn = inference.append_simulations(theta, x).train( + training_batch_size=100 + ) + conditioned_mdn = ConditionedMDN( + posterior_mdn, x_o, condition=condition, dims_to_sample=[0] + ) + conditional_samples_sbi = conditioned_mdn.sample((num_samples,)) + check_c2st( + conditional_samples_sbi, + conditional_samples_gt, + alg="analytic_mdn_conditioning_of_direct_posterior", + ) + + +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +def test_example_posterior(snpe_method: type): + """Return an inferred `NeuralPosterior` for interactive examination.""" + num_dim = 2 + x_o = zeros(1, num_dim) + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + if snpe_method == SNPE_A: + extra_kwargs = dict(final_round=True) + else: + extra_kwargs = dict() + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior + ) + inference = snpe_method(prior, show_progress_bars=False) + + theta, x = simulate_for_sbi( + simulator, prior, 1000, simulation_batch_size=10, num_workers=6 + ) + posterior_estimator = inference.append_simulations(theta, x).train(**extra_kwargs) + if snpe_method == SNPE_A: + posterior_estimator = inference.correct_for_proposal() + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + assert posterior is not None diff --git a/tutorials/07_conditional_distributions.ipynb b/tutorials/07_conditional_distributions.ipynb index f84014a88..2e46f8cd0 100644 --- a/tutorials/07_conditional_distributions.ipynb +++ b/tutorials/07_conditional_distributions.ipynb @@ -1,16588 +1,16588 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Analysing variability and compensation mechansims with conditional distributions\n", - "\n", - "A central advantage of `sbi` over parameter search methods such as genetic algorithms is that the posterior captures **all** models that can reproduce experimental data. This allows us to analyse whether parameters can be variable or have to be narrowly tuned, and to analyse compensation mechanisms between different parameters. See also [Marder and Taylor, 2011](https://www.nature.com/articles/nn.2735?page=2) for further motivation to identify all models that capture experimental data. \n", - "\n", - "In this tutorial, we will show how one can use the posterior distribution to identify whether parameters can be variable or have to be finely tuned, and how we can use the posterior to find potential compensation mechanisms between model parameters. To investigate this, we will extract **conditional distributions** from the posterior inferred with `sbi`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note, you can find the original version of this notebook at [https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb](https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb) in the `sbi` repository." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Main syntax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sbi.analysis import conditional_pairplot, conditional_corrcoeff\n", - "\n", - "# Plot slices through posterior, i.e. conditionals.\n", - "_ = conditional_pairplot(\n", - " density=posterior,\n", - " condition=posterior.sample((1,)),\n", - " limits=torch.tensor([[-2., 2.], [-2., 2.]]),\n", - ")\n", - "\n", - "# Compute the matrix of correlation coefficients of the slices.\n", - "cond_coeff_mat = conditional_corrcoeff(\n", - " density=posterior,\n", - " condition=posterior.sample((1,)),\n", - " limits=torch.tensor([[-2., 2.], [-2., 2.]]),\n", - ")\n", - "plt.imshow(cond_coeff_mat, clim=[-1, 1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Analysing variability and compensation mechanisms in a toy example\n", - "Below, we use a simple toy example to demonstrate the above described features. For an application of these features to a neuroscience problem, see figure 6 in [Gonçalves, Lueckmann, Deistler et al., 2019](https://arxiv.org/abs/1907.00770)." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from sbi import utils as utils\n", - "from sbi.analysis import pairplot, conditional_pairplot, conditional_corrcoeff\n", - "import torch\n", - "import numpy as np\n", - "\n", - "import matplotlib.pyplot as plt\n", - "from mpl_toolkits.mplot3d import Axes3D\n", - "from matplotlib import animation, rc\n", - "from IPython.display import HTML, Image\n", - "\n", - "_ = torch.manual_seed(0)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's say we have used SNPE to obtain a posterior distribution over three parameters. In this tutorial, we just load the posterior from a file:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from toy_posterior_for_07_cc import ExamplePosterior\n", - "posterior = ExamplePosterior()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we specify the experimental observation $x_o$ at which we want to evaluate and sample the posterior $p(\\theta|x_o)$:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "x_o = torch.ones(1, 20) # simulator output was 20-dimensional\n", - "posterior.set_default_x(x_o)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As always, we can inspect the posterior marginals with the `pairplot()` function:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "posterior_samples = posterior.sample((5000,))\n", - "\n", - "fig, ax = pairplot(\n", - " samples=posterior_samples,\n", - " limits=torch.tensor([[-2., 2.]]*3),\n", - " upper=['kde'],\n", - " diag=['kde'],\n", - " figsize=(5,5)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The 1D and 2D marginals of the posterior fill almost the entire parameter space! Also, the Pearson correlation coefficient matrix of the marginal shows rather weak interactions (low correlations):" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAADxCAYAAAD7hRNxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWYklEQVR4nO3df6xcZZ3H8fen97Y0u6wWKCmI2EJokBpWkAYwRGUBtRBDu4rabpTiQroY3F1FDSKJbHBJym4iiz+xgQooCygi1ohhkR+LRmEpbKH8WOQKKq2VSvmhhFK4vd/94zxThunMnXM7Z849c+7nZU46c84z85yJuV+eX+f5KiIwM+u3aZN9A2Y2NTjYmFkpHGzMrBQONmZWCgcbMyuFg42ZlcLBxqymJK2WtFnSgx2uS9KXJY1IekDS25quLZf0WDqWF3E/DjZm9XUFsGic6ycC89OxAvgGgKQ9gfOBo4AjgfMl7dHrzTjYmNVURNwJPDNOkcXAVZG5C5glaV/gvcAtEfFMRDwL3ML4QSuX4V6/wMyK85f7z4ztL43lKrvt6VceAl5qOrUqIlZNoLr9gCeb3m9I5zqd74mDjVmFjL00xgHvn52r7P+t2vRSRCzs8y0Vxt0osyoRTJumXEcBNgL7N71/YzrX6XxPHGzMKkbKdxRgDXBqmpU6Gng+IjYBNwPvkbRHGhh+TzrXE3ejzCpEwLSCmgCSrgGOBWZL2kA2wzQdICIuBW4CTgJGgBeBj6Vrz0j6InBP+qoLImK8geZcHGzMqkQwNFxMsyUilnW5HsBZHa6tBlYXciOJg41ZxaimgxsONmYVIsG0ggZkqsbBxqxi3LIxs1IUNUBcNQ42ZhUiuWVjZiUZGvKYjZn1meRulJmVQqiYRxEqx8HGrErcsjGzsniA2Mz6TvIAcVtp+8DrgHnAb4APpZ29WsttB9ant7+LiJN7qdeszurajer1Z30OuDUi5gO3pvftbI2Iw9LhQGPWgQBNU65j0PQabBYDV6bXVwJLevw+s6ktDRDnOQZNr2M2c9JmOwB/AOZ0KDdT0lpgFFgZETe2KyRpBdku72hYR8yYNXWGlA7+iwMn+xZKs23W1sm+hVI9+sDvno6IvfOWr+lzmN2DjaSfAvu0uXRe85uICEnR4WvmRsRGSQcCt0laHxG/bi2UNmteBTBz7xkxb0m+vVjr4OYjLp/sWyjNyOK2aYxq6x1vOPO3ectmm2fVM9p0DTYRcUKna5KekrRvRGxKKSA2d/iOjenfxyXdARwO7BRszKa8AjfPqppee35rgEa2vOXAD1sLpH1Md0uvZwPHAA/3WK9ZLQkxTfmOQdNrsFkJvFvSY8AJ6T2SFkq6LJU5BFgr6X7gdrIxGwcbs3YKzq4gaZGkR1OK3Z1miyVdLGldOn4l6bmma9ubrq3p9af1NAIbEVuA49ucXwuckV7/Aji0l3rMpooix2wkDQFfA95NlmjuHklrmv9jHxGfair/j2RDHA1bI+KwQm4Gp3Ixq5xpmpbryOFIYCQiHo+Il4FryZardLIMuKaAn9CWg41ZlShfFypn6yd3Gl1Jc4EDgNuaTs+UtFbSXZKW7OIv2mHqLGQxGwAChodytwFmp/VrDRPN9d1sKXB9RGxvOpdryUpeDjZmFZJtnpU72DzdJdf3RNLoLqUlh1TRS1bcjTKrlEK7UfcA8yUdIGkGWUDZaVZJ0puBPYBfNp0rfMmKWzZmVVJg3qiIGJX0CbI83UPA6oh4SNIFwNqIaASepcC1KUNmwyHANyWNkTVKel6y4mBjViFFP64QETeR5fRuPveFlvf/0uZzhS9ZcbAxqxKJoaGhyb6LvnCwMauQKf0gppmVy8HGzPpOIu/q4IHjYGNWKfkfshw0DjZmFTOI20fk4WBjViESDA97NsrM+qyxeVYdOdiYVYk8G2VmJXGwMbO+E576NrMyyFPfZlYCIYaHpk/2bfRFIe21HDu47ybpunT9bknziqjXrG4EDGko1zFoeg42TTu4nwgsAJZJWtBS7HTg2Yg4CLgYuKjXes1qSWLatKFcx6ApomWTZwf3xcCV6fX1wPFSTRcTmPVomoZyHYOmiDGbdju4H9WpTNo97HlgL+DpAuo3qw2hiexBPFAqNUAsaQWwAmB498GL3Ga9ksT0oRmTfRt9UUSwybODe6PMBknDwOuBLa1flNJQrAKYufeMaL1uVn+q7TqbIn5Vnh3c1wDL0+tTgNtaNlc2MxqpXIobIM4xU3yapD825fQ+o+nackmPpWN562cnqueWTc4d3C8Hvi1pBHiGLCCZ2U5U2LR2nlzfyXUR8YmWz+4JnA8sBAK4N3322V29n0LGbLrt4B4RLwEfLKIuszor+HGFHTPFAJIaM8V5UrK8F7glIp5Jn70FWEQPucDr2Tk0G1gTWmczO+XibhwrWr4sb67vD0h6QNL1khrjr7nzhOdVqdkos6lugrNR3dLv5vEj4JqI2CbpH8jWwx3X43e25ZaNWYU0ulF5jhy6zhRHxJaI2JbeXgYckfezE+VgY1YlxT6u0HWmWNK+TW9PBh5Jr28G3pNyfu8BvCed22XuRplVigp7FCHnTPE/SToZGCWbKT4tffYZSV8kC1gAFzQGi3eVg41ZhWQZMYvrcOSYKT4XOLfDZ1cDq4u6Fwcbs0opbp1N1TjYmFWIJIb9bJSZ9Zv3IDazckgMDeDGWHk42JhVSNaycbAxs76r7xYTDjZmFSLE8DQPEJtZv0nI3SgzK4PHbMys74SYhoONmZXALRsz6zsV+CBm1TjYmFWKGJJno8yszySvszGzktS1G1VICO0lN42ZNZNzfXfSS24aM3st4UV94+klN42ZtfA6m87a5Zc5qk25D0h6J/Ar4FMR8WRrgZT3ZgXAfnvO4Y4jvlfA7Q2GY++dOjn8PnLwoZN9C5UlFftslKRFwCVkexBfFhErW66fDZxBtgfxH4G/j4jfpmvbgfWp6O8i4uRe7qWsYe8fAfMi4q+BW8hy0+wkIlZFxMKIWLjX7rNKujWzKiluzKZpiONEYAGwTNKClmL/CyxMf5vXA//WdG1rRByWjp4CDRQTbHrJTWNmr5GN2eQ5ctgxxBERLwONIY4dIuL2iHgxvb2L7O+3L4oINr3kpjGzJiIbs8lzUFz63YbTgZ80vZ+ZvvcuSUt6/W09j9n0kpvGzFpNaFFfEel3s1qljwALgXc1nZ4bERslHQjcJml9RPx6V+soZFFfL7lpzOxVBQ8Q50qhK+kE4DzgXU3DHUTExvTv45LuAA4HdjnY1HNdtNkAE0O5jhzyDHEcDnwTODkiNjed30PSbun1bOAYelzO4scVzCqkyKe+cw5x/DuwO/A9SfDqFPchwDcljZE1Sla2Wag7IQ42ZpVS7BYTOYY4TujwuV8AhS6IcrAxqxjVdHTDwcascjTZN9AXDjZmFeI9iM2sRO5GmVkJ5G6UmfWfkLcFNbNyuGVjZn3nAWIzK427UWbWZ8IDxGZWCnkFsZmVxS0bMyuBWzZmVgLl3atm4DjYmFVINkDslo2ZlcCzUWbWfxL4cQUzK0NdWzaFhFBJqyVtlvRgh+uS9GVJI5IekPS2Iuo1q59snU2eI9e3SYskPZr+9j7X5vpukq5L1++WNK/p2rnp/KOS3tvrLyuqvXYFsGic6ycC89OxAvhGQfWa1U5R2RVypt89HXg2Ig4CLgYuSp9dQJaN4S1kf9tfV840nJ0UEmwi4k6y5HOdLAauisxdwKyWLJlmxquPK+T5Xw5d0++m91em19cDxytLs7AYuDYitkXEE8BI+r5dVtZIVK40oJJWNFKJbnnhuZJuzaxKNIGjkPS7O8pExCjwPLBXzs9OSKUGiCNiFbAK4K1z3xyTfDtm5Yt05FNY+t0ylNWyyZUG1MwCRb4jhzx/dzvKSBoGXg9syfnZCSkr2KwBTk2zUkcDz0fEppLqNhssYzmP7rqm303vl6fXpwC3RUSk80vTbNUBZJM7/9PDryqmGyXpGuBYsj7kBuB8YDpARFxKlpHvJLJBpheBjxVRr1kt5Wu15PiaXOl3Lwe+LWmEbJJnafrsQ5K+S5bfexQ4KyK293I/hQSbiFjW5XoAZxVRl1mtBajA0coc6XdfAj7Y4bMXAhcWdS+VGiA2MyYyQDxQHGzMqqagblTVONiYVU09Y42DjVmlBHmntQeOg41Z1dQz1jjYmFWOg42ZlcLdKDMrQ5HrbKrEwcasSib2IOZAcbAxq5SAsXpGGwcbswoR9e1G1XMbdzOrHLdszKrGs1Fm1nceIDazssgDxGZWinrGGgcbs0oJPPVtZmUIoqYDxJ76Nqua4jY870jSnpJukfRY+nePNmUOk/RLSQ+ltNkfbrp2haQnJK1Lx2Hd6nSwMauQCIixyHX06HPArRExH7g1vW/1InBqRDRS8P6HpFlN1z8bEYelY123Ct2NMquY2N5jsyWfxWQZUSBLv3sHcM5r7iPiV02vfy9pM7A38NyuVFhIy0bSakmbJT3Y4fqxkp5vanJ9oV05symvMUCc5+iefnc8c5pyt/0BmDNeYUlHAjOAXzedvjB1ry6WtFu3Cotq2VwBfBW4apwyP4uI9xVUn1lNTWiAeNz0u5J+CuzT5tJ5r6kxIqTOT2RJ2hf4NrA8IhrNrnPJgtQMspTZ5wAXjHezReWNulPSvCK+q2HbrK2MLG7bUKqljxx86GTfQmm+85/rJ/sWqq2gXlREnNDpmqSnJO0bEZtSMNncodzrgB8D50XEXU3f3WgVbZP0LeAz3e6nzAHit0u6X9JPJL2lXQFJKxpNwue2vFDirZlVR0TkOnrUnHZ3OfDD1gIpZe8PgKsi4vqWa/umfwUsAbq2DMoKNvcBcyPircBXgBvbFYqIVRGxMCIWztpr95JuzaxCJjZm04uVwLslPQackN4jaaGky1KZDwHvBE5rM8V9taT1wHpgNvCv3SosZTYqIv7U9PomSV+XNDsini6jfrNBUsZsVERsAY5vc34tcEZ6/R3gOx0+f9xE6ywl2EjaB3gqDUQdSdai2lJG3WaDJKKQNTSVVEiwkXQN2Zz9bEkbgPOB6QARcSlwCvBxSaPAVmBp1HVNtlmvSllmU76iZqOWdbn+VbKpcTProq7/HfYKYrMq8VPfZlaWkh5XKJ2DjVmVBB6zMbMyeDbKzMriAWIz67u0n00dOdiYVY2DjZn1W0R4NsrMyuFgY2b95zEbMyuHu1FmVoYAxhxszKzPIoKxV7ZP9m30hYONWcW4G2Vm/VfjbpQzYppVSr5smL3OWOVJv5vKbW/af3hN0/kDJN0taUTSdWlz9HE52JhVSWTdqDxHj/Kk3wXY2pRi9+Sm8xcBF0fEQcCzwOndKnSwMauQAGJsLNfRo8VkaXdJ/y7J+8GUvuU4oJHeJdfnPWZjViURRP7ZqNmS1ja9XxURq3J+Nm/63ZmpjlFgZUTcCOwFPBcRo6nMBmC/bhU62JhVSUxoNqqM9LtzI2KjpAOB21KuqOfz3mCznoONpP3JcnzPIWsFroqIS1rKCLgEOAl4ETgtIu7rtW6z+okiukjZNxWQfjciNqZ/H5d0B3A48H1glqTh1Lp5I7Cx2/0UMWYzCnw6IhYARwNnSVrQUuZEYH46VgDfKKBes/oJYHvkO3qTJ/3uHpJ2S69nA8cAD6c0TLeTpWjq+PlWPQebiNjUaKVExJ+BR9i5/7aYLF9wpOTksxq5gs3stUoaIM6TfvcQYK2k+8mCy8qIeDhdOwc4W9II2RjO5d0qLHTMRtI8smbW3S2X9gOebHrfGFDahJntUNZ+NjnT7/4COLTD5x8HjpxInYUFG0m7k/XlPtmc23uC37GCrJvFnP32LOrWzAZHMJHZqIFSyDobSdPJAs3VEXFDmyIbgf2b3rcdUIqIVRGxMCIWztpr9yJuzWzARFndqNL1HGzSTNPlwCMR8aUOxdYApypzNPB80xy/mTWUt4K4dEV0o44BPgqsl7Qunfs88CaAiLgUuIls2nuEbOr7YwXUa1ZDxU19V03PwSYifg6oS5kAzuq1LrPaa0x915BXEJtVSEQw9vJo94IDyMHGrEom9rjCQHGwMauSgBh1sDGzvnN2BTMrQbhlY2aliHCwMbMSBIxtq+fjCg42ZlVS0oOYk8HBxqxCPGZjZuXwmI2ZlcXdKDPrP3ejzKwMEcHYtno+G+UkdWZVksZs8hy9yJN+V9LfNKXeXSfpJUlL0rUrJD3RdO2wbnU62JhVSYXS70bE7Y3Uu2QZMF8E/qupyGebUvOu61ahg41ZlaQxm363bJh4+t1TgJ9ExIu7WqGDjVmllNONIn/63YalwDUt5y6U9ICkixv5pcbjAWKzComxCT2uMG6u74LS75JyvB0K3Nx0+lyyIDUDWEWWR+qC8W7WwcasUib0uMK4ub6LSL+bfAj4QUS80vTdjVbRNknfAj7T7WbdjTKrkvLGbLqm322yjJYuVCOjbcqusgR4sFuFbtmYVUl5jyusBL4r6XTgt2StFyQtBM6MiDPS+3lkOd/+u+XzV0vamyzZwTrgzG4V9hxsJO0PXEU2wBRk/cZLWsocSxY5n0inboiIcft3ZlNRlLQHcZ70u+n9b8hSZbeWO26idRbRshkFPh0R90n6K+BeSbc0JSBv+FlEvK+A+sxqzc9GdZAGijal13+W9AhZJGwNNmbWRUQwOubNs7pK/bvDgbvbXH67pPuB3wOfiYiH2nx+BbAivX3hHW8489Ei7y+n2cDTk1DvZJlKv3eyfuvciRTeHm7ZjEvS7sD3gU9GxJ9aLt8HzI2IFySdBNwIzG/9jrRGYFXr+TJJWjvedGLdTKXfOwi/NQjGahpsCpn6ljSdLNBcHRE3tF6PiD9FxAvp9U3AdEmzi6jbrG7GInIdg6aI2SgBlwOPRMSXOpTZB3gqrVQ8kizIbem1brM6qmvLpohu1DHAR4H1ktalc58H3gQQEZeSPcT1cUmjwFZgaURlQ/OkduMmwVT6vZX/rRH17Uapun/zZlPPQcNz4kuv+3Cusouf/cq9VR+DauYVxGaVEp6NMrP+CxjIwd88/CBmE0mLJD0qaUTSTjuX1Ymk1ZI2S+r6AN2gk7S/pNslPSzpIUn/PNn31FFkA8R5jkHjYJNIGgK+BpwILACWSVowuXfVV1cAiyb7JkrSeKRmAXA0cFZ1/7+N2gYbd6NedSQwEhGPA0i6lmzrxFo+dhERd6YV37U3SI/UBPhxhSlgP+DJpvcbgKMm6V6sT7o8UjPpwgPEZoOvyyM11RBe1DcVbCTbJKjhjemc1UC3R2qqos6zUQ42r7oHmC/pALIgsxT4u8m9JStCnkdqqqO+K4g9G5VExCjwCbId5B8BvttuG4y6kHQN8EvgYEkb0vaQddV4pOa4pgyOJ032TbWTtWw8G1V76Yn0myb7PsoQEcsm+x7KEhE/J9srt/IigldqOhvllo1ZxZTRspH0wbTAcSxtct6pXNuFrpIOkHR3On+dpBnd6nSwMauYkvazeRB4P3BnpwJdFrpeBFwcEQcBzwJdu+EONmYVEiWtII6IRyKi27a7Oxa6RsTLwLXA4jTgfhxwfSqXJ1e4x2zMqmQDz918dtyQdxfLmeOl3y1Ap4WuewHPpUmVxvmd0r20crAxq5CIKOx5tfFyfUfEeBkw+8LBxqymxsv1nVOnha5bgFmShlPrJtcCWI/ZmFknOxa6ptmmpcCatKXv7WTb/UL3XOGAg43ZlCTpbyVtAN4O/FjSzen8GyTdBF0Xup4DnC1phGwM5/KudXoPYjMrg1s2ZlYKBxszK4WDjZmVwsHGzErhYGNmpXCwMbNSONiYWSn+H4jRLO0e+4NOAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "corr_matrix_marginal = np.corrcoef(posterior_samples.T)\n", - "fig, ax = plt.subplots(1,1, figsize=(4, 4))\n", - "im = plt.imshow(corr_matrix_marginal, clim=[-1, 1], cmap='PiYG')\n", - "_ = fig.colorbar(im)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It might be tempting to conclude that the experimental data barely constrains our parameters and that almost all parameter combinations can reproduce the experimental data. As we will show below, this is not the case." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Because our toy posterior has only three parameters, we can plot posterior samples in a 3D plot:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "rc('animation', html='html5')\n", - "\n", - "# First set up the figure, the axis, and the plot element we want to animate\n", - "fig = plt.figure(figsize=(6,6))\n", - "ax = fig.add_subplot(111, projection='3d')\n", - "\n", - "ax.set_xlim((-2, 2))\n", - "ax.set_ylim((-2, 2))\n", - "\n", - "def init():\n", - " line, = ax.plot([], [], lw=2)\n", - " line.set_data([], [])\n", - " return (line,)\n", - "\n", - "def animate(angle):\n", - " num_samples_vis = 1000\n", - " line = ax.scatter(posterior_samples[:num_samples_vis, 0], posterior_samples[:num_samples_vis, 1], posterior_samples[:num_samples_vis, 2], zdir='z', s=15, c='#2171b5', depthshade=False)\n", - " ax.view_init(20, angle)\n", - " return (line,)\n", - "\n", - "anim = animation.FuncAnimation(fig, animate, init_func=init,\n", - " frames=range(0,360,5), interval=150, blit=True)\n", - "\n", - "plt.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "HTML(anim.to_html5_video())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Clearly, the range of admissible parameters is constrained to a narrow region in parameter space, which had not been evident from the marginals." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If the posterior has more than three dimensions, inspecting all dimensions at once will not be possible anymore. One way to still reveal structures in high-dimensional posteriors is to inspect 2D-slices through the posterior. In `sbi`, this can be done with the `conditional_pairplot()` function, which computes the conditional distributions within the posterior. We can slice (i.e. condition) the posterior at any location, given by the `condition`. In the plot below, for all upper diagonal plots, we keep all but two parameters constant at values sampled from the posterior, and inspect what combinations of the remaining two parameters can reproduce experimental data. For the plots on the diagonal (the 1D conditionals), we keep all but one parameter constant." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "condition = posterior.sample((1,))\n", - "\n", - "_ = conditional_pairplot(\n", - " density=posterior,\n", - " condition=condition,\n", - " limits=torch.tensor([[-2., 2.]]*3),\n", - " figsize=(5,5)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This plot looks completely different from the marginals obtained with `pairplot()`. As it can be seen on the diagonal plots, if all parameters but one are kept constant, the remaining parameter has to be tuned to a narrow region in parameter space. In addition, the upper diagonal plots show strong correlations: deviations in one parameter can be compensated through changes in another parameter." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can summarize these correlations in a conditional correlation matrix, which computes the Pearson correlation coefficient of each of these pairwise plots. This matrix (below) shows strong correlations between many parameters, which can be interpreted as potential compensation mechansims:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAADxCAYAAAD7hRNxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWV0lEQVR4nO3df6xcZZ3H8ffn3tvSdVGLFAEBKYZmlxp2QRvAkCgLiIUYWhG13awUF3JXg7urqEEkkRWXWHYTu/gTG6iAEsCtiDViWOTHglFYKlsoPxapoNJaqZRSJIXC7f3uH+eZMkzv3Dm3c+bcM+d+XuakM+c8M88ZoV+eX+f5KiIwM+u1gcm+ATObGhxszKwUDjZmVgoHGzMrhYONmZXCwcbMSuFgY1ZTklZI2iTpwTbXJekrktZJekDS25quLZH0WDqWFHE/DjZm9XUlMH+c6ycDc9IxDHwTQNIbgAuBo4GjgAsl7dXtzTjYmNVURNwJPDNOkQXA1ZG5G5gpaX/gPcAtEfFMRGwBbmH8oJXLULdfYGbF+fODZsSOF0dzld3+9MsPAS82nVoeEcsnUN0BwJNN79enc+3Od8XBxqxCRl8c5ZDTZuUq+3/LN74YEfN6fEuFcTfKrEoEAwPKdRRgA3BQ0/sD07l257viYGNWMVK+owCrgDPSrNQxwNaI2AjcDJwkaa80MHxSOtcVd6PMKkTAQEFNAEnXAscBsyStJ5thmgYQEZcBNwGnAOuAbcBH0rVnJH0RuDd91UURMd5Acy4ONmZVIhgcKqbZEhGLO1wP4Jw211YAKwq5kcTBxqxiVNPBDQcbswqRYKCgAZmqcbAxqxi3bMysFEUNEFeNg41ZhUhu2ZhZSQYHPWZjZj0muRtlZqUQKuZRhMpxsDGrErdszKwsHiA2s56TPEA8prR94PXAbOA3wAfTzl6t5XYAa9Pb30XEqd3Ua1Znde1GdfuzPgvcGhFzgFvT+7G8EBFHpMOBxqwNARpQrqPfdBtsFgBXpddXAQu7/D6zqS0NEOc5+k23Yzb7ps12AP4A7Num3AxJq4ERYGlE3DhWIUnDZLu8oyG9ffrMqTOkdOCWmZN9C6V56o3PTvYtlOr5jS8/HRH75C1f0+cwOwcbST8F9hvj0gXNbyIiJEWbrzk4IjZIegtwm6S1EfHr1kJps+blADP2mR6zF+bbi7UOvrTytMm+hdIsG1412bdQqru+8ORv85bNNs+qZ7TpGGwi4sR21yQ9JWn/iNiYUkBsavMdG9Kfj0u6AzgS2CXYmE15BW6eVTXd9vxWAY1seUuAH7YWSPuY7pFezwKOBR7usl6zWhJiQPmOftNtsFkKvFvSY8CJ6T2S5km6PJU5DFgt6X7gdrIxGwcbs7EUnF1B0nxJj6YUu7vMFktaJmlNOn4l6dmmazuarnXd9+1qBDYiNgMnjHF+NXB2ev1z4PBu6jGbKoocs5E0CHwdeDdZorl7Ja1q/o99RHyyqfw/kg1xNLwQEUcUcjM4lYtZ5QxoINeRw1HAuoh4PCJeAq4jW67SzmLg2gJ+wpgcbMyqRPm6UDlbP7nT6Eo6GDgEuK3p9AxJqyXdLWnhbv6inabOQhazPiBgaDB3G2BWWr/WMNFc380WASsjYkfTuVxLVvJysDGrkGzzrNzB5ukOub4nkkZ3ES05pIpesuJulFmlFNqNuheYI+kQSdPJAsous0qS/hLYC/hF07nCl6y4ZWNWJQXmjYqIEUkfJ8vTPQisiIiHJF0ErI6IRuBZBFyXMmQ2HAZ8S9IoWaOk6yUrDjZmFVL04woRcRNZTu/mc59vef8vY3yu8CUrDjZmVSIxODg42XfREw42ZhUypR/ENLNyOdiYWc9J5F0d3HccbMwqJf9Dlv3GwcasYvpx+4g8HGzMKkSCoSHPRplZjzU2z6ojBxuzKpFno8ysJA42ZtZzwlPfZlYGeerbzEogxNDgtMm+jZ4opL2WYwf3PSRdn67fI2l2EfWa1Y2AQQ3mOvpN18GmaQf3k4G5wGJJc1uKnQVsiYhDgWXAJd3Wa1ZLEgMDg7mOflNEyybPDu4LgKvS65XACVJNFxOYdWlAg7mOflPEmM1YO7gf3a5M2j1sK7A38HQB9ZvVhtBE9iDuK5UaIJY0DAwDDO3Zf5HbrFuSmDY4fbJvoyeKCDZ5dnBvlFkvaQh4PbC59YtSGorlADP2mR6t183qT7VdZ1PEr8qzg/sqYEl6fTpwW8vmymZGI5VLcQPEOWaKz5T0x6ac3mc3XVsi6bF0LGn97ER13bLJuYP7FcB3JK0DniELSGa2CxU2rZ0n13dyfUR8vOWzbwAuBOYBAfwyfXbL7t5PIWM2nXZwj4gXgQ8UUZdZnRX8uMLOmWIASY2Z4jwpWd4D3BIRz6TP3gLMp4tc4PXsHJr1rQmts5mVcnE3juGWL8ub6/v9kh6QtFJSY/w1d57wvCo1G2U21U1wNqpT+t08fgRcGxHbJf0D2Xq447v8zjG5ZWNWIY1uVJ4jh44zxRGxOSK2p7eXA2/P+9mJcrAxq5JiH1foOFMsaf+mt6cCj6TXNwMnpZzfewEnpXO7zd0os0pRYY8i5Jwp/idJpwIjZDPFZ6bPPiPpi2QBC+CixmDx7nKwMauQLCNmcR2OHDPF5wPnt/nsCmBFUffiYGNWKcWts6kaBxuzCpHEkJ+NMrNe8x7EZlYOicE+3BgrDwcbswrJWjYONmbWc/XdYsLBxqxChBga8ACxmfWahNyNMrMyeMzGzHpOiAEcbMysBG7ZmFnPqcAHMavGwcasUsSgPBtlZj0meZ2NmZWkrt2oQkJoN7lpzKyZnOu7nW5y05jZqwkv6htPN7lpzKyF19m0N1Z+maPHKPd+Se8EfgV8MiKebC2Q8t4MA+yjPfnSytMKuL3+cP7pN0z2LZTmNVvrOQBaBKnYZ6MkzQcuJduD+PKIWNpy/VzgbLI9iP8I/H1E/DZd2wGsTUV/FxGndnMvZf1T/xEwOyL+CriFLDfNLiJieUTMi4h5rxv4s5JuzaxKihuzaRriOBmYCyyWNLel2P8C89LfzZXAvzVdeyEijkhHV4EGigk23eSmMbNXycZs8hw57BziiIiXgMYQx04RcXtEbEtv7yb7+9sTRQSbbnLTmFkTkY3Z5DkoLv1uw1nAT5rez0jfe7ekhd3+tq7HbLrJTWNmrSa0qK+I9LtZrdLfAfOAdzWdPjgiNkh6C3CbpLUR8evdraOQRX3d5KYxs1cUPECcK4WupBOBC4B3NQ13EBEb0p+PS7oDOBLY7WDjaQGzihGDuY4c8gxxHAl8Czg1IjY1nd9L0h7p9SzgWLpczuLHFcwqpMinvnMOcfw7sCfwn5LglSnuw4BvSRola5QsHWOh7oQ42JhVSrFbTOQY4jixzed+Dhxe2I3gYGNWOarp6IaDjVnlaLJvoCccbMwqxHsQm1mJ3I0ysxLI3Sgz6z0hbwtqZuVwy8bMes4DxGZWGnejzKzHhAeIzawU8gpiMyuLWzZmVgK3bMysBMq7V03fcbAxq5BsgNgtGzMrgWejzKz3JPDjCmZWhrq2bAoJoZJWSNok6cE21yXpK5LWSXpA0tuKqNesfrJ1NnmOXN8mzZf0aPq799kxru8h6fp0/R5Js5uunZ/OPyrpPd3+sqLaa1cC88e5fjIwJx3DwDcLqtesdorKrpAz/e5ZwJaIOBRYBlySPjuXLBvDW8n+bn9DOdNwtlNIsImIO8mSz7WzALg6MncDM1uyZJoZrzyukOd/OXRMv5veX5VerwROUJZmYQFwXURsj4gngHXp+3ZbWSNRudKAShpupBJ9bvSFkm7NrEo0gaOQ9Ls7y0TECLAV2DvnZyekUgPEEbEcWA5w6NAbY5Jvx6x8kY58Cku/W4ayWja50oCaWaDId+SQ5+/dzjKShoDXA5tzfnZCygo2q4Az0qzUMcDWiNhYUt1m/WU059FZx/S76f2S9Pp04LaIiHR+UZqtOoRscud/uvhVxXSjJF0LHEfWh1wPXAhMA4iIy8gy8p1CNsi0DfhIEfWa1VK+VkuOr8mVfvcK4DuS1pFN8ixKn31I0vfI8nuPAOdExI5u7qeQYBMRiztcD+CcIuoyq7UAFThamSP97ovAB9p89mLg4qLupVIDxGbGRAaI+4qDjVnVFNSNqhoHG7OqqWescbAxq5Qg77R233GwMauaesYaBxuzynGwMbNSuBtlZmUocp1NlTjYmFXJxB7E7CsONmaVEjBaz2jjYGNWIaK+3ah6buNuZpXjlo1Z1Xg2ysx6zgPEZlYWeYDYzEpRz1jjYGNWKYGnvs2sDEHUdIDYU99mVVPchudtSXqDpFskPZb+3GuMMkdI+oWkh1La7A81XbtS0hOS1qTjiE51OtiYVUgExGjkOrr0WeDWiJgD3Jret9oGnBERjRS8/yFpZtP1z0TEEelY06lCd6PMKiZ2dNlsyWcBWUYUyNLv3gGc96r7iPhV0+vfS9oE7AM8uzsVFtKykbRC0iZJD7a5fpykrU1Nrs+PVc5symsMEOc5OqffHc++Tbnb/gDsO15hSUcB04FfN52+OHWvlknao1OFRbVsrgS+Blw9Tpm7IuK9BdVnVlMTGiAeN/2upJ8C+41x6YJX1RgRUvsnsiTtD3wHWBIRjWbX+WRBajpZyuzzgIvGu9mi8kbdKWl2Ed/V8NQbn2XZcGvyvvp6zdapM3y27flSugn9q6D/eyLixHbXJD0laf+I2JiCyaY25V4H/Bi4ICLubvruRqtou6RvA5/udD9l/hv+Dkn3S/qJpLeOVUDScKNJ+PI2/wtpU1NE5Dq61Jx2dwnww9YCKWXvD4CrI2Jly7X9058CFgJjDqE0KyvY3AccHBF/DXwVuHGsQhGxPCLmRcS8aa+ZOv+lN9tpYmM23VgKvFvSY8CJ6T2S5km6PJX5IPBO4MwxprivkbQWWAvMAv61U4WlzEZFxHNNr2+S9A1JsyLi6TLqN+snZcxGRcRm4IQxzq8Gzk6vvwt8t83nj59onaUEG0n7AU+lgaijyFpUm8uo26yfRBSyhqaSCgk2kq4lm7OfJWk9cCEwDSAiLgNOBz4maQR4AVgUdV2Tbdatmg5XFjUbtbjD9a+RTY2bWQd1/e+wVxCbVYmf+jazspT0uELpHGzMqiTwmI2ZlcGzUWZWFg8Qm1nPpf1s6sjBxqxqHGzMrNciwrNRZlYOBxsz6z2P2ZhZOdyNMrMyBDDqYGNmPRYRjL68Y7JvoyccbMwqxt0oM+u9GnejvNGvWaXky4bZ7YxVnvS7qdyOpv2HVzWdP0TSPZLWSbo+bY4+LgcbsyqJrBuV5+hSnvS7AC80pdg9ten8JcCyiDgU2AKc1alCBxuzCgkgRkdzHV1aQJZ2l/TnwrwfTOlbjgca6V1yfd5jNmZVEkHkn42aJWl10/vlEbE852fzpt+dkeoYAZZGxI3A3sCzETGSyqwHDuhUoYONWZXEhGajyki/e3BEbJD0FuC2lCtqa94bbNZ1sJF0EFmO733JWoHLI+LSljICLgVOAbYBZ0bEfd3WbVY/UUQXKfumAtLvRsSG9Ofjku4AjgS+D8yUNJRaNwcCGzrdTxFjNiPApyJiLnAMcI6kuS1lTgbmpGMY+GYB9ZrVTwA7It/RnTzpd/eStEd6PQs4Fng4pWG6nSxFU9vPt+o62ETExkYrJSL+BDzCrv23BWT5giMlJ5/ZyBVsZq9W0gBxnvS7hwGrJd1PFlyWRsTD6dp5wLmS1pGN4VzRqcJCx2wkzSZrZt3TcukA4Mmm940BpY2Y2U5l7WeTM/3uz4HD23z+ceCoidRZWLCRtCdZX+4Tzbm9J/gdw2TdLPZ4/WBRt2bWP4KJzEb1laLS704jCzTXRMQNYxTZABzU9H7MAaU0bbcc4LVvml7PTT3MxlXcAHHVdD1mk2aargAeiYgvtym2CjhDmWOArU1z/GbWUN4K4tIV0bI5FvgwsFbSmnTuc8CbASLiMuAmsmnvdWRT3x8poF6zGqpvy6brYBMRPwPUoUwA53Rbl1ntNaa+a8griM0qJCIYfWmkc8E+5GBjViUTe1yhrzjYmFVJQIw42JhZzzm7gpmVINyyMbNSRDjYmFkJAka3+3EFM+u1kh7EnAwONmYV4jEbMyuHx2zMrCzuRplZ77kbZWZliAhGt9fz2SgnqTOrkjRmk+foRp70u5L+pin17hpJL0pamK5dKemJpmtHdKrTwcasSiqUfjcibm+k3iXLgLkN+K+mIp9pSs27plOFDjZmVZLGbHrdsmHi6XdPB34SEdt2t0IHG7NKKacbRf70uw2LgGtbzl0s6QFJyxr5pcbjAWKzConRCT2uMG6u74LS75JyvB0O3Nx0+nyyIDWdLEnBecBF492sg41ZpUzocYVxc30XkX43+SDwg4h4uem7G62i7ZK+DXy60826G2VWJeWN2XRMv9tkMS1dqEZG25RdZSHwYKcK3bIxq5LyHldYCnxP0lnAb8laL0iaB3w0Is5O72eT5Xz775bPXyNpH7JkB2uAj3aqsOtgI+kg4GqyAaYg6zde2lLmOLLI+UQ6dUNEjNu/M5uKoqQ9iPOk303vf0OWKru13PETrbOIls0I8KmIuE/Sa4FfSrqlKQF5w10R8d4C6jOrNT8b1UYaKNqYXv9J0iNkkbA12JhZBxHByKg3z+oo9e+OBO4Z4/I7JN0P/B74dEQ8NMbnh4Hh9Pb5u77w5KNF3l9Os4CnJ6HeyTKVfu9k/daDJ1J4R7hlMy5JewLfBz4REc+1XL4PODginpd0CnAjMKf1O9IageWt58skafV404l1M5V+bz/81iAYrWmwKWTqW9I0skBzTUTc0Ho9Ip6LiOfT65uAaZJmFVG3Wd2MRuQ6+k0Rs1ECrgAeiYgvtymzH/BUWql4FFmQ29xt3WZ1VNeWTRHdqGOBDwNrJa1J5z4HvBkgIi4je4jrY5JGgBeARRGVDc2T2o2bBFPp91b+t0bUtxul6v6dN5t6Dh3aN778ug/lKrtgy1d/WfUxqGZeQWxWKeHZKDPrvYC+HPzNww9iNpE0X9KjktZJ2mXnsjqRtELSJkkdH6Drd5IOknS7pIclPSTpnyf7ntqKbIA4z9FvHGwSSYPA14GTgbnAYklzJ/eueupKYP5k30RJGo/UzAWOAc6p7j/bqG2wcTfqFUcB6yLicQBJ15FtnVjLxy4i4s604rv2+umRmgA/rjAFHAA82fR+PXD0JN2L9UiHR2omXXiA2Kz/dXikphrCi/qmgg1kmwQ1HJjOWQ10eqSmKuo8G+Vg84p7gTmSDiELMouAv53cW7Ii5Hmkpjrqu4LYs1FJRIwAHyfbQf4R4HtjbYNRF5KuBX4B/IWk9Wl7yLpqPFJzfFMGx1Mm+6bGkrVsPBtVe+mJ9Jsm+z7KEBGLJ/seyhIRPyPbK7fyIoKXazob5ZaNWcWU0bKR9IG0wHE0bXLertyYC10lHSLpnnT+eknTO9XpYGNWMSXtZ/MgcBpwZ7sCHRa6XgIsi4hDgS1Ax264g41ZhURJK4gj4pGI6LTt7s6FrhHxEnAdsCANuB8PrEzl8uQK95iNWZWs59mbz40b8u5iOWO89LsFaLfQdW/g2TSp0ji/S7qXVg42ZhUSEYU9rzZeru+IGC8DZk842JjV1Hi5vnNqt9B1MzBT0lBq3eRaAOsxGzNrZ+dC1zTbtAhYlbb0vZ1su1/onCsccLAxm5IkvU/SeuAdwI8l3ZzOv0nSTdBxoet5wLmS1pGN4VzRsU7vQWxmZXDLxsxK4WBjZqVwsDGzUjjYmFkpHGzMrBQONmZWCgcbMyvF/wPZqxE8c4CsDAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "cond_coeff_mat = conditional_corrcoeff(\n", - " density=posterior,\n", - " condition=condition,\n", - " limits=torch.tensor([[-2., 2.]]*3),\n", - ")\n", - "fig, ax = plt.subplots(1,1, figsize=(4,4))\n", - "im = plt.imshow(cond_coeff_mat, clim=[-1, 1], cmap='PiYG')\n", - "_ = fig.colorbar(im)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "So far, we have investigated the conditional distribution only at a specific `condition` sampled from the posterior. In many applications, it makes sense to repeat the above analyses with a different `condition` (another sample from the posterior), which can be interpreted as slicing the posterior at a different location. Note that `conditional_corrcoeff()` can directly compute the matrix for several `conditions` and then outputs the average over them. This can be done by passing a batch of $N$ conditions as the `condition` argument." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Sampling conditional distributions\n", - "\n", - "So far, we have demonstrated how one can plot 2D conditional distributions with `conditional_pairplot()` and how one can compute the pairwise conditional correlation coefficient with `conditional_corrcoeff()`. In some cases, it can be useful to keep a subset of parameters fixed and to vary **more than two** parameters. This can be done by sampling the conditonal posterior $p(\\theta_i | \\theta_{j \\neq i}, x_o)$. As of `sbi` `v0.18.0`, this functionality requires using the [sampler interface](https://www.mackelab.org/sbi/tutorial/11_sampler_interface/). In this tutorial, we demonstrate this functionality on a linear gaussian simulator with four parameters. We would like to fix the forth parameter to $\\theta_4=0.2$ and sample the first three parameters given that value, i.e. we want to sample $p(\\theta_1, \\theta_2, \\theta_3 | \\theta_4 = 0.2, x_o)$. For an application in neuroscience, see [Deistler, Gonçalves, Macke, 2021](https://www.biorxiv.org/content/10.1101/2021.07.30.454484v4.abstract)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial, we will use SNPE, but the same also works for SNLE and SNRE. First, we define the prior and the simulator and train the deep neural density estimator:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4683b28ba0614e87b33854d207e30da2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Running 1000 simulations.: 0%| | 0/1000 [00:00" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "from sbi.analysis import pairplot\n", - "\n", - "_ = pairplot(cond_samples, limits=[[-2, 2], [-2, 2], [-2, 2], [-2, 2]], figsize=(4, 4))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Analysing variability and compensation mechansims with conditional distributions\n", + "\n", + "A central advantage of `sbi` over parameter search methods such as genetic algorithms is that the posterior captures **all** models that can reproduce experimental data. This allows us to analyse whether parameters can be variable or have to be narrowly tuned, and to analyse compensation mechanisms between different parameters. See also [Marder and Taylor, 2011](https://www.nature.com/articles/nn.2735?page=2) for further motivation to identify all models that capture experimental data. \n", + "\n", + "In this tutorial, we will show how one can use the posterior distribution to identify whether parameters can be variable or have to be finely tuned, and how we can use the posterior to find potential compensation mechanisms between model parameters. To investigate this, we will extract **conditional distributions** from the posterior inferred with `sbi`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note, you can find the original version of this notebook at [https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb](https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb) in the `sbi` repository." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Main syntax" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sbi.analysis import conditional_pairplot, conditional_corrcoeff\n", + "\n", + "# Plot slices through posterior, i.e. conditionals.\n", + "_ = conditional_pairplot(\n", + " density=posterior,\n", + " condition=posterior.sample((1,)),\n", + " limits=torch.tensor([[-2., 2.], [-2., 2.]]),\n", + ")\n", + "\n", + "# Compute the matrix of correlation coefficients of the slices.\n", + "cond_coeff_mat = conditional_corrcoeff(\n", + " density=posterior,\n", + " condition=posterior.sample((1,)),\n", + " limits=torch.tensor([[-2., 2.], [-2., 2.]]),\n", + ")\n", + "plt.imshow(cond_coeff_mat, clim=[-1, 1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Analysing variability and compensation mechanisms in a toy example\n", + "Below, we use a simple toy example to demonstrate the above described features. For an application of these features to a neuroscience problem, see figure 6 in [Gonçalves, Lueckmann, Deistler et al., 2019](https://arxiv.org/abs/1907.00770)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sbi import utils as utils\n", + "from sbi.analysis import pairplot, conditional_pairplot, conditional_corrcoeff\n", + "import torch\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "from matplotlib import animation, rc\n", + "from IPython.display import HTML, Image\n", + "\n", + "_ = torch.manual_seed(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's say we have used SNPE to obtain a posterior distribution over three parameters. In this tutorial, we just load the posterior from a file:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from toy_posterior_for_07_cc import ExamplePosterior\n", + "posterior = ExamplePosterior()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we specify the experimental observation $x_o$ at which we want to evaluate and sample the posterior $p(\\theta|x_o)$:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "x_o = torch.ones(1, 20) # simulator output was 20-dimensional\n", + "posterior.set_default_x(x_o)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As always, we can inspect the posterior marginals with the `pairplot()` function:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "posterior_samples = posterior.sample((5000,))\n", + "\n", + "fig, ax = pairplot(\n", + " samples=posterior_samples,\n", + " limits=torch.tensor([[-2., 2.]]*3),\n", + " upper=['kde'],\n", + " diag=['kde'],\n", + " figsize=(5,5)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The 1D and 2D marginals of the posterior fill almost the entire parameter space! Also, the Pearson correlation coefficient matrix of the marginal shows rather weak interactions (low correlations):" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAADxCAYAAAD7hRNxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWYklEQVR4nO3df6xcZZ3H8fen97Y0u6wWKCmI2EJokBpWkAYwRGUBtRBDu4rabpTiQroY3F1FDSKJbHBJym4iiz+xgQooCygi1ohhkR+LRmEpbKH8WOQKKq2VSvmhhFK4vd/94zxThunMnXM7Z849c+7nZU46c84z85yJuV+eX+f5KiIwM+u3aZN9A2Y2NTjYmFkpHGzMrBQONmZWCgcbMyuFg42ZlcLBxqymJK2WtFnSgx2uS9KXJY1IekDS25quLZf0WDqWF3E/DjZm9XUFsGic6ycC89OxAvgGgKQ9gfOBo4AjgfMl7dHrzTjYmNVURNwJPDNOkcXAVZG5C5glaV/gvcAtEfFMRDwL3ML4QSuX4V6/wMyK85f7z4ztL43lKrvt6VceAl5qOrUqIlZNoLr9gCeb3m9I5zqd74mDjVmFjL00xgHvn52r7P+t2vRSRCzs8y0Vxt0osyoRTJumXEcBNgL7N71/YzrX6XxPHGzMKkbKdxRgDXBqmpU6Gng+IjYBNwPvkbRHGhh+TzrXE3ejzCpEwLSCmgCSrgGOBWZL2kA2wzQdICIuBW4CTgJGgBeBj6Vrz0j6InBP+qoLImK8geZcHGzMqkQwNFxMsyUilnW5HsBZHa6tBlYXciOJg41ZxaimgxsONmYVIsG0ggZkqsbBxqxi3LIxs1IUNUBcNQ42ZhUiuWVjZiUZGvKYjZn1meRulJmVQqiYRxEqx8HGrErcsjGzsniA2Mz6TvIAcVtp+8DrgHnAb4APpZ29WsttB9ant7+LiJN7qdeszurajer1Z30OuDUi5gO3pvftbI2Iw9LhQGPWgQBNU65j0PQabBYDV6bXVwJLevw+s6ktDRDnOQZNr2M2c9JmOwB/AOZ0KDdT0lpgFFgZETe2KyRpBdku72hYR8yYNXWGlA7+iwMn+xZKs23W1sm+hVI9+sDvno6IvfOWr+lzmN2DjaSfAvu0uXRe85uICEnR4WvmRsRGSQcCt0laHxG/bi2UNmteBTBz7xkxb0m+vVjr4OYjLp/sWyjNyOK2aYxq6x1vOPO3ectmm2fVM9p0DTYRcUKna5KekrRvRGxKKSA2d/iOjenfxyXdARwO7BRszKa8AjfPqppee35rgEa2vOXAD1sLpH1Md0uvZwPHAA/3WK9ZLQkxTfmOQdNrsFkJvFvSY8AJ6T2SFkq6LJU5BFgr6X7gdrIxGwcbs3YKzq4gaZGkR1OK3Z1miyVdLGldOn4l6bmma9ubrq3p9af1NAIbEVuA49ucXwuckV7/Aji0l3rMpooix2wkDQFfA95NlmjuHklrmv9jHxGfair/j2RDHA1bI+KwQm4Gp3Ixq5xpmpbryOFIYCQiHo+Il4FryZardLIMuKaAn9CWg41ZlShfFypn6yd3Gl1Jc4EDgNuaTs+UtFbSXZKW7OIv2mHqLGQxGwAChodytwFmp/VrDRPN9d1sKXB9RGxvOpdryUpeDjZmFZJtnpU72DzdJdf3RNLoLqUlh1TRS1bcjTKrlEK7UfcA8yUdIGkGWUDZaVZJ0puBPYBfNp0rfMmKWzZmVVJg3qiIGJX0CbI83UPA6oh4SNIFwNqIaASepcC1KUNmwyHANyWNkTVKel6y4mBjViFFP64QETeR5fRuPveFlvf/0uZzhS9ZcbAxqxKJoaGhyb6LvnCwMauQKf0gppmVy8HGzPpOIu/q4IHjYGNWKfkfshw0DjZmFTOI20fk4WBjViESDA97NsrM+qyxeVYdOdiYVYk8G2VmJXGwMbO+E576NrMyyFPfZlYCIYaHpk/2bfRFIe21HDu47ybpunT9bknziqjXrG4EDGko1zFoeg42TTu4nwgsAJZJWtBS7HTg2Yg4CLgYuKjXes1qSWLatKFcx6ApomWTZwf3xcCV6fX1wPFSTRcTmPVomoZyHYOmiDGbdju4H9WpTNo97HlgL+DpAuo3qw2hiexBPFAqNUAsaQWwAmB498GL3Ga9ksT0oRmTfRt9UUSwybODe6PMBknDwOuBLa1flNJQrAKYufeMaL1uVn+q7TqbIn5Vnh3c1wDL0+tTgNtaNlc2MxqpXIobIM4xU3yapD825fQ+o+nackmPpWN562cnqueWTc4d3C8Hvi1pBHiGLCCZ2U5U2LR2nlzfyXUR8YmWz+4JnA8sBAK4N3322V29n0LGbLrt4B4RLwEfLKIuszor+HGFHTPFAJIaM8V5UrK8F7glIp5Jn70FWEQPucDr2Tk0G1gTWmczO+XibhwrWr4sb67vD0h6QNL1khrjr7nzhOdVqdkos6lugrNR3dLv5vEj4JqI2CbpH8jWwx3X43e25ZaNWYU0ulF5jhy6zhRHxJaI2JbeXgYckfezE+VgY1YlxT6u0HWmWNK+TW9PBh5Jr28G3pNyfu8BvCed22XuRplVigp7FCHnTPE/SToZGCWbKT4tffYZSV8kC1gAFzQGi3eVg41ZhWQZMYvrcOSYKT4XOLfDZ1cDq4u6Fwcbs0opbp1N1TjYmFWIJIb9bJSZ9Zv3IDazckgMDeDGWHk42JhVSNaycbAxs76r7xYTDjZmFSLE8DQPEJtZv0nI3SgzK4PHbMys74SYhoONmZXALRsz6zsV+CBm1TjYmFWKGJJno8yszySvszGzktS1G1VICO0lN42ZNZNzfXfSS24aM3st4UV94+klN42ZtfA6m87a5Zc5qk25D0h6J/Ar4FMR8WRrgZT3ZgXAfnvO4Y4jvlfA7Q2GY++dOjn8PnLwoZN9C5UlFftslKRFwCVkexBfFhErW66fDZxBtgfxH4G/j4jfpmvbgfWp6O8i4uRe7qWsYe8fAfMi4q+BW8hy0+wkIlZFxMKIWLjX7rNKujWzKiluzKZpiONEYAGwTNKClmL/CyxMf5vXA//WdG1rRByWjp4CDRQTbHrJTWNmr5GN2eQ5ctgxxBERLwONIY4dIuL2iHgxvb2L7O+3L4oINr3kpjGzJiIbs8lzUFz63YbTgZ80vZ+ZvvcuSUt6/W09j9n0kpvGzFpNaFFfEel3s1qljwALgXc1nZ4bERslHQjcJml9RPx6V+soZFFfL7lpzOxVBQ8Q50qhK+kE4DzgXU3DHUTExvTv45LuAA4HdjnY1HNdtNkAE0O5jhzyDHEcDnwTODkiNjed30PSbun1bOAYelzO4scVzCqkyKe+cw5x/DuwO/A9SfDqFPchwDcljZE1Sla2Wag7IQ42ZpVS7BYTOYY4TujwuV8AhS6IcrAxqxjVdHTDwcascjTZN9AXDjZmFeI9iM2sRO5GmVkJ5G6UmfWfkLcFNbNyuGVjZn3nAWIzK427UWbWZ8IDxGZWCnkFsZmVxS0bMyuBWzZmVgLl3atm4DjYmFVINkDslo2ZlcCzUWbWfxL4cQUzK0NdWzaFhFBJqyVtlvRgh+uS9GVJI5IekPS2Iuo1q59snU2eI9e3SYskPZr+9j7X5vpukq5L1++WNK/p2rnp/KOS3tvrLyuqvXYFsGic6ycC89OxAvhGQfWa1U5R2RVypt89HXg2Ig4CLgYuSp9dQJaN4S1kf9tfV840nJ0UEmwi4k6y5HOdLAauisxdwKyWLJlmxquPK+T5Xw5d0++m91em19cDxytLs7AYuDYitkXEE8BI+r5dVtZIVK40oJJWNFKJbnnhuZJuzaxKNIGjkPS7O8pExCjwPLBXzs9OSKUGiCNiFbAK4K1z3xyTfDtm5Yt05FNY+t0ylNWyyZUG1MwCRb4jhzx/dzvKSBoGXg9syfnZCSkr2KwBTk2zUkcDz0fEppLqNhssYzmP7rqm303vl6fXpwC3RUSk80vTbNUBZJM7/9PDryqmGyXpGuBYsj7kBuB8YDpARFxKlpHvJLJBpheBjxVRr1kt5Wu15PiaXOl3Lwe+LWmEbJJnafrsQ5K+S5bfexQ4KyK293I/hQSbiFjW5XoAZxVRl1mtBajA0coc6XdfAj7Y4bMXAhcWdS+VGiA2MyYyQDxQHGzMqqagblTVONiYVU09Y42DjVmlBHmntQeOg41Z1dQz1jjYmFWOg42ZlcLdKDMrQ5HrbKrEwcasSib2IOZAcbAxq5SAsXpGGwcbswoR9e1G1XMbdzOrHLdszKrGs1Fm1nceIDazssgDxGZWinrGGgcbs0oJPPVtZmUIoqYDxJ76Nqua4jY870jSnpJukfRY+nePNmUOk/RLSQ+ltNkfbrp2haQnJK1Lx2Hd6nSwMauQCIixyHX06HPArRExH7g1vW/1InBqRDRS8P6HpFlN1z8bEYelY123Ct2NMquY2N5jsyWfxWQZUSBLv3sHcM5r7iPiV02vfy9pM7A38NyuVFhIy0bSakmbJT3Y4fqxkp5vanJ9oV05symvMUCc5+iefnc8c5pyt/0BmDNeYUlHAjOAXzedvjB1ry6WtFu3Cotq2VwBfBW4apwyP4uI9xVUn1lNTWiAeNz0u5J+CuzT5tJ5r6kxIqTOT2RJ2hf4NrA8IhrNrnPJgtQMspTZ5wAXjHezReWNulPSvCK+q2HbrK2MLG7bUKqljxx86GTfQmm+85/rJ/sWqq2gXlREnNDpmqSnJO0bEZtSMNncodzrgB8D50XEXU3f3WgVbZP0LeAz3e6nzAHit0u6X9JPJL2lXQFJKxpNwue2vFDirZlVR0TkOnrUnHZ3OfDD1gIpZe8PgKsi4vqWa/umfwUsAbq2DMoKNvcBcyPircBXgBvbFYqIVRGxMCIWztpr95JuzaxCJjZm04uVwLslPQackN4jaaGky1KZDwHvBE5rM8V9taT1wHpgNvCv3SosZTYqIv7U9PomSV+XNDsini6jfrNBUsZsVERsAY5vc34tcEZ6/R3gOx0+f9xE6ywl2EjaB3gqDUQdSdai2lJG3WaDJKKQNTSVVEiwkXQN2Zz9bEkbgPOB6QARcSlwCvBxSaPAVmBp1HVNtlmvSllmU76iZqOWdbn+VbKpcTProq7/HfYKYrMq8VPfZlaWkh5XKJ2DjVmVBB6zMbMyeDbKzMriAWIz67u0n00dOdiYVY2DjZn1W0R4NsrMyuFgY2b95zEbMyuHu1FmVoYAxhxszKzPIoKxV7ZP9m30hYONWcW4G2Vm/VfjbpQzYppVSr5smL3OWOVJv5vKbW/af3hN0/kDJN0taUTSdWlz9HE52JhVSWTdqDxHj/Kk3wXY2pRi9+Sm8xcBF0fEQcCzwOndKnSwMauQAGJsLNfRo8VkaXdJ/y7J+8GUvuU4oJHeJdfnPWZjViURRP7ZqNmS1ja9XxURq3J+Nm/63ZmpjlFgZUTcCOwFPBcRo6nMBmC/bhU62JhVSUxoNqqM9LtzI2KjpAOB21KuqOfz3mCznoONpP3JcnzPIWsFroqIS1rKCLgEOAl4ETgtIu7rtW6z+okiukjZNxWQfjciNqZ/H5d0B3A48H1glqTh1Lp5I7Cx2/0UMWYzCnw6IhYARwNnSVrQUuZEYH46VgDfKKBes/oJYHvkO3qTJ/3uHpJ2S69nA8cAD6c0TLeTpWjq+PlWPQebiNjUaKVExJ+BR9i5/7aYLF9wpOTksxq5gs3stUoaIM6TfvcQYK2k+8mCy8qIeDhdOwc4W9II2RjO5d0qLHTMRtI8smbW3S2X9gOebHrfGFDahJntUNZ+NjnT7/4COLTD5x8HjpxInYUFG0m7k/XlPtmc23uC37GCrJvFnP32LOrWzAZHMJHZqIFSyDobSdPJAs3VEXFDmyIbgf2b3rcdUIqIVRGxMCIWztpr9yJuzWzARFndqNL1HGzSTNPlwCMR8aUOxdYApypzNPB80xy/mTWUt4K4dEV0o44BPgqsl7Qunfs88CaAiLgUuIls2nuEbOr7YwXUa1ZDxU19V03PwSYifg6oS5kAzuq1LrPaa0x915BXEJtVSEQw9vJo94IDyMHGrEom9rjCQHGwMauSgBh1sDGzvnN2BTMrQbhlY2aliHCwMbMSBIxtq+fjCg42ZlVS0oOYk8HBxqxCPGZjZuXwmI2ZlcXdKDPrP3ejzKwMEcHYtno+G+UkdWZVksZs8hy9yJN+V9LfNKXeXSfpJUlL0rUrJD3RdO2wbnU62JhVSYXS70bE7Y3Uu2QZMF8E/qupyGebUvOu61ahg41ZlaQxm363bJh4+t1TgJ9ExIu7WqGDjVmllNONIn/63YalwDUt5y6U9ICkixv5pcbjAWKzComxCT2uMG6u74LS75JyvB0K3Nx0+lyyIDUDWEWWR+qC8W7WwcasUib0uMK4ub6LSL+bfAj4QUS80vTdjVbRNknfAj7T7WbdjTKrkvLGbLqm322yjJYuVCOjbcqusgR4sFuFbtmYVUl5jyusBL4r6XTgt2StFyQtBM6MiDPS+3lkOd/+u+XzV0vamyzZwTrgzG4V9hxsJO0PXEU2wBRk/cZLWsocSxY5n0inboiIcft3ZlNRlLQHcZ70u+n9b8hSZbeWO26idRbRshkFPh0R90n6K+BeSbc0JSBv+FlEvK+A+sxqzc9GdZAGijal13+W9AhZJGwNNmbWRUQwOubNs7pK/bvDgbvbXH67pPuB3wOfiYiH2nx+BbAivX3hHW8489Ei7y+n2cDTk1DvZJlKv3eyfuvciRTeHm7ZjEvS7sD3gU9GxJ9aLt8HzI2IFySdBNwIzG/9jrRGYFXr+TJJWjvedGLdTKXfOwi/NQjGahpsCpn6ljSdLNBcHRE3tF6PiD9FxAvp9U3AdEmzi6jbrG7GInIdg6aI2SgBlwOPRMSXOpTZB3gqrVQ8kizIbem1brM6qmvLpohu1DHAR4H1ktalc58H3gQQEZeSPcT1cUmjwFZgaURlQ/OkduMmwVT6vZX/rRH17Uapun/zZlPPQcNz4kuv+3Cusouf/cq9VR+DauYVxGaVEp6NMrP+CxjIwd88/CBmE0mLJD0qaUTSTjuX1Ymk1ZI2S+r6AN2gk7S/pNslPSzpIUn/PNn31FFkA8R5jkHjYJNIGgK+BpwILACWSVowuXfVV1cAiyb7JkrSeKRmAXA0cFZ1/7+N2gYbd6NedSQwEhGPA0i6lmzrxFo+dhERd6YV37U3SI/UBPhxhSlgP+DJpvcbgKMm6V6sT7o8UjPpwgPEZoOvyyM11RBe1DcVbCTbJKjhjemc1UC3R2qqos6zUQ42r7oHmC/pALIgsxT4u8m9JStCnkdqqqO+K4g9G5VExCjwCbId5B8BvttuG4y6kHQN8EvgYEkb0vaQddV4pOa4pgyOJ032TbWTtWw8G1V76Yn0myb7PsoQEcsm+x7KEhE/J9srt/IigldqOhvllo1ZxZTRspH0wbTAcSxtct6pXNuFrpIOkHR3On+dpBnd6nSwMauYkvazeRB4P3BnpwJdFrpeBFwcEQcBzwJdu+EONmYVEiWtII6IRyKi27a7Oxa6RsTLwLXA4jTgfhxwfSqXJ1e4x2zMqmQDz918dtyQdxfLmeOl3y1Ap4WuewHPpUmVxvmd0r20crAxq5CIKOx5tfFyfUfEeBkw+8LBxqymxsv1nVOnha5bgFmShlPrJtcCWI/ZmFknOxa6ptmmpcCatKXv7WTb/UL3XOGAg43ZlCTpbyVtAN4O/FjSzen8GyTdBF0Xup4DnC1phGwM5/KudXoPYjMrg1s2ZlYKBxszK4WDjZmVwsHGzErhYGNmpXCwMbNSONiYWSn+H4jRLO0e+4NOAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "corr_matrix_marginal = np.corrcoef(posterior_samples.T)\n", + "fig, ax = plt.subplots(1,1, figsize=(4, 4))\n", + "im = plt.imshow(corr_matrix_marginal, clim=[-1, 1], cmap='PiYG')\n", + "_ = fig.colorbar(im)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It might be tempting to conclude that the experimental data barely constrains our parameters and that almost all parameter combinations can reproduce the experimental data. As we will show below, this is not the case." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because our toy posterior has only three parameters, we can plot posterior samples in a 3D plot:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "rc('animation', html='html5')\n", + "\n", + "# First set up the figure, the axis, and the plot element we want to animate\n", + "fig = plt.figure(figsize=(6,6))\n", + "ax = fig.add_subplot(111, projection='3d')\n", + "\n", + "ax.set_xlim((-2, 2))\n", + "ax.set_ylim((-2, 2))\n", + "\n", + "def init():\n", + " line, = ax.plot([], [], lw=2)\n", + " line.set_data([], [])\n", + " return (line,)\n", + "\n", + "def animate(angle):\n", + " num_samples_vis = 1000\n", + " line = ax.scatter(posterior_samples[:num_samples_vis, 0], posterior_samples[:num_samples_vis, 1], posterior_samples[:num_samples_vis, 2], zdir='z', s=15, c='#2171b5', depthshade=False)\n", + " ax.view_init(20, angle)\n", + " return (line,)\n", + "\n", + "anim = animation.FuncAnimation(fig, animate, init_func=init,\n", + " frames=range(0,360,5), interval=150, blit=True)\n", + "\n", + "plt.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "HTML(anim.to_html5_video())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Clearly, the range of admissible parameters is constrained to a narrow region in parameter space, which had not been evident from the marginals." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the posterior has more than three dimensions, inspecting all dimensions at once will not be possible anymore. One way to still reveal structures in high-dimensional posteriors is to inspect 2D-slices through the posterior. In `sbi`, this can be done with the `conditional_pairplot()` function, which computes the conditional distributions within the posterior. We can slice (i.e. condition) the posterior at any location, given by the `condition`. In the plot below, for all upper diagonal plots, we keep all but two parameters constant at values sampled from the posterior, and inspect what combinations of the remaining two parameters can reproduce experimental data. For the plots on the diagonal (the 1D conditionals), we keep all but one parameter constant." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "condition = posterior.sample((1,))\n", + "\n", + "_ = conditional_pairplot(\n", + " density=posterior,\n", + " condition=condition,\n", + " limits=torch.tensor([[-2., 2.]]*3),\n", + " figsize=(5,5)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This plot looks completely different from the marginals obtained with `pairplot()`. As it can be seen on the diagonal plots, if all parameters but one are kept constant, the remaining parameter has to be tuned to a narrow region in parameter space. In addition, the upper diagonal plots show strong correlations: deviations in one parameter can be compensated through changes in another parameter." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can summarize these correlations in a conditional correlation matrix, which computes the Pearson correlation coefficient of each of these pairwise plots. This matrix (below) shows strong correlations between many parameters, which can be interpreted as potential compensation mechansims:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAADxCAYAAAD7hRNxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWV0lEQVR4nO3df6xcZZ3H8ffn3tvSdVGLFAEBKYZmlxp2QRvAkCgLiIUYWhG13awUF3JXg7urqEEkkRWXWHYTu/gTG6iAEsCtiDViWOTHglFYKlsoPxapoNJaqZRSJIXC7f3uH+eZMkzv3Dm3c+bcM+d+XuakM+c8M88ZoV+eX+f5KiIwM+u1gcm+ATObGhxszKwUDjZmVgoHGzMrhYONmZXCwcbMSuFgY1ZTklZI2iTpwTbXJekrktZJekDS25quLZH0WDqWFHE/DjZm9XUlMH+c6ycDc9IxDHwTQNIbgAuBo4GjgAsl7dXtzTjYmNVURNwJPDNOkQXA1ZG5G5gpaX/gPcAtEfFMRGwBbmH8oJXLULdfYGbF+fODZsSOF0dzld3+9MsPAS82nVoeEcsnUN0BwJNN79enc+3Od8XBxqxCRl8c5ZDTZuUq+3/LN74YEfN6fEuFcTfKrEoEAwPKdRRgA3BQ0/sD07l257viYGNWMVK+owCrgDPSrNQxwNaI2AjcDJwkaa80MHxSOtcVd6PMKkTAQEFNAEnXAscBsyStJ5thmgYQEZcBNwGnAOuAbcBH0rVnJH0RuDd91UURMd5Acy4ONmZVIhgcKqbZEhGLO1wP4Jw211YAKwq5kcTBxqxiVNPBDQcbswqRYKCgAZmqcbAxqxi3bMysFEUNEFeNg41ZhUhu2ZhZSQYHPWZjZj0muRtlZqUQKuZRhMpxsDGrErdszKwsHiA2s56TPEA8prR94PXAbOA3wAfTzl6t5XYAa9Pb30XEqd3Ua1Znde1GdfuzPgvcGhFzgFvT+7G8EBFHpMOBxqwNARpQrqPfdBtsFgBXpddXAQu7/D6zqS0NEOc5+k23Yzb7ps12AP4A7Num3AxJq4ERYGlE3DhWIUnDZLu8oyG9ffrMqTOkdOCWmZN9C6V56o3PTvYtlOr5jS8/HRH75C1f0+cwOwcbST8F9hvj0gXNbyIiJEWbrzk4IjZIegtwm6S1EfHr1kJps+blADP2mR6zF+bbi7UOvrTytMm+hdIsG1412bdQqru+8ORv85bNNs+qZ7TpGGwi4sR21yQ9JWn/iNiYUkBsavMdG9Kfj0u6AzgS2CXYmE15BW6eVTXd9vxWAY1seUuAH7YWSPuY7pFezwKOBR7usl6zWhJiQPmOftNtsFkKvFvSY8CJ6T2S5km6PJU5DFgt6X7gdrIxGwcbs7EUnF1B0nxJj6YUu7vMFktaJmlNOn4l6dmmazuarnXd9+1qBDYiNgMnjHF+NXB2ev1z4PBu6jGbKoocs5E0CHwdeDdZorl7Ja1q/o99RHyyqfw/kg1xNLwQEUcUcjM4lYtZ5QxoINeRw1HAuoh4PCJeAq4jW67SzmLg2gJ+wpgcbMyqRPm6UDlbP7nT6Eo6GDgEuK3p9AxJqyXdLWnhbv6inabOQhazPiBgaDB3G2BWWr/WMNFc380WASsjYkfTuVxLVvJysDGrkGzzrNzB5ukOub4nkkZ3ES05pIpesuJulFmlFNqNuheYI+kQSdPJAsous0qS/hLYC/hF07nCl6y4ZWNWJQXmjYqIEUkfJ8vTPQisiIiHJF0ErI6IRuBZBFyXMmQ2HAZ8S9IoWaOk6yUrDjZmFVL04woRcRNZTu/mc59vef8vY3yu8CUrDjZmVSIxODg42XfREw42ZhUypR/ENLNyOdiYWc9J5F0d3HccbMwqJf9Dlv3GwcasYvpx+4g8HGzMKkSCoSHPRplZjzU2z6ojBxuzKpFno8ysJA42ZtZzwlPfZlYGeerbzEogxNDgtMm+jZ4opL2WYwf3PSRdn67fI2l2EfWa1Y2AQQ3mOvpN18GmaQf3k4G5wGJJc1uKnQVsiYhDgWXAJd3Wa1ZLEgMDg7mOflNEyybPDu4LgKvS65XACVJNFxOYdWlAg7mOflPEmM1YO7gf3a5M2j1sK7A38HQB9ZvVhtBE9iDuK5UaIJY0DAwDDO3Zf5HbrFuSmDY4fbJvoyeKCDZ5dnBvlFkvaQh4PbC59YtSGorlADP2mR6t183qT7VdZ1PEr8qzg/sqYEl6fTpwW8vmymZGI5VLcQPEOWaKz5T0x6ac3mc3XVsi6bF0LGn97ER13bLJuYP7FcB3JK0DniELSGa2CxU2rZ0n13dyfUR8vOWzbwAuBOYBAfwyfXbL7t5PIWM2nXZwj4gXgQ8UUZdZnRX8uMLOmWIASY2Z4jwpWd4D3BIRz6TP3gLMp4tc4PXsHJr1rQmts5mVcnE3juGWL8ub6/v9kh6QtFJSY/w1d57wvCo1G2U21U1wNqpT+t08fgRcGxHbJf0D2Xq447v8zjG5ZWNWIY1uVJ4jh44zxRGxOSK2p7eXA2/P+9mJcrAxq5JiH1foOFMsaf+mt6cCj6TXNwMnpZzfewEnpXO7zd0os0pRYY8i5Jwp/idJpwIjZDPFZ6bPPiPpi2QBC+CixmDx7nKwMauQLCNmcR2OHDPF5wPnt/nsCmBFUffiYGNWKcWts6kaBxuzCpHEkJ+NMrNe8x7EZlYOicE+3BgrDwcbswrJWjYONmbWc/XdYsLBxqxChBga8ACxmfWahNyNMrMyeMzGzHpOiAEcbMysBG7ZmFnPqcAHMavGwcasUsSgPBtlZj0meZ2NmZWkrt2oQkJoN7lpzKyZnOu7nW5y05jZqwkv6htPN7lpzKyF19m0N1Z+maPHKPd+Se8EfgV8MiKebC2Q8t4MA+yjPfnSytMKuL3+cP7pN0z2LZTmNVvrOQBaBKnYZ6MkzQcuJduD+PKIWNpy/VzgbLI9iP8I/H1E/DZd2wGsTUV/FxGndnMvZf1T/xEwOyL+CriFLDfNLiJieUTMi4h5rxv4s5JuzaxKihuzaRriOBmYCyyWNLel2P8C89LfzZXAvzVdeyEijkhHV4EGigk23eSmMbNXycZs8hw57BziiIiXgMYQx04RcXtEbEtv7yb7+9sTRQSbbnLTmFkTkY3Z5DkoLv1uw1nAT5rez0jfe7ekhd3+tq7HbLrJTWNmrSa0qK+I9LtZrdLfAfOAdzWdPjgiNkh6C3CbpLUR8evdraOQRX3d5KYxs1cUPECcK4WupBOBC4B3NQ13EBEb0p+PS7oDOBLY7WDjaQGzihGDuY4c8gxxHAl8Czg1IjY1nd9L0h7p9SzgWLpczuLHFcwqpMinvnMOcfw7sCfwn5LglSnuw4BvSRola5QsHWOh7oQ42JhVSrFbTOQY4jixzed+Dhxe2I3gYGNWOarp6IaDjVnlaLJvoCccbMwqxHsQm1mJ3I0ysxLI3Sgz6z0hbwtqZuVwy8bMes4DxGZWGnejzKzHhAeIzawU8gpiMyuLWzZmVgK3bMysBMq7V03fcbAxq5BsgNgtGzMrgWejzKz3JPDjCmZWhrq2bAoJoZJWSNok6cE21yXpK5LWSXpA0tuKqNesfrJ1NnmOXN8mzZf0aPq799kxru8h6fp0/R5Js5uunZ/OPyrpPd3+sqLaa1cC88e5fjIwJx3DwDcLqtesdorKrpAz/e5ZwJaIOBRYBlySPjuXLBvDW8n+bn9DOdNwtlNIsImIO8mSz7WzALg6MncDM1uyZJoZrzyukOd/OXRMv5veX5VerwROUJZmYQFwXURsj4gngHXp+3ZbWSNRudKAShpupBJ9bvSFkm7NrEo0gaOQ9Ls7y0TECLAV2DvnZyekUgPEEbEcWA5w6NAbY5Jvx6x8kY58Cku/W4ayWja50oCaWaDId+SQ5+/dzjKShoDXA5tzfnZCygo2q4Az0qzUMcDWiNhYUt1m/WU059FZx/S76f2S9Pp04LaIiHR+UZqtOoRscud/uvhVxXSjJF0LHEfWh1wPXAhMA4iIy8gy8p1CNsi0DfhIEfWa1VK+VkuOr8mVfvcK4DuS1pFN8ixKn31I0vfI8nuPAOdExI5u7qeQYBMRiztcD+CcIuoyq7UAFThamSP97ovAB9p89mLg4qLupVIDxGbGRAaI+4qDjVnVFNSNqhoHG7OqqWescbAxq5Qg77R233GwMauaesYaBxuzynGwMbNSuBtlZmUocp1NlTjYmFXJxB7E7CsONmaVEjBaz2jjYGNWIaK+3ah6buNuZpXjlo1Z1Xg2ysx6zgPEZlYWeYDYzEpRz1jjYGNWKYGnvs2sDEHUdIDYU99mVVPchudtSXqDpFskPZb+3GuMMkdI+oWkh1La7A81XbtS0hOS1qTjiE51OtiYVUgExGjkOrr0WeDWiJgD3Jret9oGnBERjRS8/yFpZtP1z0TEEelY06lCd6PMKiZ2dNlsyWcBWUYUyNLv3gGc96r7iPhV0+vfS9oE7AM8uzsVFtKykbRC0iZJD7a5fpykrU1Nrs+PVc5symsMEOc5OqffHc++Tbnb/gDsO15hSUcB04FfN52+OHWvlknao1OFRbVsrgS+Blw9Tpm7IuK9BdVnVlMTGiAeN/2upJ8C+41x6YJX1RgRUvsnsiTtD3wHWBIRjWbX+WRBajpZyuzzgIvGu9mi8kbdKWl2Ed/V8NQbn2XZcGvyvvp6zdapM3y27flSugn9q6D/eyLixHbXJD0laf+I2JiCyaY25V4H/Bi4ICLubvruRqtou6RvA5/udD9l/hv+Dkn3S/qJpLeOVUDScKNJ+PI2/wtpU1NE5Dq61Jx2dwnww9YCKWXvD4CrI2Jly7X9058CFgJjDqE0KyvY3AccHBF/DXwVuHGsQhGxPCLmRcS8aa+ZOv+lN9tpYmM23VgKvFvSY8CJ6T2S5km6PJX5IPBO4MwxprivkbQWWAvMAv61U4WlzEZFxHNNr2+S9A1JsyLi6TLqN+snZcxGRcRm4IQxzq8Gzk6vvwt8t83nj59onaUEG0n7AU+lgaijyFpUm8uo26yfRBSyhqaSCgk2kq4lm7OfJWk9cCEwDSAiLgNOBz4maQR4AVgUdV2Tbdatmg5XFjUbtbjD9a+RTY2bWQd1/e+wVxCbVYmf+jazspT0uELpHGzMqiTwmI2ZlcGzUWZWFg8Qm1nPpf1s6sjBxqxqHGzMrNciwrNRZlYOBxsz6z2P2ZhZOdyNMrMyBDDqYGNmPRYRjL68Y7JvoyccbMwqxt0oM+u9GnejvNGvWaXky4bZ7YxVnvS7qdyOpv2HVzWdP0TSPZLWSbo+bY4+LgcbsyqJrBuV5+hSnvS7AC80pdg9ten8JcCyiDgU2AKc1alCBxuzCgkgRkdzHV1aQJZ2l/TnwrwfTOlbjgca6V1yfd5jNmZVEkHkn42aJWl10/vlEbE852fzpt+dkeoYAZZGxI3A3sCzETGSyqwHDuhUoYONWZXEhGajyki/e3BEbJD0FuC2lCtqa94bbNZ1sJF0EFmO733JWoHLI+LSljICLgVOAbYBZ0bEfd3WbVY/UUQXKfumAtLvRsSG9Ofjku4AjgS+D8yUNJRaNwcCGzrdTxFjNiPApyJiLnAMcI6kuS1lTgbmpGMY+GYB9ZrVTwA7It/RnTzpd/eStEd6PQs4Fng4pWG6nSxFU9vPt+o62ETExkYrJSL+BDzCrv23BWT5giMlJ5/ZyBVsZq9W0gBxnvS7hwGrJd1PFlyWRsTD6dp5wLmS1pGN4VzRqcJCx2wkzSZrZt3TcukA4Mmm940BpY2Y2U5l7WeTM/3uz4HD23z+ceCoidRZWLCRtCdZX+4Tzbm9J/gdw2TdLPZ4/WBRt2bWP4KJzEb1laLS704jCzTXRMQNYxTZABzU9H7MAaU0bbcc4LVvml7PTT3MxlXcAHHVdD1mk2aargAeiYgvtym2CjhDmWOArU1z/GbWUN4K4tIV0bI5FvgwsFbSmnTuc8CbASLiMuAmsmnvdWRT3x8poF6zGqpvy6brYBMRPwPUoUwA53Rbl1ntNaa+a8griM0qJCIYfWmkc8E+5GBjViUTe1yhrzjYmFVJQIw42JhZzzm7gpmVINyyMbNSRDjYmFkJAka3+3EFM+u1kh7EnAwONmYV4jEbMyuHx2zMrCzuRplZ77kbZWZliAhGt9fz2SgnqTOrkjRmk+foRp70u5L+pin17hpJL0pamK5dKemJpmtHdKrTwcasSiqUfjcibm+k3iXLgLkN+K+mIp9pSs27plOFDjZmVZLGbHrdsmHi6XdPB34SEdt2t0IHG7NKKacbRf70uw2LgGtbzl0s6QFJyxr5pcbjAWKzConRCT2uMG6u74LS75JyvB0O3Nx0+nyyIDWdLEnBecBF492sg41ZpUzocYVxc30XkX43+SDwg4h4uem7G62i7ZK+DXy60826G2VWJeWN2XRMv9tkMS1dqEZG25RdZSHwYKcK3bIxq5LyHldYCnxP0lnAb8laL0iaB3w0Is5O72eT5Xz775bPXyNpH7JkB2uAj3aqsOtgI+kg4GqyAaYg6zde2lLmOLLI+UQ6dUNEjNu/M5uKoqQ9iPOk303vf0OWKru13PETrbOIls0I8KmIuE/Sa4FfSrqlKQF5w10R8d4C6jOrNT8b1UYaKNqYXv9J0iNkkbA12JhZBxHByKg3z+oo9e+OBO4Z4/I7JN0P/B74dEQ8NMbnh4Hh9Pb5u77w5KNF3l9Os4CnJ6HeyTKVfu9k/daDJ1J4R7hlMy5JewLfBz4REc+1XL4PODginpd0CnAjMKf1O9IageWt58skafV404l1M5V+bz/81iAYrWmwKWTqW9I0skBzTUTc0Ho9Ip6LiOfT65uAaZJmFVG3Wd2MRuQ6+k0Rs1ECrgAeiYgvtymzH/BUWql4FFmQ29xt3WZ1VNeWTRHdqGOBDwNrJa1J5z4HvBkgIi4je4jrY5JGgBeARRGVDc2T2o2bBFPp91b+t0bUtxul6v6dN5t6Dh3aN778ug/lKrtgy1d/WfUxqGZeQWxWKeHZKDPrvYC+HPzNww9iNpE0X9KjktZJ2mXnsjqRtELSJkkdH6Drd5IOknS7pIclPSTpnyf7ntqKbIA4z9FvHGwSSYPA14GTgbnAYklzJ/eueupKYP5k30RJGo/UzAWOAc6p7j/bqG2wcTfqFUcB6yLicQBJ15FtnVjLxy4i4s604rv2+umRmgA/rjAFHAA82fR+PXD0JN2L9UiHR2omXXiA2Kz/dXikphrCi/qmgg1kmwQ1HJjOWQ10eqSmKuo8G+Vg84p7gTmSDiELMouAv53cW7Ii5Hmkpjrqu4LYs1FJRIwAHyfbQf4R4HtjbYNRF5KuBX4B/IWk9Wl7yLpqPFJzfFMGx1Mm+6bGkrVsPBtVe+mJ9Jsm+z7KEBGLJ/seyhIRPyPbK7fyIoKXazob5ZaNWcWU0bKR9IG0wHE0bXLertyYC10lHSLpnnT+eknTO9XpYGNWMSXtZ/MgcBpwZ7sCHRa6XgIsi4hDgS1Ax264g41ZhURJK4gj4pGI6LTt7s6FrhHxEnAdsCANuB8PrEzl8uQK95iNWZWs59mbz40b8u5iOWO89LsFaLfQdW/g2TSp0ji/S7qXVg42ZhUSEYU9rzZeru+IGC8DZk842JjV1Hi5vnNqt9B1MzBT0lBq3eRaAOsxGzNrZ+dC1zTbtAhYlbb0vZ1su1/onCsccLAxm5IkvU/SeuAdwI8l3ZzOv0nSTdBxoet5wLmS1pGN4VzRsU7vQWxmZXDLxsxK4WBjZqVwsDGzUjjYmFkpHGzMrBQONmZWCgcbMyvF/wPZqxE8c4CsDAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "cond_coeff_mat = conditional_corrcoeff(\n", + " density=posterior,\n", + " condition=condition,\n", + " limits=torch.tensor([[-2., 2.]]*3),\n", + ")\n", + "fig, ax = plt.subplots(1,1, figsize=(4,4))\n", + "im = plt.imshow(cond_coeff_mat, clim=[-1, 1], cmap='PiYG')\n", + "_ = fig.colorbar(im)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So far, we have investigated the conditional distribution only at a specific `condition` sampled from the posterior. In many applications, it makes sense to repeat the above analyses with a different `condition` (another sample from the posterior), which can be interpreted as slicing the posterior at a different location. Note that `conditional_corrcoeff()` can directly compute the matrix for several `conditions` and then outputs the average over them. This can be done by passing a batch of $N$ conditions as the `condition` argument." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampling conditional distributions\n", + "\n", + "So far, we have demonstrated how one can plot 2D conditional distributions with `conditional_pairplot()` and how one can compute the pairwise conditional correlation coefficient with `conditional_corrcoeff()`. In some cases, it can be useful to keep a subset of parameters fixed and to vary **more than two** parameters. This can be done by sampling the conditonal posterior $p(\\theta_i | \\theta_{j \\neq i}, x_o)$. As of `sbi` `v0.18.0`, this functionality requires using the [sampler interface](https://www.mackelab.org/sbi/tutorial/11_sampler_interface/). In this tutorial, we demonstrate this functionality on a linear gaussian simulator with four parameters. We would like to fix the forth parameter to $\\theta_4=0.2$ and sample the first three parameters given that value, i.e. we want to sample $p(\\theta_1, \\theta_2, \\theta_3 | \\theta_4 = 0.2, x_o)$. For an application in neuroscience, see [Deistler, Gonçalves, Macke, 2021](https://www.biorxiv.org/content/10.1101/2021.07.30.454484v4.abstract)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we will use SNPE, but the same also works for SNLE and SNRE. First, we define the prior and the simulator and train the deep neural density estimator:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4683b28ba0614e87b33854d207e30da2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running 1000 simulations.: 0%| | 0/1000 [00:00" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from sbi.analysis import pairplot\n", + "\n", + "_ = pairplot(cond_samples, limits=[[-2, 2], [-2, 2], [-2, 2], [-2, 2]], figsize=(4, 4))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 713398b63a24c5033761fed3fae90dde61a1f1e9 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Thu, 21 Apr 2022 15:09:29 -0400 Subject: [PATCH 07/15] Trying to fix whitespace issue --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 9d9fc26e5..bff20f8dc 100644 --- a/README.md +++ b/README.md @@ -117,3 +117,4 @@ If you use `sbi` consider citing the [sbi software paper](https://doi.org/10.211 journal = {Journal of Open Source Software} } ``` + From 0d167662bd7803074f76969b145f08fce8bb04c1 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Thu, 21 Apr 2022 15:12:07 -0400 Subject: [PATCH 08/15] new_branch --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index bff20f8dc..9d9fc26e5 100644 --- a/README.md +++ b/README.md @@ -117,4 +117,3 @@ If you use `sbi` consider citing the [sbi software paper](https://doi.org/10.211 journal = {Journal of Open Source Software} } ``` - From 1767ec89a0684a22a26a2a0eb46a429530e792bf Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Thu, 21 Apr 2022 15:18:54 -0400 Subject: [PATCH 09/15] Renormalize - fixes line endings --- README.md | 238 +- sbi/inference/snpe/snpe_a.py | 1718 +- sbi/inference/snpe/snpe_base.py | 1194 +- sbi/inference/snpe/snpe_c.py | 1250 +- tests/linearGaussian_snpe_test.py | 1166 +- tutorials/07_conditional_distributions.ipynb | 33176 ++++++++--------- 6 files changed, 19371 insertions(+), 19371 deletions(-) diff --git a/README.md b/README.md index 9d9fc26e5..96012b76c 100644 --- a/README.md +++ b/README.md @@ -1,119 +1,119 @@ -[![PyPI version](https://badge.fury.io/py/sbi.svg)](https://badge.fury.io/py/sbi) -[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mackelab/sbi/blob/master/CONTRIBUTING.md) -[![Tests](https://github.com/mackelab/sbi/workflows/Tests/badge.svg?branch=main)](https://github.com/mackelab/sbi/actions) -[![codecov](https://codecov.io/gh/mackelab/sbi/branch/main/graph/badge.svg)](https://codecov.io/gh/mackelab/sbi) -[![GitHub license](https://img.shields.io/github/license/mackelab/sbi)](https://github.com/mackelab/sbi/blob/master/LICENSE.txt) -[![DOI](https://joss.theoj.org/papers/10.21105/joss.02505/status.svg)](https://doi.org/10.21105/joss.02505) - -## sbi: simulation-based inference -[Getting Started](https://www.mackelab.org/sbi/tutorial/00_getting_started/) | [Documentation](https://www.mackelab.org/sbi/) - -`sbi` is a PyTorch package for simulation-based inference. Simulation-based inference is -the process of finding parameters of a simulator from observations. - -`sbi` takes a Bayesian approach and returns a full posterior distribution -over the parameters, conditional on the observations. This posterior can be amortized (i.e. -useful for any observation) or focused (i.e. tailored to a particular observation), with different -computational trade-offs. - -`sbi` offers a simple interface for one-line posterior inference. - -```python -from sbi.inference import infer -# import your simulator, define your prior over the parameters -parameter_posterior = infer(simulator, prior, method='SNPE', num_simulations=100) -``` -See below for the available methods of inference, `SNPE`, `SNRE` and `SNLE`. - - -## Installation - -`sbi` requires Python 3.6 or higher. We recommend to use a [`conda`](https://docs.conda.io/en/latest/miniconda.html) virtual -environment ([Miniconda installation instructions](https://docs.conda.io/en/latest/miniconda.html])). If `conda` is installed on the system, an environment for -installing `sbi` can be created as follows: -```commandline -# Create an environment for sbi (indicate Python 3.6 or higher); activate it -$ conda create -n sbi_env python=3.7 && conda activate sbi_env -``` - -Independent of whether you are using `conda` or not, `sbi` can be installed using `pip`: -```commandline -$ pip install sbi -``` - -To test the installation, drop into a python prompt and run -```python -from sbi.examples.minimal import simple -posterior = simple() -print(posterior) -``` - -## Inference Algorithms - -The following algorithms are currently available: - -#### Sequential Neural Posterior Estimation (SNPE) - -* [`SNPE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snpe.snpe_a.SNPE_A) from Papamakarios G and Murray I [_Fast ε-free Inference of Simulation Models with Bayesian Conditional Density Estimation_](https://proceedings.neurips.cc/paper/2016/hash/6aca97005c68f1206823815f66102863-Abstract.html) (NeurIPS 2016). - -* [`SNPE_C`](https://www.mackelab.org/sbi/reference/#sbi.inference.snpe.snpe_c.SNPE_C) or `APT` from Greenberg D, Nonnenmacher M, and Macke J [_Automatic - Posterior Transformation for likelihood-free - inference_](https://arxiv.org/abs/1905.07488) (ICML 2019). - - -#### Sequential Neural Likelihood Estimation (SNLE) -* [`SNLE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snle.snle_a.SNLE_A) or just `SNL` from Papamakarios G, Sterrat DC and Murray I [_Sequential - Neural Likelihood_](https://arxiv.org/abs/1805.07226) (AISTATS 2019). - - -#### Sequential Neural Ratio Estimation (SNRE) - -* [`SNRE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_a.SNRE_A) or `AALR` from Hermans J, Begy V, and Louppe G. [_Likelihood-free Inference with Amortized Approximate Likelihood Ratios_](https://arxiv.org/abs/1903.04057) (ICML 2020). - -* [`SNRE_B`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_b.SNRE_B) or `SRE` from Durkan C, Murray I, and Papamakarios G. [_On Contrastive Learning for Likelihood-free Inference_](https://arxiv.org/abs/2002.03712) (ICML 2020). - -#### Sequential Neural Variational Inference (SNVI) - -* [`SNVI`](https://www.mackelab.org/sbi/reference/#sbi.inference.posteriors.vi_posterior) from Glöckler M, Deistler M, Macke J, [_Variational methods for simulation-based inference_](https://openreview.net/forum?id=kZ0UYdhqkNY) (ICLR 2022). - -## Feedback and Contributions - -We would like to hear how `sbi` is working for your inference problems as well as receive bug reports, pull requests and other feedback (see -[contribute](http://www.mackelab.org/sbi/contribute/)). - - -## Acknowledgements - -`sbi` is the successor (using PyTorch) of the -[`delfi`](https://github.com/mackelab/delfi) package. It was started as a fork of Conor -M. Durkan's `lfi`. `sbi` runs as a community project; development is coordinated at the -[mackelab](https://uni-tuebingen.de/en/research/core-research/cluster-of-excellence-machine-learning/research/research/cluster-research-groups/professorships/machine-learning-in-science/). See also [credits](https://github.com/mackelab/sbi/blob/master/docs/docs/credits.md). - - -## Support - -`sbi` has been supported by the German Federal Ministry of Education and Research (BMBF) through the project ADIMEM, FKZ 01IS18052 A-D). [ADIMEM](https://fit.uni-tuebingen.de/Project/Details?id=9199) is a collaborative project between the groups of Jakob Macke (Uni Tübingen), Philipp Berens (Uni Tübingen), Philipp Hennig (Uni Tübingen) and Marcel Oberlaender (caesar Bonn) which aims to develop inference methods for mechanistic models. - - -## License - -[Affero General Public License v3 (AGPLv3)](https://www.gnu.org/licenses/) - - -## Citation -If you use `sbi` consider citing the [sbi software paper](https://doi.org/10.21105/joss.02505), in addition to the original research articles describing the specifc sbi-algorithm(s) you are using: - -``` -@article{tejero-cantero2020sbi, - doi = {10.21105/joss.02505}, - url = {https://doi.org/10.21105/joss.02505}, - year = {2020}, - publisher = {The Open Journal}, - volume = {5}, - number = {52}, - pages = {2505}, - author = {Alvaro Tejero-Cantero and Jan Boelts and Michael Deistler and Jan-Matthis Lueckmann and Conor Durkan and Pedro J. Gonçalves and David S. Greenberg and Jakob H. Macke}, - title = {sbi: A toolkit for simulation-based inference}, - journal = {Journal of Open Source Software} -} -``` +[![PyPI version](https://badge.fury.io/py/sbi.svg)](https://badge.fury.io/py/sbi) +[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mackelab/sbi/blob/master/CONTRIBUTING.md) +[![Tests](https://github.com/mackelab/sbi/workflows/Tests/badge.svg?branch=main)](https://github.com/mackelab/sbi/actions) +[![codecov](https://codecov.io/gh/mackelab/sbi/branch/main/graph/badge.svg)](https://codecov.io/gh/mackelab/sbi) +[![GitHub license](https://img.shields.io/github/license/mackelab/sbi)](https://github.com/mackelab/sbi/blob/master/LICENSE.txt) +[![DOI](https://joss.theoj.org/papers/10.21105/joss.02505/status.svg)](https://doi.org/10.21105/joss.02505) + +## sbi: simulation-based inference +[Getting Started](https://www.mackelab.org/sbi/tutorial/00_getting_started/) | [Documentation](https://www.mackelab.org/sbi/) + +`sbi` is a PyTorch package for simulation-based inference. Simulation-based inference is +the process of finding parameters of a simulator from observations. + +`sbi` takes a Bayesian approach and returns a full posterior distribution +over the parameters, conditional on the observations. This posterior can be amortized (i.e. +useful for any observation) or focused (i.e. tailored to a particular observation), with different +computational trade-offs. + +`sbi` offers a simple interface for one-line posterior inference. + +```python +from sbi.inference import infer +# import your simulator, define your prior over the parameters +parameter_posterior = infer(simulator, prior, method='SNPE', num_simulations=100) +``` +See below for the available methods of inference, `SNPE`, `SNRE` and `SNLE`. + + +## Installation + +`sbi` requires Python 3.6 or higher. We recommend to use a [`conda`](https://docs.conda.io/en/latest/miniconda.html) virtual +environment ([Miniconda installation instructions](https://docs.conda.io/en/latest/miniconda.html])). If `conda` is installed on the system, an environment for +installing `sbi` can be created as follows: +```commandline +# Create an environment for sbi (indicate Python 3.6 or higher); activate it +$ conda create -n sbi_env python=3.7 && conda activate sbi_env +``` + +Independent of whether you are using `conda` or not, `sbi` can be installed using `pip`: +```commandline +$ pip install sbi +``` + +To test the installation, drop into a python prompt and run +```python +from sbi.examples.minimal import simple +posterior = simple() +print(posterior) +``` + +## Inference Algorithms + +The following algorithms are currently available: + +#### Sequential Neural Posterior Estimation (SNPE) + +* [`SNPE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snpe.snpe_a.SNPE_A) from Papamakarios G and Murray I [_Fast ε-free Inference of Simulation Models with Bayesian Conditional Density Estimation_](https://proceedings.neurips.cc/paper/2016/hash/6aca97005c68f1206823815f66102863-Abstract.html) (NeurIPS 2016). + +* [`SNPE_C`](https://www.mackelab.org/sbi/reference/#sbi.inference.snpe.snpe_c.SNPE_C) or `APT` from Greenberg D, Nonnenmacher M, and Macke J [_Automatic + Posterior Transformation for likelihood-free + inference_](https://arxiv.org/abs/1905.07488) (ICML 2019). + + +#### Sequential Neural Likelihood Estimation (SNLE) +* [`SNLE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snle.snle_a.SNLE_A) or just `SNL` from Papamakarios G, Sterrat DC and Murray I [_Sequential + Neural Likelihood_](https://arxiv.org/abs/1805.07226) (AISTATS 2019). + + +#### Sequential Neural Ratio Estimation (SNRE) + +* [`SNRE_A`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_a.SNRE_A) or `AALR` from Hermans J, Begy V, and Louppe G. [_Likelihood-free Inference with Amortized Approximate Likelihood Ratios_](https://arxiv.org/abs/1903.04057) (ICML 2020). + +* [`SNRE_B`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_b.SNRE_B) or `SRE` from Durkan C, Murray I, and Papamakarios G. [_On Contrastive Learning for Likelihood-free Inference_](https://arxiv.org/abs/2002.03712) (ICML 2020). + +#### Sequential Neural Variational Inference (SNVI) + +* [`SNVI`](https://www.mackelab.org/sbi/reference/#sbi.inference.posteriors.vi_posterior) from Glöckler M, Deistler M, Macke J, [_Variational methods for simulation-based inference_](https://openreview.net/forum?id=kZ0UYdhqkNY) (ICLR 2022). + +## Feedback and Contributions + +We would like to hear how `sbi` is working for your inference problems as well as receive bug reports, pull requests and other feedback (see +[contribute](http://www.mackelab.org/sbi/contribute/)). + + +## Acknowledgements + +`sbi` is the successor (using PyTorch) of the +[`delfi`](https://github.com/mackelab/delfi) package. It was started as a fork of Conor +M. Durkan's `lfi`. `sbi` runs as a community project; development is coordinated at the +[mackelab](https://uni-tuebingen.de/en/research/core-research/cluster-of-excellence-machine-learning/research/research/cluster-research-groups/professorships/machine-learning-in-science/). See also [credits](https://github.com/mackelab/sbi/blob/master/docs/docs/credits.md). + + +## Support + +`sbi` has been supported by the German Federal Ministry of Education and Research (BMBF) through the project ADIMEM, FKZ 01IS18052 A-D). [ADIMEM](https://fit.uni-tuebingen.de/Project/Details?id=9199) is a collaborative project between the groups of Jakob Macke (Uni Tübingen), Philipp Berens (Uni Tübingen), Philipp Hennig (Uni Tübingen) and Marcel Oberlaender (caesar Bonn) which aims to develop inference methods for mechanistic models. + + +## License + +[Affero General Public License v3 (AGPLv3)](https://www.gnu.org/licenses/) + + +## Citation +If you use `sbi` consider citing the [sbi software paper](https://doi.org/10.21105/joss.02505), in addition to the original research articles describing the specifc sbi-algorithm(s) you are using: + +``` +@article{tejero-cantero2020sbi, + doi = {10.21105/joss.02505}, + url = {https://doi.org/10.21105/joss.02505}, + year = {2020}, + publisher = {The Open Journal}, + volume = {5}, + number = {52}, + pages = {2505}, + author = {Alvaro Tejero-Cantero and Jan Boelts and Michael Deistler and Jan-Matthis Lueckmann and Conor Durkan and Pedro J. Gonçalves and David S. Greenberg and Jakob H. Macke}, + title = {sbi: A toolkit for simulation-based inference}, + journal = {Journal of Open Source Software} +} +``` diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 664e62fff..d480e6997 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -1,859 +1,859 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . - -import warnings -from copy import deepcopy -from functools import partial -from typing import Any, Callable, Dict, Optional, Union - -import torch -import torch.nn as nn -from pyknos.mdn.mdn import MultivariateGaussianMDN -from pyknos.nflows import flows -from pyknos.nflows.transforms import CompositeTransform -from torch import Tensor -from torch.distributions import Distribution, MultivariateNormal - -import sbi.utils as utils -from sbi.inference.posteriors.direct_posterior import DirectPosterior -from sbi.inference.snpe.snpe_base import PosteriorEstimator -from sbi.types import TensorboardSummaryWriter, TorchModule -from sbi.utils import torchutils - - -class SNPE_A(PosteriorEstimator): - def __init__( - self, - prior: Optional[Distribution] = None, - density_estimator: Union[str, Callable] = "mdn_snpe_a", - num_components: int = 10, - device: str = "cpu", - logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorboardSummaryWriter] = None, - show_progress_bars: bool = True, - ): - r"""SNPE-A [1]. - - [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional - Density Estimation_, Papamakarios et al., NeurIPS 2016, - https://arxiv.org/abs/1605.06376. - - This class implements SNPE-A. SNPE-A trains across multiple rounds with a - maximum-likelihood-loss. This will make training converge to the proposal - posterior instead of the true posterior. To correct for this, SNPE-A applies a - post-hoc correction after training. This correction has to be performed - analytically. Thus, SNPE-A is limited to Gaussian distributions for all but the - last round. In the last round, SNPE-A can use a Mixture of Gaussians. - - Args: - prior: A probability distribution that expresses prior knowledge about the - parameters, e.g. which ranges are meaningful for them. Any - object with `.log_prob()`and `.sample()` (for example, a PyTorch - distribution) can be used. - density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a - pre-configured mixture of densities network. Alternatively, a function - that builds a custom neural network can be provided. The function will - be called with the first batch of simulations (theta, x), which can - thus be used for shape inference and potentially for z-scoring. It - needs to return a PyTorch `nn.Module` implementing the density - estimator. The density estimator needs to provide the methods - `.log_prob` and `.sample()`. Note that until the last round only a - single (multivariate) Gaussian component is used for training (see - Algorithm 1 in [1]). In the last round, this component is replicated - `num_components` times, its parameters are perturbed with a very small - noise, and then the last training round is done with the expanded - Gaussian mixture as estimator for the proposal posterior. - num_components: Number of components of the mixture of Gaussians in the - last round. This overrides the `num_components` value passed to - `posterior_nn()`. - device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". - logging_level: Minimum severity of messages to log. One of the strings - INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) - show_progress_bars: Whether to show a progressbar during training. - """ - - # Catch invalid inputs. - if not ((density_estimator == "mdn_snpe_a") or callable(density_estimator)): - raise TypeError( - "The `density_estimator` passed to SNPE_A needs to be a " - "callable or the string 'mdn_snpe_a'!" - ) - - # `num_components` will be used to replicate the Gaussian in the last round. - self._num_components = num_components - self._ran_final_round = False - - # WARNING: sneaky trick ahead. We proxy the parent's `train` here, - # requiring the signature to have `num_atoms`, save it for use below, and - # continue. It's sneaky because we are using the object (self) as a namespace - # to pass arguments between functions, and that's implicit state management. - kwargs = utils.del_entries( - locals(), - entries=("self", "__class__", "num_components"), - ) - super().__init__(**kwargs) - - def train( - self, - final_round: bool = False, - training_batch_size: int = 50, - learning_rate: float = 5e-4, - validation_fraction: float = 0.1, - stop_after_epochs: int = 20, - max_num_epochs: int = 2**31 - 1, - clip_max_norm: Optional[float] = 5.0, - calibration_kernel: Optional[Callable] = None, - resume_training: bool = False, - force_first_round_loss: bool = False, - retrain_from_scratch: bool = False, - show_train_summary: bool = False, - dataloader_kwargs: Optional[Dict] = None, - component_perturbation: float = 5e-3, - ) -> nn.Module: - r"""Return density estimator that approximates the proposal posterior. - - [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional - Density Estimation_, Papamakarios et al., NeurIPS 2016, - https://arxiv.org/abs/1605.06376. - - Training is performed with maximum likelihood on samples from the latest round, - which leads the algorithm to converge to the proposal posterior. - - Args: - final_round: Whether we are in the last round of training or not. For all - but the last round, Algorithm 1 from [1] is executed. In last the - round, Algorithm 2 from [1] is executed once. - training_batch_size: Training batch size. - learning_rate: Learning rate for Adam optimizer. - validation_fraction: The fraction of data to use for validation. - stop_after_epochs: The number of epochs to wait for improvement on the - validation set before terminating training. - max_num_epochs: Maximum number of epochs to run. If reached, we stop - training even when the validation loss is still decreasing. Otherwise, - we train until validation loss increases (see also `stop_after_epochs`). - clip_max_norm: Value at which to clip the total gradient norm in order to - prevent exploding gradients. Use None for no clipping. - calibration_kernel: A function to calibrate the loss with respect to the - simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - resume_training: Can be used in case training time is limited, e.g. on a - cluster. If `True`, the split between train and validation set, the - optimizer, the number of epochs, and the best validation log-prob will - be restored from the last time `.train()` was called. - force_first_round_loss: If `True`, train with maximum likelihood, - i.e., potentially ignoring the correction for using a proposal - distribution different from the prior. - force_first_round_loss: If `True`, train with maximum likelihood, - regardless of the proposal distribution. - retrain_from_scratch: Whether to retrain the conditional density - estimator for the posterior from scratch each round. Not supported for - SNPE-A. - show_train_summary: Whether to print the number of epochs and validation - loss and leakage after the training. - dataloader_kwargs: Additional or updated kwargs to be passed to the training - and validation dataloaders (like, e.g., a collate_fn) - component_perturbation: The standard deviation applied to all weights and - biases when, in the last round, the Mixture of Gaussians is build from - a single Gaussian. This value can be problem-specific and also depends - on the number of mixture components. - - Returns: - Density estimator that approximates the distribution $p(\theta|x)$. - """ - - assert not retrain_from_scratch, """Retraining from scratch is not supported in SNPE-A yet. The reason for - this is that, if we reininitialized the density estimator, the z-scoring would - change, which would break the posthoc correction. This is a pure implementation - issue.""" - - kwargs = utils.del_entries( - locals(), - entries=("self", "__class__", "final_round", "component_perturbation"), - ) - - # SNPE-A always discards the prior samples. - kwargs["discard_prior_samples"] = True - - self._round = max(self._data_round_index) - - if final_round: - # If there is (will be) only one round, train with Algorithm 2 from [1]. - if self._round == 0: - self._build_neural_net = partial( - self._build_neural_net, num_components=self._num_components - ) - # Run Algorithm 2 from [1]. - elif not self._ran_final_round: - # Now switch to the specified number of components. This method will - # only be used if `retrain_from_scratch=True`. Otherwise, - # the MDN will be built from replicating the single-component net for - # `num_component` times (via `_expand_mog()`). - self._build_neural_net = partial( - self._build_neural_net, num_components=self._num_components - ) - - # Extend the MDN to the originally desired number of components. - self._expand_mog(eps=component_perturbation) - else: - warnings.warn( - "You have already run SNPE-A with `final_round=True`. Running it" - "again with this setting will not allow computing the posthoc" - "correction applied in SNPE-A. Thus, you will get an error when " - "calling `.build_posterior()` after training.", - UserWarning, - ) - else: - # Run Algorithm 1 from [1]. - # Wrap the function that builds the MDN such that we can make - # sure that there is only one component when running. - self._build_neural_net = partial(self._build_neural_net, num_components=1) - - if final_round: - self._ran_final_round = True - - return super().train(**kwargs) - - def correct_for_proposal( - self, - density_estimator: Optional[TorchModule] = None, - ) -> "SNPE_A_MDN": - r"""Build mixture of Gaussians that approximates the posterior. - - Returns a `SNPE_A_MDN` object, which applies the posthoc-correction required in - SNPE-A. - - Args: - density_estimator: The density estimator that the posterior is based on. - If `None`, use the latest neural density estimator that was trained. - - Returns: - Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. - """ - if density_estimator is None: - density_estimator = deepcopy( - self._neural_net - ) # PosteriorEstimator.train() also returns a deepcopy, mimic this here - # If internal net is used device is defined. - device = self._device - else: - # Otherwise, infer it from the device of the net parameters. - device = str(next(density_estimator.parameters()).device) - - # Set proposal of the density estimator. - # This also evokes the z-scoring correction if necessary. - if ( - self._proposal_roundwise[-1] is self._prior - or self._proposal_roundwise[-1] is None - ): - proposal = self._prior - assert isinstance( - proposal, (MultivariateNormal, utils.BoxUniform) - ), """Prior must be `torch.distributions.MultivariateNormal` or `sbi.utils. - BoxUniform`""" - else: - assert isinstance( - self._proposal_roundwise[-1], DirectPosterior - ), """The proposal you passed to `append_simulations` is neither the prior - nor a `DirectPosterior`. SNPE-A currently only supports these scenarios. - """ - proposal = self._proposal_roundwise[-1] - - # Create the SNPE_A_MDN - wrapped_density_estimator = SNPE_A_MDN( - flow=density_estimator, proposal=proposal, prior=self._prior, device=device - ) - return wrapped_density_estimator - - def build_posterior( - self, - density_estimator: Optional[TorchModule] = None, - prior: Optional[Distribution] = None, - ) -> "DirectPosterior": - r"""Build posterior from the neural density estimator. - - This method first corrects the estimated density with `correct_for_proposal` - and then returns a `DirectPosterior`. - - Args: - density_estimator: The density estimator that the posterior is based on. - If `None`, use the latest neural density estimator that was trained. - prior: Prior distribution. - - Returns: - Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. - """ - if prior is None: - assert ( - self._prior is not None - ), """You did not pass a prior. You have to pass the prior either at - initialization `inference = SNPE_A(prior)` or to `.build_posterior - (prior=prior)`.""" - prior = self._prior - - wrapped_density_estimator = self.correct_for_proposal( - density_estimator=density_estimator - ) - self._posterior = DirectPosterior( - posterior_estimator=wrapped_density_estimator, - prior=prior, - ) - return deepcopy(self._posterior) - - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - ) -> Tensor: - """Return the log-probability of the proposal posterior. - - For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in - `_loss()` to be found in `snpe_base.py`. - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - proposal: Proposal distribution. - - Returns: Log-probability of the proposal posterior. - """ - return self._neural_net.log_prob(theta, x) - - def _expand_mog(self, eps: float = 1e-5): - """ - Replicate a singe Gaussian trained with Algorithm 1 before continuing - with Algorithm 2. The weights and biases of the associated MDN layers - are repeated `num_components` times, slightly perturbed to break the - symmetry such that the gradients in the subsequent training are not - all identical. - - Args: - eps: Standard deviation for the random perturbation. - """ - assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN) - - # Increase the number of components - self._neural_net._distribution._num_components = self._num_components - - # Expand the 1-dim Gaussian. - for name, param in self._neural_net.named_parameters(): - if any( - key in name for key in ["logits", "means", "unconstrained", "upper"] - ): - if "bias" in name: - param.data = param.data.repeat(self._num_components) - param.data.add_(torch.randn_like(param.data) * eps) - param.grad = None # let autograd construct a new gradient - elif "weight" in name: - param.data = param.data.repeat(self._num_components, 1) - param.data.add_(torch.randn_like(param.data) * eps) - param.grad = None # let autograd construct a new gradient - - -class SNPE_A_MDN(nn.Module): - """Generates a posthoc-corrected MDN which approximates the posterior. - - This class takes as input the density estimator (abbreviated with `_d` suffix, aka - the proposal posterior) and the proposal prior (abbreviated with `_pp` suffix) from - which the simulations were drawn. It uses the algorithm presented in SNPE-A [1] to - compute the approximate posterior (abbreviated with `_p` suffix) from the two. The - approximate posterior is a MoG. This class also implements log-prob calculation - sampling from the approximate posterior. It inherits from `nn.Module` since the - constructor of `DirectPosterior` expects the argument `neural_net` to be a - `nn.Module`. - - [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional - Density Estimation_, Papamakarios et al., NeurIPS 2016, - https://arxiv.org/abs/1605.06376. - """ - - def __init__( - self, - flow: flows.Flow, - proposal: Union["utils.BoxUniform", "MultivariateNormal", "DirectPosterior"], - prior: Distribution, - device: str, - ): - """Constructor. - - Args: - flow: The trained normalizing flow, passed when building the posterior. - proposal: The proposal distribution. - prior: The prior distribution. - """ - # Call nn.Module's constructor. - super().__init__() - - self._neural_net = flow - self._prior = prior - self._device = device - - # Set the proposal using the `default_x`. - if isinstance(proposal, (utils.BoxUniform, MultivariateNormal)): - self._apply_correction = False - else: - self._apply_correction = True - logits_pp, m_pp, prec_pp = proposal.posterior_estimator._posthoc_correction( - proposal.default_x - ) - self._logits_pp, self._m_pp, self._prec_pp = ( - logits_pp.detach(), - m_pp.detach(), - prec_pp.detach(), - ) - - # Take care of z-scoring, pre-compute and store prior terms. - self._set_state_for_mog_proposal() - - def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: - inputs, context = inputs.to(self._device), context.to(self._device) - - if not self._apply_correction: - return self._neural_net.log_prob(inputs, context) - else: - # When we want to compute the approx. posterior, a proposal prior \tilde{p} - # has already been observed. To analytically calculate the log-prob of the - # Gaussian, we first need to compute the mixture components. - - # Compute the mixture components of the proposal posterior. - logits_pp, m_pp, prec_pp = self._posthoc_correction(context) - - # z-score theta if it z-scoring had been requested. - theta = self._maybe_z_score_theta(inputs) - - # Compute the log_prob of theta under the product. - log_prob_proposal_posterior = utils.mog_log_prob( - theta, logits_pp, m_pp, prec_pp - ) - utils.assert_all_finite( - log_prob_proposal_posterior, "proposal posterior eval" - ) - return log_prob_proposal_posterior # \hat{p} from eq (3) in [1] - - def sample(self, num_samples: int, context: Tensor, batch_size: int = 1) -> Tensor: - context = context.to(self._device) - - if not self._apply_correction: - return self._neural_net.sample(num_samples, context, batch_size) - else: - # When we want to sample from the approx. posterior, a proposal prior - # \tilde{p} has already been observed. To analytically calculate the - # log-prob of the Gaussian, we first need to compute the mixture components. - return self._sample_approx_posterior_mog(num_samples, context, batch_size) - - def _sample_approx_posterior_mog( - self, num_samples, x: Tensor, batch_size: int - ) -> Tensor: - r"""Sample from the approximate posterior. - - Args: - num_samples: Desired number of samples. - x: Conditioning context for posterior $p(\theta|x)$. - batch_size: Batch size for sampling. - - Returns: - Samples from the approximate mixture of Gaussians posterior. - """ - - # Compute the mixture components of the posterior. - logits_p, m_p, prec_p = self._posthoc_correction(x) - - # Compute the precision factors which represent the upper triangular matrix - # of the cholesky decomposition of the prec_p. - prec_factors_p = torch.linalg.cholesky(prec_p, upper=True) - - assert logits_p.ndim == 2 - assert m_p.ndim == 3 - assert prec_p.ndim == 4 - assert prec_factors_p.ndim == 4 - - # Replicate to use batched sampling from pyknos. - if batch_size is not None and batch_size > 1: - logits_p = logits_p.repeat(batch_size, 1) - m_p = m_p.repeat(batch_size, 1, 1) - prec_factors_p = prec_factors_p.repeat(batch_size, 1, 1, 1) - - # Get (optionally z-scored) MoG samples. - theta = MultivariateGaussianMDN.sample_mog( - num_samples, logits_p, m_p, prec_factors_p - ) - - embedded_context = self._neural_net._embedding_net(x) - if embedded_context is not None: - # Merge the context dimension with sample dimension in order to - # apply the transform. - theta = torchutils.merge_leading_dims(theta, num_dims=2) - embedded_context = torchutils.repeat_rows( - embedded_context, num_reps=num_samples - ) - - theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) - - if embedded_context is not None: - # Split the context dimension from sample dimension. - theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples]) - - return theta - - def _posthoc_correction(self, x: Tensor): - """ - Compute the mixture components of the posterior given the current density - estimator and the proposal. - - Args: - x: Conditioning context for posterior. - - Returns: - Mixture components of the posterior. - """ - - # Evaluate the density estimator. - encoded_x = self._neural_net._embedding_net(x) - dist = self._neural_net._distribution # defined to avoid black formatting. - logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) - norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) - - # The following if case is needed because, in the constructor, we call - # `_posthoc_correction` regardless of whether the `proposal` itself had a - # `proposal` or not. - if not self._apply_correction: - return norm_logits_d, m_d, prec_d - else: - logits_pp, m_pp, prec_pp = self._logits_pp, self._m_pp, self._prec_pp - - # Compute the MoG parameters of the posterior. - logits_p, m_p, prec_p, cov_p = self._proposal_posterior_transformation( - logits_pp, m_pp, prec_pp, norm_logits_d, m_d, prec_d - ) - return logits_p, m_p, prec_p - - def _proposal_posterior_transformation( - self, - logits_pp: Tensor, - means_pp: Tensor, - precisions_pp: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Transforms the proposal posterior (the MDN) into the posterior. - - The approximate posterior is: - $p(\theta|x) = 1/Z * q(\theta|x) * p(\theta) / prop(\theta)$ - In words: posterior = proposal posterior estimate * prior / proposal. - - Since the proposal posterior estimate and the proposal are MoG, and the - prior is either Gaussian or uniform, we can solve this in closed-form. - - This function implements Appendix C from [1], and is highly similar to - `SNPE_C._automatic_posterior_transformation()`. - - Args: - logits_pp: Component weight of each Gaussian of the proposal prior. - means_pp: Mean of each Gaussian of the proposal prior. - precisions_pp: Precision matrix of each Gaussian of the proposal prior. - logits_d: Component weight for each Gaussian of the density estimator. - means_d: Mean of each Gaussian of the density estimator. - precisions_d: Precision matrix of each Gaussian of the density estimator. - - Returns: (Component weight, mean, precision matrix, covariance matrix) of each - Gaussian of the approximate posterior. - """ - - precisions_post, covariances_post = self._precisions_posterior( - precisions_pp, precisions_d - ) - - means_post = self._means_posterior( - covariances_post, means_pp, precisions_pp, means_d, precisions_d - ) - - logits_post = SNPE_A_MDN._logits_posterior( - means_post, - precisions_post, - covariances_post, - logits_pp, - means_pp, - precisions_pp, - logits_d, - means_d, - precisions_d, - ) - - return logits_post, means_post, precisions_post, covariances_post - - def _set_state_for_mog_proposal(self) -> None: - """ - Set state variables of the SNPE_A_MDN instance every time `set_proposal()` - is called, i.e. every time a posterior is build using - `SNPE_A.build_posterior()`. - - This function is almost identical to `SNPE_C._set_state_for_mog_proposal()`. - - Three things are computed: - 1) Check if z-scoring was requested. To do so, we check if the `_transform` - argument of the net had been a `CompositeTransform`. See pyknos mdn.py. - 2) Define a (potentially standardized) prior. It's standardized if z-scoring - had been requested. - 3) Compute (Precision * mean) for the prior. This quantity is used at every - training step if the prior is Gaussian. - """ - - self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) - - self._set_maybe_z_scored_prior() - - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - self.prec_m_prod_prior = torch.mv( - self._maybe_z_scored_prior.precision_matrix, # type: ignore - self._maybe_z_scored_prior.loc, # type: ignore - ) - - def _set_maybe_z_scored_prior(self) -> None: - r""" - Compute and store potentially standardized prior (if z-scoring was requested). - - This function is highly similar to `SNPE_C._set_maybe_z_scored_prior()`. - - The proposal posterior is: - $p(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ - - Let's denote z-scored theta by `a`: a = (theta - mean) / std - Then $p'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ - - The ' indicates that the evaluation occurs in standardized space. The constant - scaling factor has been absorbed into $Z_2$. - From the above equation, we see that we need to evaluate the prior **in - standardized space**. We build the standardized prior in this function. - - The standardize transform that is applied to the samples theta does not use - the exact prior mean and std (due to implementation issues). Hence, the z-scored - prior will not be exactly have mean=0 and std=1. - """ - if self.z_score_theta: - scale = self._neural_net._transform._transforms[0]._scale - shift = self._neural_net._transform._transforms[0]._shift - - # Following the definition of the linear transform in - # `standardizing_transform` in `sbiutils.py`: - # shift=-mean / std - # scale=1 / std - # Solving these equations for mean and std: - estim_prior_std = 1 / scale - estim_prior_mean = -shift * estim_prior_std - - # Compute the discrepancy of the true prior mean and std and the mean and - # std that was empirically estimated from samples. - # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) - # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean - # and std (estimated from samples and used to build standardize transform). - almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std - almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std - - if isinstance(self._prior, MultivariateNormal): - self._maybe_z_scored_prior = MultivariateNormal( - almost_zero_mean, torch.diag(almost_one_std) - ) - else: - range_ = torch.sqrt(almost_one_std * 3.0) - self._maybe_z_scored_prior = utils.BoxUniform( - almost_zero_mean - range_, almost_zero_mean + range_ - ) - else: - self._maybe_z_scored_prior = self._prior - - def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: - """Return potentially standardized theta if z-scoring was requested.""" - - if self.z_score_theta: - theta, _ = self._neural_net._transform(theta) - - return theta - - def _precisions_posterior(self, precisions_pp: Tensor, precisions_d: Tensor): - r"""Return the precisions and covariances of the MoG posterior. - - As described at the end of Appendix C in [1], it can happen that the - proposal's precision matrix is not positive definite. - - $S_k^\prime = ( S_k^{-1} - S_0^{-1} )^{-1}$ - (see eq (23) in Appendix C of [1]) - - Args: - precisions_pp: Precision matrices of the proposal prior. - precisions_d: Precision matrices of the density estimator. - - Returns: (Precisions, Covariances) of the MoG posterior. - """ - - num_comps_p = precisions_pp.shape[1] - num_comps_d = precisions_d.shape[1] - - # Check if precision matrices are positive definite. - for batches in precisions_pp: - for pprior in batches: - eig_pprior = torch.linalg.eigvalsh(pprior, UPLO="U") - if not (eig_pprior > 0).all(): - raise AssertionError( - "The precision matrix of the proposal is not positive definite!" - ) - for batches in precisions_d: - for d in batches: - eig_d = torch.linalg.eigvalsh(d, UPLO="U") - if not (eig_d > 0).all(): - raise AssertionError( - "The precision matrix of the density estimator is not " - "positive definite!" - ) - - precisions_pp_rep = precisions_pp.repeat_interleave(num_comps_d, dim=1) - precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) - - precisions_p = precisions_d_rep - precisions_pp_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - precisions_p += self._maybe_z_scored_prior.precision_matrix - - # Check if precision matrix is positive definite. - for idx_batch, batches in enumerate(precisions_p): - for idx_comp, pp in enumerate(batches): - eig_pp = torch.symeig(pp, eigenvectors=False).eigenvalues - if not (eig_pp > 0).all(): - raise AssertionError( - "The precision matrix of a posterior is not positive " - "definite! This is a known issue for SNPE-A. Either try a " - "different parameter setting, e.g. a different number of " - "mixture components (when contracting SNPE-A), or a different " - "value for the parameter perturbation (when building the " - "posterior)." - ) - - covariances_p = torch.inverse(precisions_p) - return precisions_p, covariances_p - - def _means_posterior( - self, - covariances_p: Tensor, - means_pp: Tensor, - precisions_pp: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Return the means of the MoG posterior. - - $m_k^\prime = S_k^\prime ( S_k^{-1} m_k - S_0^{-1} m_0 )$ - (see eq (24) in Appendix C of [1]) - - Args: - covariances_post: Covariance matrices of the MoG posterior. - means_pp: Means of the proposal prior. - precisions_pp: Precision matrices of the proposal prior. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Means of the MoG posterior. - """ - - num_comps_pp = precisions_pp.shape[1] - num_comps_d = precisions_d.shape[1] - - # Compute the products P_k * m_k and P_0 * m_0. - prec_m_prod_pp = utils.batched_mixture_mv(precisions_pp, means_pp) - prec_m_prod_d = utils.batched_mixture_mv(precisions_d, means_d) - - # Repeat them to allow for matrix operations: same trick as for the precisions. - prec_m_prod_pp_rep = prec_m_prod_pp.repeat_interleave(num_comps_d, dim=1) - prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_pp, 1) - - # Compute the means P_k^prime * (P_k * m_k - P_0 * m_0). - summed_cov_m_prod_rep = prec_m_prod_d_rep - prec_m_prod_pp_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - summed_cov_m_prod_rep += self.prec_m_prod_prior - - means_p = utils.batched_mixture_mv(covariances_p, summed_cov_m_prod_rep) - return means_p - - @staticmethod - def _logits_posterior( - means_post: Tensor, - precisions_post: Tensor, - covariances_post: Tensor, - logits_pp: Tensor, - means_pp: Tensor, - precisions_pp: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Return the component weights (i.e. logits) of the MoG posterior. - - $\alpha_k^\prime = \frac{ \alpha_k exp(-0.5 c_k) }{ \sum{j} \alpha_j exp(-0.5 - c_j) } $ - with - $c_k = logdet(S_k) - logdet(S_0) - logdet(S_k^\prime) + - + m_k^T P_k m_k - m_0^T P_0 m_0 - m_k^\prime^T P_k^\prime m_k^\prime$ - (see eqs. (25, 26) in Appendix C of [1]) - - Args: - means_post: Means of the posterior. - precisions_post: Precision matrices of the posterior. - covariances_post: Covariance matrices of the posterior. - logits_pp: Component weights (i.e. logits) of the proposal prior. - means_pp: Means of the proposal prior. - precisions_pp: Precision matrices of the proposal prior. - logits_d: Component weights (i.e. logits) of the density estimator. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Component weights of the proposal posterior. - """ - - num_comps_pp = precisions_pp.shape[1] - num_comps_d = precisions_d.shape[1] - - # Compute the ratio of the logits similar to eq (10) in Appendix A.1 of [2] - logits_pp_rep = logits_pp.repeat_interleave(num_comps_d, dim=1) - logits_d_rep = logits_d.repeat(1, num_comps_pp) - logit_factors = logits_d_rep - logits_pp_rep - - # Compute the log-determinants - logdet_covariances_post = torch.logdet(covariances_post) - logdet_covariances_pp = -torch.logdet(precisions_pp) - logdet_covariances_d = -torch.logdet(precisions_d) - - # Repeat the proposal and density estimator terms such that there are LK terms. - # Same trick as has been used above. - logdet_covariances_pp_rep = logdet_covariances_pp.repeat_interleave( - num_comps_d, dim=1 - ) - logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_pp) - - log_sqrt_det_ratio = 0.5 * ( # similar to eq (14) in Appendix A.1 of [2] - logdet_covariances_post - + logdet_covariances_pp_rep - - logdet_covariances_d_rep - ) - - # Compute for proposal, density estimator, and proposal posterior: - exponent_pp = utils.batched_mixture_vmv( - precisions_pp, means_pp # m_0 in eq (26) in Appendix C of [1] - ) - exponent_d = utils.batched_mixture_vmv( - precisions_d, means_d # m_k in eq (26) in Appendix C of [1] - ) - exponent_post = utils.batched_mixture_vmv( - precisions_post, means_post # m_k^\prime in eq (26) in Appendix C of [1] - ) - - # Extend proposal and density estimator exponents to get LK terms. - exponent_pp_rep = exponent_pp.repeat_interleave(num_comps_d, dim=1) - exponent_d_rep = exponent_d.repeat(1, num_comps_pp) - exponent = -0.5 * ( - exponent_d_rep - exponent_pp_rep - exponent_post # eq (26) in [1] - ) - - logits_post = logit_factors + log_sqrt_det_ratio + exponent - return logits_post +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +import warnings +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, Optional, Union + +import torch +import torch.nn as nn +from pyknos.mdn.mdn import MultivariateGaussianMDN +from pyknos.nflows import flows +from pyknos.nflows.transforms import CompositeTransform +from torch import Tensor +from torch.distributions import Distribution, MultivariateNormal + +import sbi.utils as utils +from sbi.inference.posteriors.direct_posterior import DirectPosterior +from sbi.inference.snpe.snpe_base import PosteriorEstimator +from sbi.types import TensorboardSummaryWriter, TorchModule +from sbi.utils import torchutils + + +class SNPE_A(PosteriorEstimator): + def __init__( + self, + prior: Optional[Distribution] = None, + density_estimator: Union[str, Callable] = "mdn_snpe_a", + num_components: int = 10, + device: str = "cpu", + logging_level: Union[int, str] = "WARNING", + summary_writer: Optional[TensorboardSummaryWriter] = None, + show_progress_bars: bool = True, + ): + r"""SNPE-A [1]. + + [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional + Density Estimation_, Papamakarios et al., NeurIPS 2016, + https://arxiv.org/abs/1605.06376. + + This class implements SNPE-A. SNPE-A trains across multiple rounds with a + maximum-likelihood-loss. This will make training converge to the proposal + posterior instead of the true posterior. To correct for this, SNPE-A applies a + post-hoc correction after training. This correction has to be performed + analytically. Thus, SNPE-A is limited to Gaussian distributions for all but the + last round. In the last round, SNPE-A can use a Mixture of Gaussians. + + Args: + prior: A probability distribution that expresses prior knowledge about the + parameters, e.g. which ranges are meaningful for them. Any + object with `.log_prob()`and `.sample()` (for example, a PyTorch + distribution) can be used. + density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a + pre-configured mixture of densities network. Alternatively, a function + that builds a custom neural network can be provided. The function will + be called with the first batch of simulations (theta, x), which can + thus be used for shape inference and potentially for z-scoring. It + needs to return a PyTorch `nn.Module` implementing the density + estimator. The density estimator needs to provide the methods + `.log_prob` and `.sample()`. Note that until the last round only a + single (multivariate) Gaussian component is used for training (see + Algorithm 1 in [1]). In the last round, this component is replicated + `num_components` times, its parameters are perturbed with a very small + noise, and then the last training round is done with the expanded + Gaussian mixture as estimator for the proposal posterior. + num_components: Number of components of the mixture of Gaussians in the + last round. This overrides the `num_components` value passed to + `posterior_nn()`. + device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". + logging_level: Minimum severity of messages to log. One of the strings + INFO, WARNING, DEBUG, ERROR and CRITICAL. + summary_writer: A tensorboard `SummaryWriter` to control, among others, log + file location (default is `/logs`.) + show_progress_bars: Whether to show a progressbar during training. + """ + + # Catch invalid inputs. + if not ((density_estimator == "mdn_snpe_a") or callable(density_estimator)): + raise TypeError( + "The `density_estimator` passed to SNPE_A needs to be a " + "callable or the string 'mdn_snpe_a'!" + ) + + # `num_components` will be used to replicate the Gaussian in the last round. + self._num_components = num_components + self._ran_final_round = False + + # WARNING: sneaky trick ahead. We proxy the parent's `train` here, + # requiring the signature to have `num_atoms`, save it for use below, and + # continue. It's sneaky because we are using the object (self) as a namespace + # to pass arguments between functions, and that's implicit state management. + kwargs = utils.del_entries( + locals(), + entries=("self", "__class__", "num_components"), + ) + super().__init__(**kwargs) + + def train( + self, + final_round: bool = False, + training_batch_size: int = 50, + learning_rate: float = 5e-4, + validation_fraction: float = 0.1, + stop_after_epochs: int = 20, + max_num_epochs: int = 2**31 - 1, + clip_max_norm: Optional[float] = 5.0, + calibration_kernel: Optional[Callable] = None, + resume_training: bool = False, + force_first_round_loss: bool = False, + retrain_from_scratch: bool = False, + show_train_summary: bool = False, + dataloader_kwargs: Optional[Dict] = None, + component_perturbation: float = 5e-3, + ) -> nn.Module: + r"""Return density estimator that approximates the proposal posterior. + + [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional + Density Estimation_, Papamakarios et al., NeurIPS 2016, + https://arxiv.org/abs/1605.06376. + + Training is performed with maximum likelihood on samples from the latest round, + which leads the algorithm to converge to the proposal posterior. + + Args: + final_round: Whether we are in the last round of training or not. For all + but the last round, Algorithm 1 from [1] is executed. In last the + round, Algorithm 2 from [1] is executed once. + training_batch_size: Training batch size. + learning_rate: Learning rate for Adam optimizer. + validation_fraction: The fraction of data to use for validation. + stop_after_epochs: The number of epochs to wait for improvement on the + validation set before terminating training. + max_num_epochs: Maximum number of epochs to run. If reached, we stop + training even when the validation loss is still decreasing. Otherwise, + we train until validation loss increases (see also `stop_after_epochs`). + clip_max_norm: Value at which to clip the total gradient norm in order to + prevent exploding gradients. Use None for no clipping. + calibration_kernel: A function to calibrate the loss with respect to the + simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. + resume_training: Can be used in case training time is limited, e.g. on a + cluster. If `True`, the split between train and validation set, the + optimizer, the number of epochs, and the best validation log-prob will + be restored from the last time `.train()` was called. + force_first_round_loss: If `True`, train with maximum likelihood, + i.e., potentially ignoring the correction for using a proposal + distribution different from the prior. + force_first_round_loss: If `True`, train with maximum likelihood, + regardless of the proposal distribution. + retrain_from_scratch: Whether to retrain the conditional density + estimator for the posterior from scratch each round. Not supported for + SNPE-A. + show_train_summary: Whether to print the number of epochs and validation + loss and leakage after the training. + dataloader_kwargs: Additional or updated kwargs to be passed to the training + and validation dataloaders (like, e.g., a collate_fn) + component_perturbation: The standard deviation applied to all weights and + biases when, in the last round, the Mixture of Gaussians is build from + a single Gaussian. This value can be problem-specific and also depends + on the number of mixture components. + + Returns: + Density estimator that approximates the distribution $p(\theta|x)$. + """ + + assert not retrain_from_scratch, """Retraining from scratch is not supported in SNPE-A yet. The reason for + this is that, if we reininitialized the density estimator, the z-scoring would + change, which would break the posthoc correction. This is a pure implementation + issue.""" + + kwargs = utils.del_entries( + locals(), + entries=("self", "__class__", "final_round", "component_perturbation"), + ) + + # SNPE-A always discards the prior samples. + kwargs["discard_prior_samples"] = True + + self._round = max(self._data_round_index) + + if final_round: + # If there is (will be) only one round, train with Algorithm 2 from [1]. + if self._round == 0: + self._build_neural_net = partial( + self._build_neural_net, num_components=self._num_components + ) + # Run Algorithm 2 from [1]. + elif not self._ran_final_round: + # Now switch to the specified number of components. This method will + # only be used if `retrain_from_scratch=True`. Otherwise, + # the MDN will be built from replicating the single-component net for + # `num_component` times (via `_expand_mog()`). + self._build_neural_net = partial( + self._build_neural_net, num_components=self._num_components + ) + + # Extend the MDN to the originally desired number of components. + self._expand_mog(eps=component_perturbation) + else: + warnings.warn( + "You have already run SNPE-A with `final_round=True`. Running it" + "again with this setting will not allow computing the posthoc" + "correction applied in SNPE-A. Thus, you will get an error when " + "calling `.build_posterior()` after training.", + UserWarning, + ) + else: + # Run Algorithm 1 from [1]. + # Wrap the function that builds the MDN such that we can make + # sure that there is only one component when running. + self._build_neural_net = partial(self._build_neural_net, num_components=1) + + if final_round: + self._ran_final_round = True + + return super().train(**kwargs) + + def correct_for_proposal( + self, + density_estimator: Optional[TorchModule] = None, + ) -> "SNPE_A_MDN": + r"""Build mixture of Gaussians that approximates the posterior. + + Returns a `SNPE_A_MDN` object, which applies the posthoc-correction required in + SNPE-A. + + Args: + density_estimator: The density estimator that the posterior is based on. + If `None`, use the latest neural density estimator that was trained. + + Returns: + Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. + """ + if density_estimator is None: + density_estimator = deepcopy( + self._neural_net + ) # PosteriorEstimator.train() also returns a deepcopy, mimic this here + # If internal net is used device is defined. + device = self._device + else: + # Otherwise, infer it from the device of the net parameters. + device = str(next(density_estimator.parameters()).device) + + # Set proposal of the density estimator. + # This also evokes the z-scoring correction if necessary. + if ( + self._proposal_roundwise[-1] is self._prior + or self._proposal_roundwise[-1] is None + ): + proposal = self._prior + assert isinstance( + proposal, (MultivariateNormal, utils.BoxUniform) + ), """Prior must be `torch.distributions.MultivariateNormal` or `sbi.utils. + BoxUniform`""" + else: + assert isinstance( + self._proposal_roundwise[-1], DirectPosterior + ), """The proposal you passed to `append_simulations` is neither the prior + nor a `DirectPosterior`. SNPE-A currently only supports these scenarios. + """ + proposal = self._proposal_roundwise[-1] + + # Create the SNPE_A_MDN + wrapped_density_estimator = SNPE_A_MDN( + flow=density_estimator, proposal=proposal, prior=self._prior, device=device + ) + return wrapped_density_estimator + + def build_posterior( + self, + density_estimator: Optional[TorchModule] = None, + prior: Optional[Distribution] = None, + ) -> "DirectPosterior": + r"""Build posterior from the neural density estimator. + + This method first corrects the estimated density with `correct_for_proposal` + and then returns a `DirectPosterior`. + + Args: + density_estimator: The density estimator that the posterior is based on. + If `None`, use the latest neural density estimator that was trained. + prior: Prior distribution. + + Returns: + Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. + """ + if prior is None: + assert ( + self._prior is not None + ), """You did not pass a prior. You have to pass the prior either at + initialization `inference = SNPE_A(prior)` or to `.build_posterior + (prior=prior)`.""" + prior = self._prior + + wrapped_density_estimator = self.correct_for_proposal( + density_estimator=density_estimator + ) + self._posterior = DirectPosterior( + posterior_estimator=wrapped_density_estimator, + prior=prior, + ) + return deepcopy(self._posterior) + + def _log_prob_proposal_posterior( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + ) -> Tensor: + """Return the log-probability of the proposal posterior. + + For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in + `_loss()` to be found in `snpe_base.py`. + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + proposal: Proposal distribution. + + Returns: Log-probability of the proposal posterior. + """ + return self._neural_net.log_prob(theta, x) + + def _expand_mog(self, eps: float = 1e-5): + """ + Replicate a singe Gaussian trained with Algorithm 1 before continuing + with Algorithm 2. The weights and biases of the associated MDN layers + are repeated `num_components` times, slightly perturbed to break the + symmetry such that the gradients in the subsequent training are not + all identical. + + Args: + eps: Standard deviation for the random perturbation. + """ + assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN) + + # Increase the number of components + self._neural_net._distribution._num_components = self._num_components + + # Expand the 1-dim Gaussian. + for name, param in self._neural_net.named_parameters(): + if any( + key in name for key in ["logits", "means", "unconstrained", "upper"] + ): + if "bias" in name: + param.data = param.data.repeat(self._num_components) + param.data.add_(torch.randn_like(param.data) * eps) + param.grad = None # let autograd construct a new gradient + elif "weight" in name: + param.data = param.data.repeat(self._num_components, 1) + param.data.add_(torch.randn_like(param.data) * eps) + param.grad = None # let autograd construct a new gradient + + +class SNPE_A_MDN(nn.Module): + """Generates a posthoc-corrected MDN which approximates the posterior. + + This class takes as input the density estimator (abbreviated with `_d` suffix, aka + the proposal posterior) and the proposal prior (abbreviated with `_pp` suffix) from + which the simulations were drawn. It uses the algorithm presented in SNPE-A [1] to + compute the approximate posterior (abbreviated with `_p` suffix) from the two. The + approximate posterior is a MoG. This class also implements log-prob calculation + sampling from the approximate posterior. It inherits from `nn.Module` since the + constructor of `DirectPosterior` expects the argument `neural_net` to be a + `nn.Module`. + + [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional + Density Estimation_, Papamakarios et al., NeurIPS 2016, + https://arxiv.org/abs/1605.06376. + """ + + def __init__( + self, + flow: flows.Flow, + proposal: Union["utils.BoxUniform", "MultivariateNormal", "DirectPosterior"], + prior: Distribution, + device: str, + ): + """Constructor. + + Args: + flow: The trained normalizing flow, passed when building the posterior. + proposal: The proposal distribution. + prior: The prior distribution. + """ + # Call nn.Module's constructor. + super().__init__() + + self._neural_net = flow + self._prior = prior + self._device = device + + # Set the proposal using the `default_x`. + if isinstance(proposal, (utils.BoxUniform, MultivariateNormal)): + self._apply_correction = False + else: + self._apply_correction = True + logits_pp, m_pp, prec_pp = proposal.posterior_estimator._posthoc_correction( + proposal.default_x + ) + self._logits_pp, self._m_pp, self._prec_pp = ( + logits_pp.detach(), + m_pp.detach(), + prec_pp.detach(), + ) + + # Take care of z-scoring, pre-compute and store prior terms. + self._set_state_for_mog_proposal() + + def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: + inputs, context = inputs.to(self._device), context.to(self._device) + + if not self._apply_correction: + return self._neural_net.log_prob(inputs, context) + else: + # When we want to compute the approx. posterior, a proposal prior \tilde{p} + # has already been observed. To analytically calculate the log-prob of the + # Gaussian, we first need to compute the mixture components. + + # Compute the mixture components of the proposal posterior. + logits_pp, m_pp, prec_pp = self._posthoc_correction(context) + + # z-score theta if it z-scoring had been requested. + theta = self._maybe_z_score_theta(inputs) + + # Compute the log_prob of theta under the product. + log_prob_proposal_posterior = utils.mog_log_prob( + theta, logits_pp, m_pp, prec_pp + ) + utils.assert_all_finite( + log_prob_proposal_posterior, "proposal posterior eval" + ) + return log_prob_proposal_posterior # \hat{p} from eq (3) in [1] + + def sample(self, num_samples: int, context: Tensor, batch_size: int = 1) -> Tensor: + context = context.to(self._device) + + if not self._apply_correction: + return self._neural_net.sample(num_samples, context, batch_size) + else: + # When we want to sample from the approx. posterior, a proposal prior + # \tilde{p} has already been observed. To analytically calculate the + # log-prob of the Gaussian, we first need to compute the mixture components. + return self._sample_approx_posterior_mog(num_samples, context, batch_size) + + def _sample_approx_posterior_mog( + self, num_samples, x: Tensor, batch_size: int + ) -> Tensor: + r"""Sample from the approximate posterior. + + Args: + num_samples: Desired number of samples. + x: Conditioning context for posterior $p(\theta|x)$. + batch_size: Batch size for sampling. + + Returns: + Samples from the approximate mixture of Gaussians posterior. + """ + + # Compute the mixture components of the posterior. + logits_p, m_p, prec_p = self._posthoc_correction(x) + + # Compute the precision factors which represent the upper triangular matrix + # of the cholesky decomposition of the prec_p. + prec_factors_p = torch.linalg.cholesky(prec_p, upper=True) + + assert logits_p.ndim == 2 + assert m_p.ndim == 3 + assert prec_p.ndim == 4 + assert prec_factors_p.ndim == 4 + + # Replicate to use batched sampling from pyknos. + if batch_size is not None and batch_size > 1: + logits_p = logits_p.repeat(batch_size, 1) + m_p = m_p.repeat(batch_size, 1, 1) + prec_factors_p = prec_factors_p.repeat(batch_size, 1, 1, 1) + + # Get (optionally z-scored) MoG samples. + theta = MultivariateGaussianMDN.sample_mog( + num_samples, logits_p, m_p, prec_factors_p + ) + + embedded_context = self._neural_net._embedding_net(x) + if embedded_context is not None: + # Merge the context dimension with sample dimension in order to + # apply the transform. + theta = torchutils.merge_leading_dims(theta, num_dims=2) + embedded_context = torchutils.repeat_rows( + embedded_context, num_reps=num_samples + ) + + theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) + + if embedded_context is not None: + # Split the context dimension from sample dimension. + theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples]) + + return theta + + def _posthoc_correction(self, x: Tensor): + """ + Compute the mixture components of the posterior given the current density + estimator and the proposal. + + Args: + x: Conditioning context for posterior. + + Returns: + Mixture components of the posterior. + """ + + # Evaluate the density estimator. + encoded_x = self._neural_net._embedding_net(x) + dist = self._neural_net._distribution # defined to avoid black formatting. + logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) + norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) + + # The following if case is needed because, in the constructor, we call + # `_posthoc_correction` regardless of whether the `proposal` itself had a + # `proposal` or not. + if not self._apply_correction: + return norm_logits_d, m_d, prec_d + else: + logits_pp, m_pp, prec_pp = self._logits_pp, self._m_pp, self._prec_pp + + # Compute the MoG parameters of the posterior. + logits_p, m_p, prec_p, cov_p = self._proposal_posterior_transformation( + logits_pp, m_pp, prec_pp, norm_logits_d, m_d, prec_d + ) + return logits_p, m_p, prec_p + + def _proposal_posterior_transformation( + self, + logits_pp: Tensor, + means_pp: Tensor, + precisions_pp: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Transforms the proposal posterior (the MDN) into the posterior. + + The approximate posterior is: + $p(\theta|x) = 1/Z * q(\theta|x) * p(\theta) / prop(\theta)$ + In words: posterior = proposal posterior estimate * prior / proposal. + + Since the proposal posterior estimate and the proposal are MoG, and the + prior is either Gaussian or uniform, we can solve this in closed-form. + + This function implements Appendix C from [1], and is highly similar to + `SNPE_C._automatic_posterior_transformation()`. + + Args: + logits_pp: Component weight of each Gaussian of the proposal prior. + means_pp: Mean of each Gaussian of the proposal prior. + precisions_pp: Precision matrix of each Gaussian of the proposal prior. + logits_d: Component weight for each Gaussian of the density estimator. + means_d: Mean of each Gaussian of the density estimator. + precisions_d: Precision matrix of each Gaussian of the density estimator. + + Returns: (Component weight, mean, precision matrix, covariance matrix) of each + Gaussian of the approximate posterior. + """ + + precisions_post, covariances_post = self._precisions_posterior( + precisions_pp, precisions_d + ) + + means_post = self._means_posterior( + covariances_post, means_pp, precisions_pp, means_d, precisions_d + ) + + logits_post = SNPE_A_MDN._logits_posterior( + means_post, + precisions_post, + covariances_post, + logits_pp, + means_pp, + precisions_pp, + logits_d, + means_d, + precisions_d, + ) + + return logits_post, means_post, precisions_post, covariances_post + + def _set_state_for_mog_proposal(self) -> None: + """ + Set state variables of the SNPE_A_MDN instance every time `set_proposal()` + is called, i.e. every time a posterior is build using + `SNPE_A.build_posterior()`. + + This function is almost identical to `SNPE_C._set_state_for_mog_proposal()`. + + Three things are computed: + 1) Check if z-scoring was requested. To do so, we check if the `_transform` + argument of the net had been a `CompositeTransform`. See pyknos mdn.py. + 2) Define a (potentially standardized) prior. It's standardized if z-scoring + had been requested. + 3) Compute (Precision * mean) for the prior. This quantity is used at every + training step if the prior is Gaussian. + """ + + self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) + + self._set_maybe_z_scored_prior() + + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + self.prec_m_prod_prior = torch.mv( + self._maybe_z_scored_prior.precision_matrix, # type: ignore + self._maybe_z_scored_prior.loc, # type: ignore + ) + + def _set_maybe_z_scored_prior(self) -> None: + r""" + Compute and store potentially standardized prior (if z-scoring was requested). + + This function is highly similar to `SNPE_C._set_maybe_z_scored_prior()`. + + The proposal posterior is: + $p(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ + + Let's denote z-scored theta by `a`: a = (theta - mean) / std + Then $p'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ + + The ' indicates that the evaluation occurs in standardized space. The constant + scaling factor has been absorbed into $Z_2$. + From the above equation, we see that we need to evaluate the prior **in + standardized space**. We build the standardized prior in this function. + + The standardize transform that is applied to the samples theta does not use + the exact prior mean and std (due to implementation issues). Hence, the z-scored + prior will not be exactly have mean=0 and std=1. + """ + if self.z_score_theta: + scale = self._neural_net._transform._transforms[0]._scale + shift = self._neural_net._transform._transforms[0]._shift + + # Following the definition of the linear transform in + # `standardizing_transform` in `sbiutils.py`: + # shift=-mean / std + # scale=1 / std + # Solving these equations for mean and std: + estim_prior_std = 1 / scale + estim_prior_mean = -shift * estim_prior_std + + # Compute the discrepancy of the true prior mean and std and the mean and + # std that was empirically estimated from samples. + # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) + # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean + # and std (estimated from samples and used to build standardize transform). + almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std + almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std + + if isinstance(self._prior, MultivariateNormal): + self._maybe_z_scored_prior = MultivariateNormal( + almost_zero_mean, torch.diag(almost_one_std) + ) + else: + range_ = torch.sqrt(almost_one_std * 3.0) + self._maybe_z_scored_prior = utils.BoxUniform( + almost_zero_mean - range_, almost_zero_mean + range_ + ) + else: + self._maybe_z_scored_prior = self._prior + + def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: + """Return potentially standardized theta if z-scoring was requested.""" + + if self.z_score_theta: + theta, _ = self._neural_net._transform(theta) + + return theta + + def _precisions_posterior(self, precisions_pp: Tensor, precisions_d: Tensor): + r"""Return the precisions and covariances of the MoG posterior. + + As described at the end of Appendix C in [1], it can happen that the + proposal's precision matrix is not positive definite. + + $S_k^\prime = ( S_k^{-1} - S_0^{-1} )^{-1}$ + (see eq (23) in Appendix C of [1]) + + Args: + precisions_pp: Precision matrices of the proposal prior. + precisions_d: Precision matrices of the density estimator. + + Returns: (Precisions, Covariances) of the MoG posterior. + """ + + num_comps_p = precisions_pp.shape[1] + num_comps_d = precisions_d.shape[1] + + # Check if precision matrices are positive definite. + for batches in precisions_pp: + for pprior in batches: + eig_pprior = torch.linalg.eigvalsh(pprior, UPLO="U") + if not (eig_pprior > 0).all(): + raise AssertionError( + "The precision matrix of the proposal is not positive definite!" + ) + for batches in precisions_d: + for d in batches: + eig_d = torch.linalg.eigvalsh(d, UPLO="U") + if not (eig_d > 0).all(): + raise AssertionError( + "The precision matrix of the density estimator is not " + "positive definite!" + ) + + precisions_pp_rep = precisions_pp.repeat_interleave(num_comps_d, dim=1) + precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) + + precisions_p = precisions_d_rep - precisions_pp_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + precisions_p += self._maybe_z_scored_prior.precision_matrix + + # Check if precision matrix is positive definite. + for idx_batch, batches in enumerate(precisions_p): + for idx_comp, pp in enumerate(batches): + eig_pp = torch.symeig(pp, eigenvectors=False).eigenvalues + if not (eig_pp > 0).all(): + raise AssertionError( + "The precision matrix of a posterior is not positive " + "definite! This is a known issue for SNPE-A. Either try a " + "different parameter setting, e.g. a different number of " + "mixture components (when contracting SNPE-A), or a different " + "value for the parameter perturbation (when building the " + "posterior)." + ) + + covariances_p = torch.inverse(precisions_p) + return precisions_p, covariances_p + + def _means_posterior( + self, + covariances_p: Tensor, + means_pp: Tensor, + precisions_pp: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Return the means of the MoG posterior. + + $m_k^\prime = S_k^\prime ( S_k^{-1} m_k - S_0^{-1} m_0 )$ + (see eq (24) in Appendix C of [1]) + + Args: + covariances_post: Covariance matrices of the MoG posterior. + means_pp: Means of the proposal prior. + precisions_pp: Precision matrices of the proposal prior. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Means of the MoG posterior. + """ + + num_comps_pp = precisions_pp.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute the products P_k * m_k and P_0 * m_0. + prec_m_prod_pp = utils.batched_mixture_mv(precisions_pp, means_pp) + prec_m_prod_d = utils.batched_mixture_mv(precisions_d, means_d) + + # Repeat them to allow for matrix operations: same trick as for the precisions. + prec_m_prod_pp_rep = prec_m_prod_pp.repeat_interleave(num_comps_d, dim=1) + prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_pp, 1) + + # Compute the means P_k^prime * (P_k * m_k - P_0 * m_0). + summed_cov_m_prod_rep = prec_m_prod_d_rep - prec_m_prod_pp_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + summed_cov_m_prod_rep += self.prec_m_prod_prior + + means_p = utils.batched_mixture_mv(covariances_p, summed_cov_m_prod_rep) + return means_p + + @staticmethod + def _logits_posterior( + means_post: Tensor, + precisions_post: Tensor, + covariances_post: Tensor, + logits_pp: Tensor, + means_pp: Tensor, + precisions_pp: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Return the component weights (i.e. logits) of the MoG posterior. + + $\alpha_k^\prime = \frac{ \alpha_k exp(-0.5 c_k) }{ \sum{j} \alpha_j exp(-0.5 + c_j) } $ + with + $c_k = logdet(S_k) - logdet(S_0) - logdet(S_k^\prime) + + + m_k^T P_k m_k - m_0^T P_0 m_0 - m_k^\prime^T P_k^\prime m_k^\prime$ + (see eqs. (25, 26) in Appendix C of [1]) + + Args: + means_post: Means of the posterior. + precisions_post: Precision matrices of the posterior. + covariances_post: Covariance matrices of the posterior. + logits_pp: Component weights (i.e. logits) of the proposal prior. + means_pp: Means of the proposal prior. + precisions_pp: Precision matrices of the proposal prior. + logits_d: Component weights (i.e. logits) of the density estimator. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Component weights of the proposal posterior. + """ + + num_comps_pp = precisions_pp.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute the ratio of the logits similar to eq (10) in Appendix A.1 of [2] + logits_pp_rep = logits_pp.repeat_interleave(num_comps_d, dim=1) + logits_d_rep = logits_d.repeat(1, num_comps_pp) + logit_factors = logits_d_rep - logits_pp_rep + + # Compute the log-determinants + logdet_covariances_post = torch.logdet(covariances_post) + logdet_covariances_pp = -torch.logdet(precisions_pp) + logdet_covariances_d = -torch.logdet(precisions_d) + + # Repeat the proposal and density estimator terms such that there are LK terms. + # Same trick as has been used above. + logdet_covariances_pp_rep = logdet_covariances_pp.repeat_interleave( + num_comps_d, dim=1 + ) + logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_pp) + + log_sqrt_det_ratio = 0.5 * ( # similar to eq (14) in Appendix A.1 of [2] + logdet_covariances_post + + logdet_covariances_pp_rep + - logdet_covariances_d_rep + ) + + # Compute for proposal, density estimator, and proposal posterior: + exponent_pp = utils.batched_mixture_vmv( + precisions_pp, means_pp # m_0 in eq (26) in Appendix C of [1] + ) + exponent_d = utils.batched_mixture_vmv( + precisions_d, means_d # m_k in eq (26) in Appendix C of [1] + ) + exponent_post = utils.batched_mixture_vmv( + precisions_post, means_post # m_k^\prime in eq (26) in Appendix C of [1] + ) + + # Extend proposal and density estimator exponents to get LK terms. + exponent_pp_rep = exponent_pp.repeat_interleave(num_comps_d, dim=1) + exponent_d_rep = exponent_d.repeat(1, num_comps_pp) + exponent = -0.5 * ( + exponent_d_rep - exponent_pp_rep - exponent_post # eq (26) in [1] + ) + + logits_post = logit_factors + log_sqrt_det_ratio + exponent + return logits_post diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 1a56f9c75..6d7b6ccef 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -1,597 +1,597 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . -import time -from abc import ABC, abstractmethod -from copy import deepcopy -from typing import Any, Callable, Dict, Optional, Union -from warnings import warn - -import torch -from torch import Tensor, nn, ones, optim -from torch.distributions import Distribution -from torch.nn.utils.clip_grad import clip_grad_norm_ -from torch.utils import data -from torch.utils.tensorboard.writer import SummaryWriter - -from sbi import utils as utils -from sbi.inference import NeuralInference, check_if_proposal_has_default_x -from sbi.inference.posteriors import ( - DirectPosterior, - MCMCPosterior, - RejectionPosterior, - VIPosterior, -) -from sbi.inference.posteriors.base_posterior import NeuralPosterior -from sbi.inference.potentials import posterior_estimator_based_potential -from sbi.utils import ( - RestrictedPrior, - check_estimator_arg, - test_posterior_net_for_multi_d_x, - validate_theta_and_x, - x_shape_from_simulation, - handle_invalid_x, - warn_if_zscoring_changes_data, - warn_on_invalid_x, - warn_on_invalid_x_for_snpec_leakage, -) -from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior - - -class PosteriorEstimator(NeuralInference, ABC): - def __init__( - self, - prior: Optional[Distribution] = None, - density_estimator: Union[str, Callable] = "maf", - device: str = "cpu", - logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[SummaryWriter] = None, - show_progress_bars: bool = True, - ): - """Base class for Sequential Neural Posterior Estimation methods. - - Args: - density_estimator: If it is a string, use a pre-configured network of the - provided type (one of nsf, maf, mdn, made). Alternatively, a function - that builds a custom neural network can be provided. The function will - be called with the first batch of simulations (theta, x), which can - thus be used for shape inference and potentially for z-scoring. It - needs to return a PyTorch `nn.Module` implementing the density - estimator. The density estimator needs to provide the methods - `.log_prob` and `.sample()`. - - See docstring of `NeuralInference` class for all other arguments. - """ - - super().__init__( - prior=prior, - device=device, - logging_level=logging_level, - summary_writer=summary_writer, - show_progress_bars=show_progress_bars, - ) - - # As detailed in the docstring, `density_estimator` is either a string or - # a callable. The function creating the neural network is attached to - # `_build_neural_net`. It will be called in the first round and receive - # thetas and xs as inputs, so that they can be used for shape inference and - # potentially for z-scoring. - check_estimator_arg(density_estimator) - if isinstance(density_estimator, str): - self._build_neural_net = utils.posterior_nn(model=density_estimator) - else: - self._build_neural_net = density_estimator - - self._proposal_roundwise = [] - self.use_non_atomic_loss = False - - # Extra SNPE-specific fields summary_writer. - self._summary.update({"rejection_sampling_acceptance_rates": []}) # type:ignore - - def append_simulations( - self, - theta: Tensor, - x: Tensor, - proposal: Optional[DirectPosterior] = None, - exclude_invalid_x: bool = True, - warn_on_invalid: bool = True, - warn_if_zscoring: bool = True, - return_self: bool = True, - data_device: str = None, - ) -> "PosteriorEstimator": - r"""Store parameters and simulation outputs to use them for later training. - - Data are stored as entries in lists for each type of variable (parameter/data). - - Stores $\theta$, $x$, prior_masks (indicating if simulations are coming from the - prior or not) and an index indicating which round the batch of simulations came - from. - - Args: - theta: Parameter sets. - x: Simulation outputs. - proposal: The distribution that the parameters $\theta$ were sampled from. - Pass `None` if the parameters were sampled from the prior. If not - `None`, it will trigger a different loss-function. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. - warn_on_invalid: Whether to warn if data is invalid - warn_if_zscoring: Whether to test if z-scoring causes duplicates - return_self: Whether to return a instance of the class, allows chaining - with `.train()`. Setting `False` decreases memory overhead. - data_device: Where to store the data, default is on the same device where - the training is happening. If training a large dataset on a GPU with not - much VRAM can set to 'cpu' to store data on system memory instead. - - Returns: - NeuralInference object (returned so that this function is chainable). - """ - - # Add ability to specify device data is saved on - if data_device is None: data_device = self._device - theta, x = validate_theta_and_x(theta, x, training_device=data_device) - - - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) - - # Check for problematic z-scoring - if warn_if_zscoring: - warn_if_zscoring_changes_data(x[is_valid_x]) - if warn_on_invalid: - warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) - warn_on_invalid_x_for_snpec_leakage( - num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round - ) - - x = x[is_valid_x] - theta = theta[is_valid_x] - - - self._check_proposal(proposal) - - if ( - proposal is None - or proposal is self._prior - or ( - isinstance(proposal, RestrictedPrior) and proposal._prior is self._prior - ) - ): - # The `_data_round_index` will later be used to infer if one should train - # with MLE loss or with atomic loss (see, in `train()`: - # self._round = max(self._data_round_index)) - self._data_round_index.append(0) - prior_masks = mask_sims_from_prior(0, theta.size(0)) - else: - if not self._data_round_index: - # This catches a pretty specific case: if, in the first round, one - # passes data that does not come from the prior. - self._data_round_index.append(1) - else: - self._data_round_index.append(max(self._data_round_index) + 1) - prior_masks = mask_sims_from_prior(1, theta.size(0)) - - - if self._dataset is None: - #If first round, set up ConcatDataset - self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) - else: - #Otherwise append to Dataset - self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) - - self._num_sims_per_round.append(theta.size(0)) - self._proposal_roundwise.append(proposal) - - if self._prior is None or isinstance(self._prior, ImproperEmpirical): - if proposal is not None: - raise ValueError( - "You had not passed a prior at initialization, but now you " - "passed a proposal. If you want to run multi-round SNPE, you have " - "to specify a prior (set the `.prior` argument or re-initialize " - "the object with a prior distribution). If the samples you passed " - "to `append_simulations()` were sampled from the prior, you can " - "run single-round inference with " - "`append_simulations(..., proposal=None)`." - ) - theta_prior = self.get_simulations()[0] - self._prior = ImproperEmpirical(theta_prior, ones(theta_prior.shape[0])) - - #Add ability to not return self - if return_self: - return self - - def train( - self, - training_batch_size: int = 50, - learning_rate: float = 5e-4, - validation_fraction: float = 0.1, - stop_after_epochs: int = 20, - max_num_epochs: int = 2**31 - 1, - clip_max_norm: Optional[float] = 5.0, - calibration_kernel: Optional[Callable] = None, - resume_training: bool = False, - force_first_round_loss: bool = False, - discard_prior_samples: bool = False, - retrain_from_scratch: bool = False, - show_train_summary: bool = False, - dataloader_kwargs: Optional[dict] = None, - ) -> nn.Module: - r"""Return density estimator that approximates the distribution $p(\theta|x)$. - - Args: - training_batch_size: Training batch size. - learning_rate: Learning rate for Adam optimizer. - validation_fraction: The fraction of data to use for validation. - stop_after_epochs: The number of epochs to wait for improvement on the - validation set before terminating training. - max_num_epochs: Maximum number of epochs to run. If reached, we stop - training even when the validation loss is still decreasing. Otherwise, - we train until validation loss increases (see also `stop_after_epochs`). - clip_max_norm: Value at which to clip the total gradient norm in order to - prevent exploding gradients. Use None for no clipping. - calibration_kernel: A function to calibrate the loss with respect to the - simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - resume_training: Can be used in case training time is limited, e.g. on a - cluster. If `True`, the split between train and validation set, the - optimizer, the number of epochs, and the best validation log-prob will - be restored from the last time `.train()` was called. - force_first_round_loss: If `True`, train with maximum likelihood, - i.e., potentially ignoring the correction for using a proposal - distribution different from the prior. - discard_prior_samples: Whether to discard samples simulated in round 1, i.e. - from the prior. Training may be sped up by ignoring such less targeted - samples. - retrain_from_scratch: Whether to retrain the conditional density - estimator for the posterior from scratch each round. - show_train_summary: Whether to print the number of epochs and validation - loss after the training. - dataloader_kwargs: Additional or updated kwargs to be passed to the training - and validation dataloaders (like, e.g., a collate_fn) - - Returns: - Density estimator that approximates the distribution $p(\theta|x)$. - """ - if self._round == 0 and self._neural_net is not None: - assert force_first_round_loss, ( - "You have already trained this neural network. After you had trained " - "the network, you again appended simulations with `append_simulations" - "(theta, x)`, but you did not provide a proposal. If the new " - "simulations are sampled from the prior, you can set " - "`.train(..., force_first_round_loss=True`). However, if the new " - "simulations were not sampled from the prior, you should pass the " - "proposal, i.e. `append_simulations(theta, x, proposal)`. If " - "your samples are not sampled from the prior and you do not pass a " - "proposal and you set `force_first_round_loss=True`, the result of " - "SNPE will not be the true posterior. Instead, it will be the proposal " - "posterior, which (usually) is more narrow than the true posterior." - ) - - # Calibration kernels proposed in Lueckmann, Gonçalves et al., 2017. - if calibration_kernel is None: - calibration_kernel = lambda x: ones([len(x)], device=self._device) - - # Starting index for the training set (1 = discard round-0 samples). - start_idx = int(discard_prior_samples and self._round > 0) - - # For non-atomic loss, we can not reuse samples from previous rounds as of now. - # SNPE-A can, by construction of the algorithm, only use samples from the last - # round. SNPE-A is the only algorithm that has an attribute `_ran_final_round`, - # so this is how we check for whether or not we are using SNPE-A. - if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"): - start_idx = self._round - - # Set the proposal to the last proposal that was passed by the user. For - # atomic SNPE, it does not matter what the proposal is. For non-atomic - # SNPE, we only use the latest data that was passed, i.e. the one from the - # last proposal. - proposal = self._proposal_roundwise[-1] - - train_loader, val_loader = self.get_dataloaders( - start_idx, - training_batch_size, - validation_fraction, - resume_training, - dataloader_kwargs=dataloader_kwargs, - ) - # First round or if retraining from scratch: - # Call the `self._build_neural_net` with the rounds' thetas and xs as - # arguments, which will build the neural network. - # This is passed into NeuralPosterior, to create a neural posterior which - # can `sample()` and `log_prob()`. The network is accessible via `.net`. - if self._neural_net is None or retrain_from_scratch: - - #Get theta,x from dataset to initialize NN - test_theta = self._dataset.datasets[0].tensors[0][:100] - test_x = self._dataset.datasets[0].tensors[1][:100] - - self._neural_net = self._build_neural_net( - test_theta, test_x - ) - # If data on training device already move net as well. - if ( - not self._device == "cpu" - and f"{test_x.device.type}:{test_x.device.index}" == self._device - ): - self._neural_net.to(self._device) - - test_posterior_net_for_multi_d_x(self._neural_net, test_theta, test_x) - self._x_shape = x_shape_from_simulation(test_x) - - # Move entire net to device for training. - self._neural_net.to(self._device) - - if not resume_training: - self.optimizer = optim.Adam( - list(self._neural_net.parameters()), lr=learning_rate - ) - self.epoch, self._val_log_prob = 0, float("-Inf") - - while self.epoch <= max_num_epochs and not self._converged( - self.epoch, stop_after_epochs - ): - - # Train for a single epoch. - self._neural_net.train() - train_log_probs_sum = 0 - epoch_start_time = time.time() - for batch in train_loader: - self.optimizer.zero_grad() - # Get batches on current device. - theta_batch, x_batch, masks_batch = ( - batch[0].to(self._device), - batch[1].to(self._device), - batch[2].to(self._device), - ) - - train_losses = self._loss( - theta_batch, x_batch, masks_batch, proposal, calibration_kernel - ) - train_loss = torch.mean(train_losses) - train_log_probs_sum -= train_losses.sum().item() - - train_loss.backward() - if clip_max_norm is not None: - clip_grad_norm_( - self._neural_net.parameters(), max_norm=clip_max_norm - ) - self.optimizer.step() - - self.epoch += 1 - - train_log_prob_average = train_log_probs_sum / ( - len(train_loader) * train_loader.batch_size # type: ignore - ) - self._summary["train_log_probs"].append(train_log_prob_average) - - # Calculate validation performance. - self._neural_net.eval() - val_log_prob_sum = 0 - - with torch.no_grad(): - for batch in val_loader: - theta_batch, x_batch, masks_batch = ( - batch[0].to(self._device), - batch[1].to(self._device), - batch[2].to(self._device), - ) - # Take negative loss here to get validation log_prob. - val_losses = self._loss( - theta_batch, - x_batch, - masks_batch, - proposal, - calibration_kernel, - ) - val_log_prob_sum -= val_losses.sum().item() - - # Take mean over all validation samples. - self._val_log_prob = val_log_prob_sum / ( - len(val_loader) * val_loader.batch_size # type: ignore - ) - # Log validation log prob for every epoch. - self._summary["validation_log_probs"].append(self._val_log_prob) - self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time) - - self._maybe_show_progress(self._show_progress_bars, self.epoch) - - self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs) - - # Update summary. - self._summary["epochs"].append(self.epoch) - self._summary["best_validation_log_probs"].append(self._best_val_log_prob) - - # Update tensorboard and summary dict. - self._summarize(round_=self._round, x_o=None, theta_bank=None, x_bank=None) - - # Update description for progress bar. - if show_train_summary: - print(self._describe_round(self._round, self._summary)) - - # Avoid keeping the gradients in the resulting network, which can - # cause memory leakage when benchmarking. - self._neural_net.zero_grad(set_to_none=True) - - return deepcopy(self._neural_net) - - def build_posterior( - self, - density_estimator: Optional[nn.Module] = None, - prior: Optional[Distribution] = None, - sample_with: str = "rejection", - mcmc_method: str = "slice_np", - vi_method: str = "rKL", - mcmc_parameters: Dict[str, Any] = {}, - vi_parameters: Dict[str, Any] = {}, - rejection_sampling_parameters: Dict[str, Any] = {}, - ) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior, DirectPosterior]: - r"""Build posterior from the neural density estimator. - - For SNPE, the posterior distribution that is returned here implements the - following functionality over the raw neural density estimator: - - correct the calculation of the log probability such that it compensates for - the leakage. - - reject samples that lie outside of the prior bounds. - - alternatively, if leakage is very high (which can happen for multi-round - SNPE), sample from the posterior with MCMC. - - Args: - density_estimator: The density estimator that the posterior is based on. - If `None`, use the latest neural density estimator that was trained. - prior: Prior distribution. - sample_with: Method to use for sampling from the posterior. Must be one of - [`mcmc` | `rejection` | `vi`]. - mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, - `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy - implementation of slice sampling; select `hmc`, `nuts` or `slice` for - Pyro-based sampling. - vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. Note - some of the methods admit a `mode seeking` property (e.g. rKL) whereas - some admit a `mass covering` one (e.g fKL). - mcmc_parameters: Additional kwargs passed to `MCMCPosterior`. - vi_parameters: Additional kwargs passed to `VIPosterior`. - rejection_sampling_parameters: Additional kwargs passed to - `RejectionPosterior` or `DirectPosterior`. By default, - `DirectPosterior` is used. Only if `rejection_sampling_parameters` - contains `proposal`, a `RejectionPosterior` is instantiated. - - Returns: - Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods - (the returned log-probability is unnormalized). - """ - if prior is None: - assert self._prior is not None, ( - "You did not pass a prior. You have to pass the prior either at " - "initialization `inference = SNPE(prior)` or to " - "`.build_posterior(prior=prior)`." - ) - prior = self._prior - else: - utils.check_prior(prior) - - if density_estimator is None: - posterior_estimator = self._neural_net - # If internal net is used device is defined. - device = self._device - else: - posterior_estimator = density_estimator - # Otherwise, infer it from the device of the net parameters. - device = next(density_estimator.parameters()).device.type - - potential_fn, theta_transform = posterior_estimator_based_potential( - posterior_estimator=posterior_estimator, prior=prior, x_o=None - ) - - if sample_with == "rejection": - if "proposal" in rejection_sampling_parameters.keys(): - self._posterior = RejectionPosterior( - potential_fn=potential_fn, - device=device, - x_shape=self._x_shape, - **rejection_sampling_parameters, - ) - else: - self._posterior = DirectPosterior( - posterior_estimator=posterior_estimator, - prior=prior, - x_shape=self._x_shape, - device=device, - ) - elif sample_with == "mcmc": - self._posterior = MCMCPosterior( - potential_fn=potential_fn, - theta_transform=theta_transform, - proposal=prior, - method=mcmc_method, - device=device, - x_shape=self._x_shape, - **mcmc_parameters, - ) - elif sample_with == "vi": - self._posterior = VIPosterior( - potential_fn=potential_fn, - theta_transform=theta_transform, - prior=prior, # type: ignore - vi_method=vi_method, - device=device, - x_shape=self._x_shape, - **vi_parameters, - ) - else: - raise NotImplementedError - - # Store models at end of each round. - self._model_bank.append(deepcopy(self._posterior)) - - return deepcopy(self._posterior) - - @abstractmethod - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - ) -> Tensor: - raise NotImplementedError - - def _loss( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - calibration_kernel: Callable, - ) -> Tensor: - """Return loss with proposal correction (`round_>0`) or without it (`round_=0`). - - The loss is the negative log prob. Irrespective of the round or SNPE method - (A, B, or C), it can be weighted with a calibration kernel. - - Returns: - Calibration kernel-weighted negative log prob. - """ - if self._round == 0: - # Use posterior log prob (without proposal correction) for first round. - log_prob = self._neural_net.log_prob(theta, x) - else: - log_prob = self._log_prob_proposal_posterior(theta, x, masks, proposal) - - return -(calibration_kernel(x) * log_prob) - - def _check_proposal(self, proposal): - """ - Check for validity of the provided proposal distribution. - - If the proposal is a `NeuralPosterior`, we check if the default_x is set. - If the proposal is **not** a `NeuralPosterior`, we warn since it is likely that - the user simply passed the prior, but this would still trigger atomic loss. - """ - if proposal is not None: - check_if_proposal_has_default_x(proposal) - - if isinstance(proposal, RestrictedPrior): - if proposal._prior is not self._prior: - warn( - "The proposal you passed is a `RestrictedPrior`, but the " - "proposal distribution it uses is not the prior (it can be " - "accessed via `RestrictedPrior._prior`). We do not " - "recommend to mix the `RestrictedPrior` with multi-round " - "SNPE." - ) - elif ( - not isinstance(proposal, NeuralPosterior) - and proposal is not self._prior - ): - warn( - "The proposal you passed is neither the prior nor a " - "`NeuralPosterior` object. If you are an expert user and did so " - "for research purposes, this is fine. If not, you might be doing " - "something wrong: feel free to create an issue on Github." - ) - elif self._round > 0: - raise ValueError( - "A proposal was passed but no prior was passed at initialisation. When " - "running multi-round inference, a prior needs to be specified upon " - "initialisation. Potential fix: setting the `._prior` attribute or " - "re-initialisation. If the samples passed to `append_simulations()` " - "were sampled from the prior, single-round inference can be performed " - "with `append_simulations(..., proprosal=None)`." - ) +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . +import time +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, Callable, Dict, Optional, Union +from warnings import warn + +import torch +from torch import Tensor, nn, ones, optim +from torch.distributions import Distribution +from torch.nn.utils.clip_grad import clip_grad_norm_ +from torch.utils import data +from torch.utils.tensorboard.writer import SummaryWriter + +from sbi import utils as utils +from sbi.inference import NeuralInference, check_if_proposal_has_default_x +from sbi.inference.posteriors import ( + DirectPosterior, + MCMCPosterior, + RejectionPosterior, + VIPosterior, +) +from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials import posterior_estimator_based_potential +from sbi.utils import ( + RestrictedPrior, + check_estimator_arg, + test_posterior_net_for_multi_d_x, + validate_theta_and_x, + x_shape_from_simulation, + handle_invalid_x, + warn_if_zscoring_changes_data, + warn_on_invalid_x, + warn_on_invalid_x_for_snpec_leakage, +) +from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior + + +class PosteriorEstimator(NeuralInference, ABC): + def __init__( + self, + prior: Optional[Distribution] = None, + density_estimator: Union[str, Callable] = "maf", + device: str = "cpu", + logging_level: Union[int, str] = "WARNING", + summary_writer: Optional[SummaryWriter] = None, + show_progress_bars: bool = True, + ): + """Base class for Sequential Neural Posterior Estimation methods. + + Args: + density_estimator: If it is a string, use a pre-configured network of the + provided type (one of nsf, maf, mdn, made). Alternatively, a function + that builds a custom neural network can be provided. The function will + be called with the first batch of simulations (theta, x), which can + thus be used for shape inference and potentially for z-scoring. It + needs to return a PyTorch `nn.Module` implementing the density + estimator. The density estimator needs to provide the methods + `.log_prob` and `.sample()`. + + See docstring of `NeuralInference` class for all other arguments. + """ + + super().__init__( + prior=prior, + device=device, + logging_level=logging_level, + summary_writer=summary_writer, + show_progress_bars=show_progress_bars, + ) + + # As detailed in the docstring, `density_estimator` is either a string or + # a callable. The function creating the neural network is attached to + # `_build_neural_net`. It will be called in the first round and receive + # thetas and xs as inputs, so that they can be used for shape inference and + # potentially for z-scoring. + check_estimator_arg(density_estimator) + if isinstance(density_estimator, str): + self._build_neural_net = utils.posterior_nn(model=density_estimator) + else: + self._build_neural_net = density_estimator + + self._proposal_roundwise = [] + self.use_non_atomic_loss = False + + # Extra SNPE-specific fields summary_writer. + self._summary.update({"rejection_sampling_acceptance_rates": []}) # type:ignore + + def append_simulations( + self, + theta: Tensor, + x: Tensor, + proposal: Optional[DirectPosterior] = None, + exclude_invalid_x: bool = True, + warn_on_invalid: bool = True, + warn_if_zscoring: bool = True, + return_self: bool = True, + data_device: str = None, + ) -> "PosteriorEstimator": + r"""Store parameters and simulation outputs to use them for later training. + + Data are stored as entries in lists for each type of variable (parameter/data). + + Stores $\theta$, $x$, prior_masks (indicating if simulations are coming from the + prior or not) and an index indicating which round the batch of simulations came + from. + + Args: + theta: Parameter sets. + x: Simulation outputs. + proposal: The distribution that the parameters $\theta$ were sampled from. + Pass `None` if the parameters were sampled from the prior. If not + `None`, it will trigger a different loss-function. + exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` + during training. Expect errors, silent or explicit, when `False`. + warn_on_invalid: Whether to warn if data is invalid + warn_if_zscoring: Whether to test if z-scoring causes duplicates + return_self: Whether to return a instance of the class, allows chaining + with `.train()`. Setting `False` decreases memory overhead. + data_device: Where to store the data, default is on the same device where + the training is happening. If training a large dataset on a GPU with not + much VRAM can set to 'cpu' to store data on system memory instead. + + Returns: + NeuralInference object (returned so that this function is chainable). + """ + + # Add ability to specify device data is saved on + if data_device is None: data_device = self._device + theta, x = validate_theta_and_x(theta, x, training_device=data_device) + + + is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) + + # Check for problematic z-scoring + if warn_if_zscoring: + warn_if_zscoring_changes_data(x[is_valid_x]) + if warn_on_invalid: + warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) + warn_on_invalid_x_for_snpec_leakage( + num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round + ) + + x = x[is_valid_x] + theta = theta[is_valid_x] + + + self._check_proposal(proposal) + + if ( + proposal is None + or proposal is self._prior + or ( + isinstance(proposal, RestrictedPrior) and proposal._prior is self._prior + ) + ): + # The `_data_round_index` will later be used to infer if one should train + # with MLE loss or with atomic loss (see, in `train()`: + # self._round = max(self._data_round_index)) + self._data_round_index.append(0) + prior_masks = mask_sims_from_prior(0, theta.size(0)) + else: + if not self._data_round_index: + # This catches a pretty specific case: if, in the first round, one + # passes data that does not come from the prior. + self._data_round_index.append(1) + else: + self._data_round_index.append(max(self._data_round_index) + 1) + prior_masks = mask_sims_from_prior(1, theta.size(0)) + + + if self._dataset is None: + #If first round, set up ConcatDataset + self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) + else: + #Otherwise append to Dataset + self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) + + self._num_sims_per_round.append(theta.size(0)) + self._proposal_roundwise.append(proposal) + + if self._prior is None or isinstance(self._prior, ImproperEmpirical): + if proposal is not None: + raise ValueError( + "You had not passed a prior at initialization, but now you " + "passed a proposal. If you want to run multi-round SNPE, you have " + "to specify a prior (set the `.prior` argument or re-initialize " + "the object with a prior distribution). If the samples you passed " + "to `append_simulations()` were sampled from the prior, you can " + "run single-round inference with " + "`append_simulations(..., proposal=None)`." + ) + theta_prior = self.get_simulations()[0] + self._prior = ImproperEmpirical(theta_prior, ones(theta_prior.shape[0])) + + #Add ability to not return self + if return_self: + return self + + def train( + self, + training_batch_size: int = 50, + learning_rate: float = 5e-4, + validation_fraction: float = 0.1, + stop_after_epochs: int = 20, + max_num_epochs: int = 2**31 - 1, + clip_max_norm: Optional[float] = 5.0, + calibration_kernel: Optional[Callable] = None, + resume_training: bool = False, + force_first_round_loss: bool = False, + discard_prior_samples: bool = False, + retrain_from_scratch: bool = False, + show_train_summary: bool = False, + dataloader_kwargs: Optional[dict] = None, + ) -> nn.Module: + r"""Return density estimator that approximates the distribution $p(\theta|x)$. + + Args: + training_batch_size: Training batch size. + learning_rate: Learning rate for Adam optimizer. + validation_fraction: The fraction of data to use for validation. + stop_after_epochs: The number of epochs to wait for improvement on the + validation set before terminating training. + max_num_epochs: Maximum number of epochs to run. If reached, we stop + training even when the validation loss is still decreasing. Otherwise, + we train until validation loss increases (see also `stop_after_epochs`). + clip_max_norm: Value at which to clip the total gradient norm in order to + prevent exploding gradients. Use None for no clipping. + calibration_kernel: A function to calibrate the loss with respect to the + simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. + resume_training: Can be used in case training time is limited, e.g. on a + cluster. If `True`, the split between train and validation set, the + optimizer, the number of epochs, and the best validation log-prob will + be restored from the last time `.train()` was called. + force_first_round_loss: If `True`, train with maximum likelihood, + i.e., potentially ignoring the correction for using a proposal + distribution different from the prior. + discard_prior_samples: Whether to discard samples simulated in round 1, i.e. + from the prior. Training may be sped up by ignoring such less targeted + samples. + retrain_from_scratch: Whether to retrain the conditional density + estimator for the posterior from scratch each round. + show_train_summary: Whether to print the number of epochs and validation + loss after the training. + dataloader_kwargs: Additional or updated kwargs to be passed to the training + and validation dataloaders (like, e.g., a collate_fn) + + Returns: + Density estimator that approximates the distribution $p(\theta|x)$. + """ + if self._round == 0 and self._neural_net is not None: + assert force_first_round_loss, ( + "You have already trained this neural network. After you had trained " + "the network, you again appended simulations with `append_simulations" + "(theta, x)`, but you did not provide a proposal. If the new " + "simulations are sampled from the prior, you can set " + "`.train(..., force_first_round_loss=True`). However, if the new " + "simulations were not sampled from the prior, you should pass the " + "proposal, i.e. `append_simulations(theta, x, proposal)`. If " + "your samples are not sampled from the prior and you do not pass a " + "proposal and you set `force_first_round_loss=True`, the result of " + "SNPE will not be the true posterior. Instead, it will be the proposal " + "posterior, which (usually) is more narrow than the true posterior." + ) + + # Calibration kernels proposed in Lueckmann, Gonçalves et al., 2017. + if calibration_kernel is None: + calibration_kernel = lambda x: ones([len(x)], device=self._device) + + # Starting index for the training set (1 = discard round-0 samples). + start_idx = int(discard_prior_samples and self._round > 0) + + # For non-atomic loss, we can not reuse samples from previous rounds as of now. + # SNPE-A can, by construction of the algorithm, only use samples from the last + # round. SNPE-A is the only algorithm that has an attribute `_ran_final_round`, + # so this is how we check for whether or not we are using SNPE-A. + if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"): + start_idx = self._round + + # Set the proposal to the last proposal that was passed by the user. For + # atomic SNPE, it does not matter what the proposal is. For non-atomic + # SNPE, we only use the latest data that was passed, i.e. the one from the + # last proposal. + proposal = self._proposal_roundwise[-1] + + train_loader, val_loader = self.get_dataloaders( + start_idx, + training_batch_size, + validation_fraction, + resume_training, + dataloader_kwargs=dataloader_kwargs, + ) + # First round or if retraining from scratch: + # Call the `self._build_neural_net` with the rounds' thetas and xs as + # arguments, which will build the neural network. + # This is passed into NeuralPosterior, to create a neural posterior which + # can `sample()` and `log_prob()`. The network is accessible via `.net`. + if self._neural_net is None or retrain_from_scratch: + + #Get theta,x from dataset to initialize NN + test_theta = self._dataset.datasets[0].tensors[0][:100] + test_x = self._dataset.datasets[0].tensors[1][:100] + + self._neural_net = self._build_neural_net( + test_theta, test_x + ) + # If data on training device already move net as well. + if ( + not self._device == "cpu" + and f"{test_x.device.type}:{test_x.device.index}" == self._device + ): + self._neural_net.to(self._device) + + test_posterior_net_for_multi_d_x(self._neural_net, test_theta, test_x) + self._x_shape = x_shape_from_simulation(test_x) + + # Move entire net to device for training. + self._neural_net.to(self._device) + + if not resume_training: + self.optimizer = optim.Adam( + list(self._neural_net.parameters()), lr=learning_rate + ) + self.epoch, self._val_log_prob = 0, float("-Inf") + + while self.epoch <= max_num_epochs and not self._converged( + self.epoch, stop_after_epochs + ): + + # Train for a single epoch. + self._neural_net.train() + train_log_probs_sum = 0 + epoch_start_time = time.time() + for batch in train_loader: + self.optimizer.zero_grad() + # Get batches on current device. + theta_batch, x_batch, masks_batch = ( + batch[0].to(self._device), + batch[1].to(self._device), + batch[2].to(self._device), + ) + + train_losses = self._loss( + theta_batch, x_batch, masks_batch, proposal, calibration_kernel + ) + train_loss = torch.mean(train_losses) + train_log_probs_sum -= train_losses.sum().item() + + train_loss.backward() + if clip_max_norm is not None: + clip_grad_norm_( + self._neural_net.parameters(), max_norm=clip_max_norm + ) + self.optimizer.step() + + self.epoch += 1 + + train_log_prob_average = train_log_probs_sum / ( + len(train_loader) * train_loader.batch_size # type: ignore + ) + self._summary["train_log_probs"].append(train_log_prob_average) + + # Calculate validation performance. + self._neural_net.eval() + val_log_prob_sum = 0 + + with torch.no_grad(): + for batch in val_loader: + theta_batch, x_batch, masks_batch = ( + batch[0].to(self._device), + batch[1].to(self._device), + batch[2].to(self._device), + ) + # Take negative loss here to get validation log_prob. + val_losses = self._loss( + theta_batch, + x_batch, + masks_batch, + proposal, + calibration_kernel, + ) + val_log_prob_sum -= val_losses.sum().item() + + # Take mean over all validation samples. + self._val_log_prob = val_log_prob_sum / ( + len(val_loader) * val_loader.batch_size # type: ignore + ) + # Log validation log prob for every epoch. + self._summary["validation_log_probs"].append(self._val_log_prob) + self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time) + + self._maybe_show_progress(self._show_progress_bars, self.epoch) + + self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs) + + # Update summary. + self._summary["epochs"].append(self.epoch) + self._summary["best_validation_log_probs"].append(self._best_val_log_prob) + + # Update tensorboard and summary dict. + self._summarize(round_=self._round, x_o=None, theta_bank=None, x_bank=None) + + # Update description for progress bar. + if show_train_summary: + print(self._describe_round(self._round, self._summary)) + + # Avoid keeping the gradients in the resulting network, which can + # cause memory leakage when benchmarking. + self._neural_net.zero_grad(set_to_none=True) + + return deepcopy(self._neural_net) + + def build_posterior( + self, + density_estimator: Optional[nn.Module] = None, + prior: Optional[Distribution] = None, + sample_with: str = "rejection", + mcmc_method: str = "slice_np", + vi_method: str = "rKL", + mcmc_parameters: Dict[str, Any] = {}, + vi_parameters: Dict[str, Any] = {}, + rejection_sampling_parameters: Dict[str, Any] = {}, + ) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior, DirectPosterior]: + r"""Build posterior from the neural density estimator. + + For SNPE, the posterior distribution that is returned here implements the + following functionality over the raw neural density estimator: + - correct the calculation of the log probability such that it compensates for + the leakage. + - reject samples that lie outside of the prior bounds. + - alternatively, if leakage is very high (which can happen for multi-round + SNPE), sample from the posterior with MCMC. + + Args: + density_estimator: The density estimator that the posterior is based on. + If `None`, use the latest neural density estimator that was trained. + prior: Prior distribution. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection` | `vi`]. + mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, + `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy + implementation of slice sampling; select `hmc`, `nuts` or `slice` for + Pyro-based sampling. + vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. Note + some of the methods admit a `mode seeking` property (e.g. rKL) whereas + some admit a `mass covering` one (e.g fKL). + mcmc_parameters: Additional kwargs passed to `MCMCPosterior`. + vi_parameters: Additional kwargs passed to `VIPosterior`. + rejection_sampling_parameters: Additional kwargs passed to + `RejectionPosterior` or `DirectPosterior`. By default, + `DirectPosterior` is used. Only if `rejection_sampling_parameters` + contains `proposal`, a `RejectionPosterior` is instantiated. + + Returns: + Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods + (the returned log-probability is unnormalized). + """ + if prior is None: + assert self._prior is not None, ( + "You did not pass a prior. You have to pass the prior either at " + "initialization `inference = SNPE(prior)` or to " + "`.build_posterior(prior=prior)`." + ) + prior = self._prior + else: + utils.check_prior(prior) + + if density_estimator is None: + posterior_estimator = self._neural_net + # If internal net is used device is defined. + device = self._device + else: + posterior_estimator = density_estimator + # Otherwise, infer it from the device of the net parameters. + device = next(density_estimator.parameters()).device.type + + potential_fn, theta_transform = posterior_estimator_based_potential( + posterior_estimator=posterior_estimator, prior=prior, x_o=None + ) + + if sample_with == "rejection": + if "proposal" in rejection_sampling_parameters.keys(): + self._posterior = RejectionPosterior( + potential_fn=potential_fn, + device=device, + x_shape=self._x_shape, + **rejection_sampling_parameters, + ) + else: + self._posterior = DirectPosterior( + posterior_estimator=posterior_estimator, + prior=prior, + x_shape=self._x_shape, + device=device, + ) + elif sample_with == "mcmc": + self._posterior = MCMCPosterior( + potential_fn=potential_fn, + theta_transform=theta_transform, + proposal=prior, + method=mcmc_method, + device=device, + x_shape=self._x_shape, + **mcmc_parameters, + ) + elif sample_with == "vi": + self._posterior = VIPosterior( + potential_fn=potential_fn, + theta_transform=theta_transform, + prior=prior, # type: ignore + vi_method=vi_method, + device=device, + x_shape=self._x_shape, + **vi_parameters, + ) + else: + raise NotImplementedError + + # Store models at end of each round. + self._model_bank.append(deepcopy(self._posterior)) + + return deepcopy(self._posterior) + + @abstractmethod + def _log_prob_proposal_posterior( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + ) -> Tensor: + raise NotImplementedError + + def _loss( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + calibration_kernel: Callable, + ) -> Tensor: + """Return loss with proposal correction (`round_>0`) or without it (`round_=0`). + + The loss is the negative log prob. Irrespective of the round or SNPE method + (A, B, or C), it can be weighted with a calibration kernel. + + Returns: + Calibration kernel-weighted negative log prob. + """ + if self._round == 0: + # Use posterior log prob (without proposal correction) for first round. + log_prob = self._neural_net.log_prob(theta, x) + else: + log_prob = self._log_prob_proposal_posterior(theta, x, masks, proposal) + + return -(calibration_kernel(x) * log_prob) + + def _check_proposal(self, proposal): + """ + Check for validity of the provided proposal distribution. + + If the proposal is a `NeuralPosterior`, we check if the default_x is set. + If the proposal is **not** a `NeuralPosterior`, we warn since it is likely that + the user simply passed the prior, but this would still trigger atomic loss. + """ + if proposal is not None: + check_if_proposal_has_default_x(proposal) + + if isinstance(proposal, RestrictedPrior): + if proposal._prior is not self._prior: + warn( + "The proposal you passed is a `RestrictedPrior`, but the " + "proposal distribution it uses is not the prior (it can be " + "accessed via `RestrictedPrior._prior`). We do not " + "recommend to mix the `RestrictedPrior` with multi-round " + "SNPE." + ) + elif ( + not isinstance(proposal, NeuralPosterior) + and proposal is not self._prior + ): + warn( + "The proposal you passed is neither the prior nor a " + "`NeuralPosterior` object. If you are an expert user and did so " + "for research purposes, this is fine. If not, you might be doing " + "something wrong: feel free to create an issue on Github." + ) + elif self._round > 0: + raise ValueError( + "A proposal was passed but no prior was passed at initialisation. When " + "running multi-round inference, a prior needs to be specified upon " + "initialisation. Potential fix: setting the `._prior` attribute or " + "re-initialisation. If the samples passed to `append_simulations()` " + "were sampled from the prior, single-round inference can be performed " + "with `append_simulations(..., proprosal=None)`." + ) diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index b9cd11984..34c53ab4b 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -1,625 +1,625 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . - - -from typing import Callable, Dict, Optional, Union - -import torch -from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn -from pyknos.nflows.transforms import CompositeTransform -from torch import Tensor, eye, nn, ones -from torch.distributions import Distribution, MultivariateNormal, Uniform - -from sbi import utils as utils -from sbi.inference.posteriors.direct_posterior import DirectPosterior -from sbi.inference.snpe.snpe_base import PosteriorEstimator -from sbi.types import TensorboardSummaryWriter -from sbi.utils import ( - batched_mixture_mv, - batched_mixture_vmv, - check_dist_class, - clamp_and_warn, - del_entries, - repeat_rows, -) - - -class SNPE_C(PosteriorEstimator): - def __init__( - self, - prior: Optional[Distribution] = None, - density_estimator: Union[str, Callable] = "maf", - device: str = "cpu", - logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorboardSummaryWriter] = None, - show_progress_bars: bool = True, - ): - r"""SNPE-C / APT [1]. - - [1] _Automatic Posterior Transformation for Likelihood-free Inference_, - Greenberg et al., ICML 2019, https://arxiv.org/abs/1905.07488. - - This class implements two loss variants of SNPE-C: the non-atomic and the atomic - version. The atomic loss of SNPE-C can be used for any density estimator, - i.e. also for normalizing flows. However, it suffers from leakage issues. On - the other hand, the non-atomic loss can only be used only if the proposal - distribution is a mixture of Gaussians, the density estimator is a mixture of - Gaussians, and the prior is either Gaussian or Uniform. It does not suffer from - leakage issues. At the beginning of each round, we print whether the non-atomic - or the atomic version is used. - - In this codebase, we will automatically switch to the non-atomic loss if the - following criteria are fulfilled:
- - proposal is a `DirectPosterior` with density_estimator `mdn`, as built - with `utils.sbi.posterior_nn()`.
- - the density estimator is a `mdn`, as built with - `utils.sbi.posterior_nn()`.
- - `isinstance(prior, MultivariateNormal)` (from `torch.distributions`) or - `isinstance(prior, sbi.utils.BoxUniform)` - - Note that custom implementations of any of these densities (or estimators) will - not trigger the non-atomic loss, and the algorithm will fall back onto using - the atomic loss. - - Args: - prior: A probability distribution that expresses prior knowledge about the - parameters, e.g. which ranges are meaningful for them. - density_estimator: If it is a string, use a pre-configured network of the - provided type (one of nsf, maf, mdn, made). Alternatively, a function - that builds a custom neural network can be provided. The function will - be called with the first batch of simulations (theta, x), which can - thus be used for shape inference and potentially for z-scoring. It - needs to return a PyTorch `nn.Module` implementing the density - estimator. The density estimator needs to provide the methods - `.log_prob` and `.sample()`. - device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". - logging_level: Minimum severity of messages to log. One of the strings - INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) - show_progress_bars: Whether to show a progressbar during training. - """ - - kwargs = del_entries(locals(), entries=("self", "__class__")) - super().__init__(**kwargs) - - def train( - self, - num_atoms: int = 10, - training_batch_size: int = 50, - learning_rate: float = 5e-4, - validation_fraction: float = 0.1, - stop_after_epochs: int = 20, - max_num_epochs: int = 2**31 - 1, - clip_max_norm: Optional[float] = 5.0, - calibration_kernel: Optional[Callable] = None, - resume_training: bool = False, - force_first_round_loss: bool = False, - discard_prior_samples: bool = False, - use_combined_loss: bool = False, - retrain_from_scratch: bool = False, - show_train_summary: bool = False, - dataloader_kwargs: Optional[Dict] = None, - ) -> nn.Module: - r"""Return density estimator that approximates the distribution $p(\theta|x)$. - - Args: - num_atoms: Number of atoms to use for classification. - training_batch_size: Training batch size. - learning_rate: Learning rate for Adam optimizer. - validation_fraction: The fraction of data to use for validation. - stop_after_epochs: The number of epochs to wait for improvement on the - validation set before terminating training. - max_num_epochs: Maximum number of epochs to run. If reached, we stop - training even when the validation loss is still decreasing. Otherwise, - we train until validation loss increases (see also `stop_after_epochs`). - clip_max_norm: Value at which to clip the total gradient norm in order to - prevent exploding gradients. Use None for no clipping. - calibration_kernel: A function to calibrate the loss with respect to the - simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - resume_training: Can be used in case training time is limited, e.g. on a - cluster. If `True`, the split between train and validation set, the - optimizer, the number of epochs, and the best validation log-prob will - be restored from the last time `.train()` was called. - force_first_round_loss: If `True`, train with maximum likelihood, - i.e., potentially ignoring the correction for using a proposal - distribution different from the prior. - discard_prior_samples: Whether to discard samples simulated in round 1, i.e. - from the prior. Training may be sped up by ignoring such less targeted - samples. - use_combined_loss: Whether to train the neural net also on prior samples - using maximum likelihood in addition to training it on all samples using - atomic loss. The extra MLE loss helps prevent density leaking with - bounded priors. - retrain_from_scratch: Whether to retrain the conditional density - estimator for the posterior from scratch each round. - show_train_summary: Whether to print the number of epochs and validation - loss and leakage after the training. - dataloader_kwargs: Additional or updated kwargs to be passed to the training - and validation dataloaders (like, e.g., a collate_fn) - - Returns: - Density estimator that approximates the distribution $p(\theta|x)$. - """ - - # WARNING: sneaky trick ahead. We proxy the parent's `train` here, - # requiring the signature to have `num_atoms`, save it for use below, and - # continue. It's sneaky because we are using the object (self) as a namespace - # to pass arguments between functions, and that's implicit state management. - self._num_atoms = num_atoms - self._use_combined_loss = use_combined_loss - kwargs = del_entries( - locals(), entries=("self", "__class__", "num_atoms", "use_combined_loss") - ) - - self._round = max(self._data_round_index) - - if self._round > 0: - # Set the proposal to the last proposal that was passed by the user. For - # atomic SNPE, it does not matter what the proposal is. For non-atomic - # SNPE, we only use the latest data that was passed, i.e. the one from the - # last proposal. - proposal = self._proposal_roundwise[-1] - self.use_non_atomic_loss = ( - isinstance(proposal.posterior_estimator._distribution, mdn) - and isinstance(self._neural_net._distribution, mdn) - and check_dist_class( - self._prior, class_to_check=(Uniform, MultivariateNormal) - )[0] - ) - - algorithm = "non-atomic" if self.use_non_atomic_loss else "atomic" - print(f"Using SNPE-C with {algorithm} loss") - - if self.use_non_atomic_loss: - # Take care of z-scoring, pre-compute and store prior terms. - self._set_state_for_mog_proposal() - - return super().train(**kwargs) - - def _set_state_for_mog_proposal(self) -> None: - """Set state variables that are used at each training step of non-atomic SNPE-C. - - Three things are computed: - 1) Check if z-scoring was requested. To do so, we check if the `_transform` - argument of the net had been a `CompositeTransform`. See pyknos mdn.py. - 2) Define a (potentially standardized) prior. It's standardized if z-scoring - had been requested. - 3) Compute (Precision * mean) for the prior. This quantity is used at every - training step if the prior is Gaussian. - """ - - self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) - - self._set_maybe_z_scored_prior() - - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - self.prec_m_prod_prior = torch.mv( - self._maybe_z_scored_prior.precision_matrix, # type: ignore - self._maybe_z_scored_prior.loc, # type: ignore - ) - - def _set_maybe_z_scored_prior(self) -> None: - r"""Compute and store potentially standardized prior (if z-scoring was done). - - The proposal posterior is: - $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ - - Let's denote z-scored theta by `a`: a = (theta - mean) / std - Then pp'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ - - The ' indicates that the evaluation occurs in standardized space. The constant - scaling factor has been absorbed into Z_2. - From the above equation, we see that we need to evaluate the prior **in - standardized space**. We build the standardized prior in this function. - - The standardize transform that is applied to the samples theta does not use - the exact prior mean and std (due to implementation issues). Hence, the z-scored - prior will not be exactly have mean=0 and std=1. - """ - - if self.z_score_theta: - scale = self._neural_net._transform._transforms[0]._scale - shift = self._neural_net._transform._transforms[0]._shift - - # Following the definintion of the linear transform in - # `standardizing_transform` in `sbiutils.py`: - # shift=-mean / std - # scale=1 / std - # Solving these equations for mean and std: - estim_prior_std = 1 / scale - estim_prior_mean = -shift * estim_prior_std - - # Compute the discrepancy of the true prior mean and std and the mean and - # std that was empirically estimated from samples. - # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) - # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean - # and std (estimated from samples and used to build standardize transform). - almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std - almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std - - if isinstance(self._prior, MultivariateNormal): - self._maybe_z_scored_prior = MultivariateNormal( - almost_zero_mean, torch.diag(almost_one_std) - ) - else: - range_ = torch.sqrt(almost_one_std * 3.0) - self._maybe_z_scored_prior = utils.BoxUniform( - almost_zero_mean - range_, almost_zero_mean + range_ - ) - else: - self._maybe_z_scored_prior = self._prior - - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: DirectPosterior, - ) -> Tensor: - """Return the log-probability of the proposal posterior. - - If the proposal is a MoG, the density estimator is a MoG, and the prior is - either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which - suffers from leakage). - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - proposal: Proposal distribution. - - Returns: Log-probability of the proposal posterior. - """ - - if self.use_non_atomic_loss: - return self._log_prob_proposal_posterior_mog(theta, x, proposal) - else: - return self._log_prob_proposal_posterior_atomic(theta, x, masks) - - def _log_prob_proposal_posterior_atomic( - self, theta: Tensor, x: Tensor, masks: Tensor - ): - """Return log probability of the proposal posterior for atomic proposals. - - We have two main options when evaluating the proposal posterior. - (1) Generate atoms from the proposal prior. - (2) Generate atoms from a more targeted distribution, such as the most - recent posterior. - If we choose the latter, it is likely beneficial not to do this in the first - round, since we would be sampling from a randomly-initialized neural density - estimator. - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - - Returns: - Log-probability of the proposal posterior. - """ - - batch_size = theta.shape[0] - - num_atoms = int( - clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size) - ) - - # Each set of parameter atoms is evaluated using the same x, - # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2] - repeated_x = repeat_rows(x, num_atoms) - - # To generate the full set of atoms for a given item in the batch, - # we sample without replacement num_atoms - 1 times from the rest - # of the theta in the batch. - probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) - - choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) - contrasting_theta = theta[choices] - - # We can now create our sets of atoms from the contrasting parameter sets - # we have generated. - atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( - batch_size * num_atoms, -1 - ) - - # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. - log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x) - utils.assert_all_finite(log_prob_posterior, "posterior eval") - log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) - - # Get (batch_size * num_atoms) log prob prior evals. - log_prob_prior = self._prior.log_prob(atomic_theta) - log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) - utils.assert_all_finite(log_prob_prior, "prior eval") - - # Compute unnormalized proposal posterior. - unnormalized_log_prob = log_prob_posterior - log_prob_prior - - # Normalize proposal posterior across discrete set of atoms. - log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp( - unnormalized_log_prob, dim=-1 - ) - utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") - - # XXX This evaluates the posterior on _all_ prior samples - if self._use_combined_loss: - log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x) - masks = masks.reshape(-1) - log_prob_proposal_posterior = ( - masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior - ) - - return log_prob_proposal_posterior - - def _log_prob_proposal_posterior_mog( - self, theta: Tensor, x: Tensor, proposal: DirectPosterior - ) -> Tensor: - """Return log-probability of the proposal posterior for MoG proposal. - - For MoG proposals and MoG density estimators, this can be done in closed form - and does not require atomic loss (i.e. there will be no leakage issues). - - Notation: - - m are mean vectors. - prec are precision matrices. - cov are covariance matrices. - - _p at the end indicates that it is the proposal. - _d indicates that it is the density estimator. - _pp indicates the proposal posterior. - - All tensors will have shapes (batch_dim, num_components, ...) - - Args: - theta: Batch of parameters θ. - x: Batch of data. - proposal: Proposal distribution. - - Returns: - Log-probability of the proposal posterior. - """ - - # Evaluate the proposal. MDNs do not have functionality to run the embedding_net - # and then get the mixture_components (**without** calling log_prob()). Hence, - # we call them separately here. - encoded_x = proposal.posterior_estimator._embedding_net(proposal.default_x) - dist = ( - proposal.posterior_estimator._distribution - ) # defined to avoid ugly black formatting. - logits_p, m_p, prec_p, _, _ = dist.get_mixture_components(encoded_x) - norm_logits_p = logits_p - torch.logsumexp(logits_p, dim=-1, keepdim=True) - - # Evaluate the density estimator. - encoded_x = self._neural_net._embedding_net(x) - dist = self._neural_net._distribution # defined to avoid black formatting. - logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) - norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) - - # z-score theta if it z-scoring had been requested. - theta = self._maybe_z_score_theta(theta) - - # Compute the MoG parameters of the proposal posterior. - logits_pp, m_pp, prec_pp, cov_pp = self._automatic_posterior_transformation( - norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d - ) - - # Compute the log_prob of theta under the product. - log_prob_proposal_posterior = utils.mog_log_prob( - theta, logits_pp, m_pp, prec_pp - ) - utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") - - return log_prob_proposal_posterior - - def _automatic_posterior_transformation( - self, - logits_p: Tensor, - means_p: Tensor, - precisions_p: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Returns the MoG parameters of the proposal posterior. - - The proposal posterior is: - $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ - In words: proposal posterior = posterior estimate * proposal / prior. - - If the posterior estimate and the proposal are MoG and the prior is either - Gaussian or uniform, we can solve this in closed-form. The is implemented in - this function. - - This function implements Appendix A1 from Greenberg et al. 2019. - - We have to build L*K components. How do we do this? - Example: proposal has two components, density estimator has three components. - Let's call the two components of the proposal i,j and the three components - of the density estimator x,y,z. We have to multiply every component of the - proposal with every component of the density estimator. So, what we do is: - 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave() - 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat() - 3) Multiply them with simple matrix operations. - - Args: - logits_p: Component weight of each Gaussian of the proposal. - means_p: Mean of each Gaussian of the proposal. - precisions_p: Precision matrix of each Gaussian of the proposal. - logits_d: Component weight for each Gaussian of the density estimator. - means_d: Mean of each Gaussian of the density estimator. - precisions_d: Precision matrix of each Gaussian of the density estimator. - - Returns: (Component weight, mean, precision matrix, covariance matrix) of each - Gaussian of the proposal posterior. Has L*K terms (proposal has L terms, - density estimator has K terms). - """ - - precisions_pp, covariances_pp = self._precisions_proposal_posterior( - precisions_p, precisions_d - ) - - means_pp = self._means_proposal_posterior( - covariances_pp, means_p, precisions_p, means_d, precisions_d - ) - - logits_pp = self._logits_proposal_posterior( - means_pp, - precisions_pp, - covariances_pp, - logits_p, - means_p, - precisions_p, - logits_d, - means_d, - precisions_d, - ) - - return logits_pp, means_pp, precisions_pp, covariances_pp - - def _precisions_proposal_posterior( - self, precisions_p: Tensor, precisions_d: Tensor - ): - """Return the precisions and covariances of the proposal posterior. - - Args: - precisions_p: Precision matrices of the proposal distribution. - precisions_d: Precision matrices of the density estimator. - - Returns: (Precisions, Covariances) of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1) - precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) - - precisions_pp = precisions_p_rep + precisions_d_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - precisions_pp -= self._maybe_z_scored_prior.precision_matrix - - covariances_pp = torch.inverse(precisions_pp) - - return precisions_pp, covariances_pp - - def _means_proposal_posterior( - self, - covariances_pp: Tensor, - means_p: Tensor, - precisions_p: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - """Return the means of the proposal posterior. - - means_pp = C_ix * (P_i * m_i + P_x * m_x - P_o * m_o). - - Args: - covariances_pp: Covariance matrices of the proposal posterior. - means_p: Means of the proposal distribution. - precisions_p: Precision matrices of the proposal distribution. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Means of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - # First, compute the product P_i * m_i and P_j * m_j - prec_m_prod_p = batched_mixture_mv(precisions_p, means_p) - prec_m_prod_d = batched_mixture_mv(precisions_d, means_d) - - # Repeat them to allow for matrix operations: same trick as for the precisions. - prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1) - prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1) - - # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o). - summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - summed_cov_m_prod_rep -= self.prec_m_prod_prior - - means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep) - - return means_pp - - @staticmethod - def _logits_proposal_posterior( - means_pp: Tensor, - precisions_pp: Tensor, - covariances_pp: Tensor, - logits_p: Tensor, - means_p: Tensor, - precisions_p: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - """Return the component weights (i.e. logits) of the proposal posterior. - - Args: - means_pp: Means of the proposal posterior. - precisions_pp: Precision matrices of the proposal posterior. - covariances_pp: Covariance matrices of the proposal posterior. - logits_p: Component weights (i.e. logits) of the proposal distribution. - means_p: Means of the proposal distribution. - precisions_p: Precision matrices of the proposal distribution. - logits_d: Component weights (i.e. logits) of the density estimator. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Component weights of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - # Compute log(alpha_i * beta_j) - logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) - logits_d_rep = logits_d.repeat(1, num_comps_p) - logit_factors = logits_p_rep + logits_d_rep - - # Compute sqrt(det()/(det()*det())) - logdet_covariances_pp = torch.logdet(covariances_pp) - logdet_covariances_p = -torch.logdet(precisions_p) - logdet_covariances_d = -torch.logdet(precisions_d) - - # Repeat the proposal and density estimator terms such that there are LK terms. - # Same trick as has been used above. - logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( - num_comps_d, dim=1 - ) - logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) - - log_sqrt_det_ratio = 0.5 * ( - logdet_covariances_pp - - (logdet_covariances_p_rep + logdet_covariances_d_rep) - ) - - # Compute for proposal, density estimator, and proposal posterior: - # mu_i.T * P_i * mu_i - exponent_p = batched_mixture_vmv(precisions_p, means_p) - exponent_d = batched_mixture_vmv(precisions_d, means_d) - exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) - - # Extend proposal and density estimator exponents to get LK terms. - exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) - exponent_d_rep = exponent_d.repeat(1, num_comps_p) - exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) - - logits_pp = logit_factors + log_sqrt_det_ratio + exponent - - return logits_pp - - def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: - """Return potentially standardized theta if z-scoring was requested.""" - - if self.z_score_theta: - theta, _ = self._neural_net._transform(theta) - - return theta +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + + +from typing import Callable, Dict, Optional, Union + +import torch +from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn +from pyknos.nflows.transforms import CompositeTransform +from torch import Tensor, eye, nn, ones +from torch.distributions import Distribution, MultivariateNormal, Uniform + +from sbi import utils as utils +from sbi.inference.posteriors.direct_posterior import DirectPosterior +from sbi.inference.snpe.snpe_base import PosteriorEstimator +from sbi.types import TensorboardSummaryWriter +from sbi.utils import ( + batched_mixture_mv, + batched_mixture_vmv, + check_dist_class, + clamp_and_warn, + del_entries, + repeat_rows, +) + + +class SNPE_C(PosteriorEstimator): + def __init__( + self, + prior: Optional[Distribution] = None, + density_estimator: Union[str, Callable] = "maf", + device: str = "cpu", + logging_level: Union[int, str] = "WARNING", + summary_writer: Optional[TensorboardSummaryWriter] = None, + show_progress_bars: bool = True, + ): + r"""SNPE-C / APT [1]. + + [1] _Automatic Posterior Transformation for Likelihood-free Inference_, + Greenberg et al., ICML 2019, https://arxiv.org/abs/1905.07488. + + This class implements two loss variants of SNPE-C: the non-atomic and the atomic + version. The atomic loss of SNPE-C can be used for any density estimator, + i.e. also for normalizing flows. However, it suffers from leakage issues. On + the other hand, the non-atomic loss can only be used only if the proposal + distribution is a mixture of Gaussians, the density estimator is a mixture of + Gaussians, and the prior is either Gaussian or Uniform. It does not suffer from + leakage issues. At the beginning of each round, we print whether the non-atomic + or the atomic version is used. + + In this codebase, we will automatically switch to the non-atomic loss if the + following criteria are fulfilled:
+ - proposal is a `DirectPosterior` with density_estimator `mdn`, as built + with `utils.sbi.posterior_nn()`.
+ - the density estimator is a `mdn`, as built with + `utils.sbi.posterior_nn()`.
+ - `isinstance(prior, MultivariateNormal)` (from `torch.distributions`) or + `isinstance(prior, sbi.utils.BoxUniform)` + + Note that custom implementations of any of these densities (or estimators) will + not trigger the non-atomic loss, and the algorithm will fall back onto using + the atomic loss. + + Args: + prior: A probability distribution that expresses prior knowledge about the + parameters, e.g. which ranges are meaningful for them. + density_estimator: If it is a string, use a pre-configured network of the + provided type (one of nsf, maf, mdn, made). Alternatively, a function + that builds a custom neural network can be provided. The function will + be called with the first batch of simulations (theta, x), which can + thus be used for shape inference and potentially for z-scoring. It + needs to return a PyTorch `nn.Module` implementing the density + estimator. The density estimator needs to provide the methods + `.log_prob` and `.sample()`. + device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". + logging_level: Minimum severity of messages to log. One of the strings + INFO, WARNING, DEBUG, ERROR and CRITICAL. + summary_writer: A tensorboard `SummaryWriter` to control, among others, log + file location (default is `/logs`.) + show_progress_bars: Whether to show a progressbar during training. + """ + + kwargs = del_entries(locals(), entries=("self", "__class__")) + super().__init__(**kwargs) + + def train( + self, + num_atoms: int = 10, + training_batch_size: int = 50, + learning_rate: float = 5e-4, + validation_fraction: float = 0.1, + stop_after_epochs: int = 20, + max_num_epochs: int = 2**31 - 1, + clip_max_norm: Optional[float] = 5.0, + calibration_kernel: Optional[Callable] = None, + resume_training: bool = False, + force_first_round_loss: bool = False, + discard_prior_samples: bool = False, + use_combined_loss: bool = False, + retrain_from_scratch: bool = False, + show_train_summary: bool = False, + dataloader_kwargs: Optional[Dict] = None, + ) -> nn.Module: + r"""Return density estimator that approximates the distribution $p(\theta|x)$. + + Args: + num_atoms: Number of atoms to use for classification. + training_batch_size: Training batch size. + learning_rate: Learning rate for Adam optimizer. + validation_fraction: The fraction of data to use for validation. + stop_after_epochs: The number of epochs to wait for improvement on the + validation set before terminating training. + max_num_epochs: Maximum number of epochs to run. If reached, we stop + training even when the validation loss is still decreasing. Otherwise, + we train until validation loss increases (see also `stop_after_epochs`). + clip_max_norm: Value at which to clip the total gradient norm in order to + prevent exploding gradients. Use None for no clipping. + calibration_kernel: A function to calibrate the loss with respect to the + simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. + resume_training: Can be used in case training time is limited, e.g. on a + cluster. If `True`, the split between train and validation set, the + optimizer, the number of epochs, and the best validation log-prob will + be restored from the last time `.train()` was called. + force_first_round_loss: If `True`, train with maximum likelihood, + i.e., potentially ignoring the correction for using a proposal + distribution different from the prior. + discard_prior_samples: Whether to discard samples simulated in round 1, i.e. + from the prior. Training may be sped up by ignoring such less targeted + samples. + use_combined_loss: Whether to train the neural net also on prior samples + using maximum likelihood in addition to training it on all samples using + atomic loss. The extra MLE loss helps prevent density leaking with + bounded priors. + retrain_from_scratch: Whether to retrain the conditional density + estimator for the posterior from scratch each round. + show_train_summary: Whether to print the number of epochs and validation + loss and leakage after the training. + dataloader_kwargs: Additional or updated kwargs to be passed to the training + and validation dataloaders (like, e.g., a collate_fn) + + Returns: + Density estimator that approximates the distribution $p(\theta|x)$. + """ + + # WARNING: sneaky trick ahead. We proxy the parent's `train` here, + # requiring the signature to have `num_atoms`, save it for use below, and + # continue. It's sneaky because we are using the object (self) as a namespace + # to pass arguments between functions, and that's implicit state management. + self._num_atoms = num_atoms + self._use_combined_loss = use_combined_loss + kwargs = del_entries( + locals(), entries=("self", "__class__", "num_atoms", "use_combined_loss") + ) + + self._round = max(self._data_round_index) + + if self._round > 0: + # Set the proposal to the last proposal that was passed by the user. For + # atomic SNPE, it does not matter what the proposal is. For non-atomic + # SNPE, we only use the latest data that was passed, i.e. the one from the + # last proposal. + proposal = self._proposal_roundwise[-1] + self.use_non_atomic_loss = ( + isinstance(proposal.posterior_estimator._distribution, mdn) + and isinstance(self._neural_net._distribution, mdn) + and check_dist_class( + self._prior, class_to_check=(Uniform, MultivariateNormal) + )[0] + ) + + algorithm = "non-atomic" if self.use_non_atomic_loss else "atomic" + print(f"Using SNPE-C with {algorithm} loss") + + if self.use_non_atomic_loss: + # Take care of z-scoring, pre-compute and store prior terms. + self._set_state_for_mog_proposal() + + return super().train(**kwargs) + + def _set_state_for_mog_proposal(self) -> None: + """Set state variables that are used at each training step of non-atomic SNPE-C. + + Three things are computed: + 1) Check if z-scoring was requested. To do so, we check if the `_transform` + argument of the net had been a `CompositeTransform`. See pyknos mdn.py. + 2) Define a (potentially standardized) prior. It's standardized if z-scoring + had been requested. + 3) Compute (Precision * mean) for the prior. This quantity is used at every + training step if the prior is Gaussian. + """ + + self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) + + self._set_maybe_z_scored_prior() + + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + self.prec_m_prod_prior = torch.mv( + self._maybe_z_scored_prior.precision_matrix, # type: ignore + self._maybe_z_scored_prior.loc, # type: ignore + ) + + def _set_maybe_z_scored_prior(self) -> None: + r"""Compute and store potentially standardized prior (if z-scoring was done). + + The proposal posterior is: + $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ + + Let's denote z-scored theta by `a`: a = (theta - mean) / std + Then pp'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ + + The ' indicates that the evaluation occurs in standardized space. The constant + scaling factor has been absorbed into Z_2. + From the above equation, we see that we need to evaluate the prior **in + standardized space**. We build the standardized prior in this function. + + The standardize transform that is applied to the samples theta does not use + the exact prior mean and std (due to implementation issues). Hence, the z-scored + prior will not be exactly have mean=0 and std=1. + """ + + if self.z_score_theta: + scale = self._neural_net._transform._transforms[0]._scale + shift = self._neural_net._transform._transforms[0]._shift + + # Following the definintion of the linear transform in + # `standardizing_transform` in `sbiutils.py`: + # shift=-mean / std + # scale=1 / std + # Solving these equations for mean and std: + estim_prior_std = 1 / scale + estim_prior_mean = -shift * estim_prior_std + + # Compute the discrepancy of the true prior mean and std and the mean and + # std that was empirically estimated from samples. + # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) + # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean + # and std (estimated from samples and used to build standardize transform). + almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std + almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std + + if isinstance(self._prior, MultivariateNormal): + self._maybe_z_scored_prior = MultivariateNormal( + almost_zero_mean, torch.diag(almost_one_std) + ) + else: + range_ = torch.sqrt(almost_one_std * 3.0) + self._maybe_z_scored_prior = utils.BoxUniform( + almost_zero_mean - range_, almost_zero_mean + range_ + ) + else: + self._maybe_z_scored_prior = self._prior + + def _log_prob_proposal_posterior( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: DirectPosterior, + ) -> Tensor: + """Return the log-probability of the proposal posterior. + + If the proposal is a MoG, the density estimator is a MoG, and the prior is + either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which + suffers from leakage). + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + proposal: Proposal distribution. + + Returns: Log-probability of the proposal posterior. + """ + + if self.use_non_atomic_loss: + return self._log_prob_proposal_posterior_mog(theta, x, proposal) + else: + return self._log_prob_proposal_posterior_atomic(theta, x, masks) + + def _log_prob_proposal_posterior_atomic( + self, theta: Tensor, x: Tensor, masks: Tensor + ): + """Return log probability of the proposal posterior for atomic proposals. + + We have two main options when evaluating the proposal posterior. + (1) Generate atoms from the proposal prior. + (2) Generate atoms from a more targeted distribution, such as the most + recent posterior. + If we choose the latter, it is likely beneficial not to do this in the first + round, since we would be sampling from a randomly-initialized neural density + estimator. + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + + Returns: + Log-probability of the proposal posterior. + """ + + batch_size = theta.shape[0] + + num_atoms = int( + clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size) + ) + + # Each set of parameter atoms is evaluated using the same x, + # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2] + repeated_x = repeat_rows(x, num_atoms) + + # To generate the full set of atoms for a given item in the batch, + # we sample without replacement num_atoms - 1 times from the rest + # of the theta in the batch. + probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) + + choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) + contrasting_theta = theta[choices] + + # We can now create our sets of atoms from the contrasting parameter sets + # we have generated. + atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( + batch_size * num_atoms, -1 + ) + + # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. + log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x) + utils.assert_all_finite(log_prob_posterior, "posterior eval") + log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) + + # Get (batch_size * num_atoms) log prob prior evals. + log_prob_prior = self._prior.log_prob(atomic_theta) + log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) + utils.assert_all_finite(log_prob_prior, "prior eval") + + # Compute unnormalized proposal posterior. + unnormalized_log_prob = log_prob_posterior - log_prob_prior + + # Normalize proposal posterior across discrete set of atoms. + log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp( + unnormalized_log_prob, dim=-1 + ) + utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") + + # XXX This evaluates the posterior on _all_ prior samples + if self._use_combined_loss: + log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x) + masks = masks.reshape(-1) + log_prob_proposal_posterior = ( + masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior + ) + + return log_prob_proposal_posterior + + def _log_prob_proposal_posterior_mog( + self, theta: Tensor, x: Tensor, proposal: DirectPosterior + ) -> Tensor: + """Return log-probability of the proposal posterior for MoG proposal. + + For MoG proposals and MoG density estimators, this can be done in closed form + and does not require atomic loss (i.e. there will be no leakage issues). + + Notation: + + m are mean vectors. + prec are precision matrices. + cov are covariance matrices. + + _p at the end indicates that it is the proposal. + _d indicates that it is the density estimator. + _pp indicates the proposal posterior. + + All tensors will have shapes (batch_dim, num_components, ...) + + Args: + theta: Batch of parameters θ. + x: Batch of data. + proposal: Proposal distribution. + + Returns: + Log-probability of the proposal posterior. + """ + + # Evaluate the proposal. MDNs do not have functionality to run the embedding_net + # and then get the mixture_components (**without** calling log_prob()). Hence, + # we call them separately here. + encoded_x = proposal.posterior_estimator._embedding_net(proposal.default_x) + dist = ( + proposal.posterior_estimator._distribution + ) # defined to avoid ugly black formatting. + logits_p, m_p, prec_p, _, _ = dist.get_mixture_components(encoded_x) + norm_logits_p = logits_p - torch.logsumexp(logits_p, dim=-1, keepdim=True) + + # Evaluate the density estimator. + encoded_x = self._neural_net._embedding_net(x) + dist = self._neural_net._distribution # defined to avoid black formatting. + logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) + norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) + + # z-score theta if it z-scoring had been requested. + theta = self._maybe_z_score_theta(theta) + + # Compute the MoG parameters of the proposal posterior. + logits_pp, m_pp, prec_pp, cov_pp = self._automatic_posterior_transformation( + norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d + ) + + # Compute the log_prob of theta under the product. + log_prob_proposal_posterior = utils.mog_log_prob( + theta, logits_pp, m_pp, prec_pp + ) + utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") + + return log_prob_proposal_posterior + + def _automatic_posterior_transformation( + self, + logits_p: Tensor, + means_p: Tensor, + precisions_p: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Returns the MoG parameters of the proposal posterior. + + The proposal posterior is: + $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ + In words: proposal posterior = posterior estimate * proposal / prior. + + If the posterior estimate and the proposal are MoG and the prior is either + Gaussian or uniform, we can solve this in closed-form. The is implemented in + this function. + + This function implements Appendix A1 from Greenberg et al. 2019. + + We have to build L*K components. How do we do this? + Example: proposal has two components, density estimator has three components. + Let's call the two components of the proposal i,j and the three components + of the density estimator x,y,z. We have to multiply every component of the + proposal with every component of the density estimator. So, what we do is: + 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave() + 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat() + 3) Multiply them with simple matrix operations. + + Args: + logits_p: Component weight of each Gaussian of the proposal. + means_p: Mean of each Gaussian of the proposal. + precisions_p: Precision matrix of each Gaussian of the proposal. + logits_d: Component weight for each Gaussian of the density estimator. + means_d: Mean of each Gaussian of the density estimator. + precisions_d: Precision matrix of each Gaussian of the density estimator. + + Returns: (Component weight, mean, precision matrix, covariance matrix) of each + Gaussian of the proposal posterior. Has L*K terms (proposal has L terms, + density estimator has K terms). + """ + + precisions_pp, covariances_pp = self._precisions_proposal_posterior( + precisions_p, precisions_d + ) + + means_pp = self._means_proposal_posterior( + covariances_pp, means_p, precisions_p, means_d, precisions_d + ) + + logits_pp = self._logits_proposal_posterior( + means_pp, + precisions_pp, + covariances_pp, + logits_p, + means_p, + precisions_p, + logits_d, + means_d, + precisions_d, + ) + + return logits_pp, means_pp, precisions_pp, covariances_pp + + def _precisions_proposal_posterior( + self, precisions_p: Tensor, precisions_d: Tensor + ): + """Return the precisions and covariances of the proposal posterior. + + Args: + precisions_p: Precision matrices of the proposal distribution. + precisions_d: Precision matrices of the density estimator. + + Returns: (Precisions, Covariances) of the proposal posterior. L*K terms. + """ + + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1) + precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) + + precisions_pp = precisions_p_rep + precisions_d_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + precisions_pp -= self._maybe_z_scored_prior.precision_matrix + + covariances_pp = torch.inverse(precisions_pp) + + return precisions_pp, covariances_pp + + def _means_proposal_posterior( + self, + covariances_pp: Tensor, + means_p: Tensor, + precisions_p: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + """Return the means of the proposal posterior. + + means_pp = C_ix * (P_i * m_i + P_x * m_x - P_o * m_o). + + Args: + covariances_pp: Covariance matrices of the proposal posterior. + means_p: Means of the proposal distribution. + precisions_p: Precision matrices of the proposal distribution. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Means of the proposal posterior. L*K terms. + """ + + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + # First, compute the product P_i * m_i and P_j * m_j + prec_m_prod_p = batched_mixture_mv(precisions_p, means_p) + prec_m_prod_d = batched_mixture_mv(precisions_d, means_d) + + # Repeat them to allow for matrix operations: same trick as for the precisions. + prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1) + prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1) + + # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o). + summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + summed_cov_m_prod_rep -= self.prec_m_prod_prior + + means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep) + + return means_pp + + @staticmethod + def _logits_proposal_posterior( + means_pp: Tensor, + precisions_pp: Tensor, + covariances_pp: Tensor, + logits_p: Tensor, + means_p: Tensor, + precisions_p: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + """Return the component weights (i.e. logits) of the proposal posterior. + + Args: + means_pp: Means of the proposal posterior. + precisions_pp: Precision matrices of the proposal posterior. + covariances_pp: Covariance matrices of the proposal posterior. + logits_p: Component weights (i.e. logits) of the proposal distribution. + means_p: Means of the proposal distribution. + precisions_p: Precision matrices of the proposal distribution. + logits_d: Component weights (i.e. logits) of the density estimator. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Component weights of the proposal posterior. L*K terms. + """ + + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute log(alpha_i * beta_j) + logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) + logits_d_rep = logits_d.repeat(1, num_comps_p) + logit_factors = logits_p_rep + logits_d_rep + + # Compute sqrt(det()/(det()*det())) + logdet_covariances_pp = torch.logdet(covariances_pp) + logdet_covariances_p = -torch.logdet(precisions_p) + logdet_covariances_d = -torch.logdet(precisions_d) + + # Repeat the proposal and density estimator terms such that there are LK terms. + # Same trick as has been used above. + logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( + num_comps_d, dim=1 + ) + logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) + + log_sqrt_det_ratio = 0.5 * ( + logdet_covariances_pp + - (logdet_covariances_p_rep + logdet_covariances_d_rep) + ) + + # Compute for proposal, density estimator, and proposal posterior: + # mu_i.T * P_i * mu_i + exponent_p = batched_mixture_vmv(precisions_p, means_p) + exponent_d = batched_mixture_vmv(precisions_d, means_d) + exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) + + # Extend proposal and density estimator exponents to get LK terms. + exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) + exponent_d_rep = exponent_d.repeat(1, num_comps_p) + exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) + + logits_pp = logit_factors + log_sqrt_det_ratio + exponent + + return logits_pp + + def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: + """Return potentially standardized theta if z-scoring was requested.""" + + if self.z_score_theta: + theta, _ = self._neural_net._transform(theta) + + return theta diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index b43bd0144..2d84aa759 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -1,583 +1,583 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . - -from __future__ import annotations - -import numpy as np -import pytest -import torch -from scipy.stats import gaussian_kde -from torch import eye, ones, zeros -from torch.distributions import MultivariateNormal - -from sbi import analysis as analysis -from sbi import utils as utils -from sbi.analysis import ConditionedMDN, conditonal_potential -from sbi.inference import ( - SNPE_A, - SNPE_B, - SNPE_C, - DirectPosterior, - MCMCPosterior, - RejectionPosterior, - posterior_estimator_based_potential, - prepare_for_sbi, - simulate_for_sbi, -) -from sbi.simulators.linear_gaussian import ( - linear_gaussian, - samples_true_posterior_linear_gaussian_mvn_prior_different_dims, - samples_true_posterior_linear_gaussian_uniform_prior, - true_posterior_linear_gaussian_mvn_prior, -) -from tests.sbiutils_test import conditional_of_mvn -from tests.test_utils import ( - check_c2st, - get_dkl_gaussian_prior, - get_normalization_uniform_prior, - get_prob_outside_uniform_prior, -) - - -@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) -@pytest.mark.parametrize( - "num_dim, prior_str, num_trials", - ( - (2, "gaussian", 1), - (2, "uniform", 1), - (1, "gaussian", 1), - # no iid x in snpe. - pytest.param(1, "gaussian", 2, marks=pytest.mark.xfail), - pytest.param(2, "gaussian", 2, marks=pytest.mark.xfail), - ), -) -def test_c2st_snpe_on_linearGaussian( - snpe_method, num_dim: int, prior_str: str, num_trials: int -): - """Test whether SNPE infers well a simple example with available ground truth.""" - - x_o = zeros(num_trials, num_dim) - num_samples = 1000 - num_simulations = 2600 - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - if prior_str == "gaussian": - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - gt_posterior = true_posterior_linear_gaussian_mvn_prior( - x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - target_samples = gt_posterior.sample((num_samples,)) - else: - prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) - target_samples = samples_true_posterior_linear_gaussian_uniform_prior( - x_o, likelihood_shift, likelihood_cov, prior=prior, num_samples=num_samples - ) - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior - ) - - inference = snpe_method(prior, show_progress_bars=False) - - theta, x = simulate_for_sbi( - simulator, prior, num_simulations, simulation_batch_size=1000 - ) - posterior_estimator = inference.append_simulations(theta, x).train( - training_batch_size=100 - ) - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - samples = posterior.sample((num_samples,)) - - # Compute the c2st and assert it is near chance level of 0.5. - check_c2st(samples, target_samples, alg="snpe_c") - - map_ = posterior.map(num_init_samples=1_000, show_progress_bars=False) - - # Checks for log_prob() - if prior_str == "gaussian": - # For the Gaussian prior, we compute the KLd between ground truth and posterior. - dkl = get_dkl_gaussian_prior( - posterior, x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - - max_dkl = 0.15 - - assert ( - dkl < max_dkl - ), f"D-KL={dkl} is more than 2 stds above the average performance." - - assert ((map_ - gt_posterior.mean) ** 2).sum() < 0.5 - - elif prior_str == "uniform": - # Check whether the returned probability outside of the support is zero. - posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim) - assert ( - posterior_prob == 0.0 - ), "The posterior probability outside of the prior support is not zero" - - # Check whether normalization (i.e. scaling up the density due - # to leakage into regions without prior support) scales up the density by the - # correct factor. - ( - posterior_likelihood_unnorm, - posterior_likelihood_norm, - acceptance_prob, - ) = get_normalization_uniform_prior(posterior, prior, x=x_o) - # The acceptance probability should be *exactly* the ratio of the unnormalized - # and the normalized likelihood. However, we allow for an error margin of 1%, - # since the estimation of the acceptance probability is random (based on - # rejection sampling). - assert ( - acceptance_prob * 0.99 - < posterior_likelihood_unnorm / posterior_likelihood_norm - < acceptance_prob * 1.01 - ), "Normalizing the posterior density using the acceptance probability failed." - - assert ((map_ - ones(num_dim)) ** 2).sum() < 0.5 - - -def test_c2st_snpe_on_linearGaussian_different_dims(): - """Test whether SNPE B/C infer well a simple example with available ground truth. - - This example has different number of parameters theta than number of x. Also - this implicitly tests simulation_batch_size=1. It also impleictly tests whether the - prior can be `None` and whether we can stop and resume training. - - """ - - theta_dim = 3 - x_dim = 2 - discard_dims = theta_dim - x_dim - - x_o = zeros(1, x_dim) - num_samples = 1000 - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(x_dim) - likelihood_cov = 0.3 * eye(x_dim) - - prior_mean = zeros(theta_dim) - prior_cov = eye(theta_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - target_samples = samples_true_posterior_linear_gaussian_mvn_prior_different_dims( - x_o, - likelihood_shift, - likelihood_cov, - prior_mean, - prior_cov, - num_discarded_dims=discard_dims, - num_samples=num_samples, - ) - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian( - theta, likelihood_shift, likelihood_cov, num_discarded_dims=discard_dims - ), - prior, - ) - # Test whether prior can be `None`. - inference = SNPE_C(prior=None, density_estimator="maf", show_progress_bars=False) - - # type: ignore - theta, x = simulate_for_sbi(simulator, prior, 2000, simulation_batch_size=1) - - inference = inference.append_simulations(theta, x) - posterior_estimator = inference.train( - max_num_epochs=10 - ) # Test whether we can stop and resume. - posterior_estimator = inference.train( - resume_training=True, force_first_round_loss=True - ) - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - samples = posterior.sample((num_samples,)) - - # Compute the c2st and assert it is near chance level of 0.5. - check_c2st(samples, target_samples, alg="snpe_c") - - -# Test multi-round SNPE. -@pytest.mark.slow -@pytest.mark.parametrize( - "method_str", - ( - "snpe_a", - pytest.param( - "snpe_b", - marks=pytest.mark.xfail( - raises=NotImplementedError, reason="""SNPE-B not implemented""" - ), - ), - "snpe_c", - "snpe_c_non_atomic", - ), -) -def test_c2st_multi_round_snpe_on_linearGaussian(method_str: str): - """Test whether SNPE B/C infer well a simple example with available ground truth. - . - """ - - num_dim = 2 - x_o = zeros((1, num_dim)) - num_samples = 1000 - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - - gt_posterior = true_posterior_linear_gaussian_mvn_prior( - x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - target_samples = gt_posterior.sample((num_samples,)) - - if method_str == "snpe_c_non_atomic": - # Test whether SNPE works properly with structured z-scoring. - density_estimator = utils.posterior_nn( - "mdn", z_score_x="structured", num_components=5 - ) - method_str = "snpe_c" - elif method_str == "snpe_a": - density_estimator = "mdn_snpe_a" - else: - density_estimator = "maf" - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior - ) - creation_args = dict( - prior=prior, - density_estimator=density_estimator, - show_progress_bars=False, - ) - - if method_str == "snpe_b": - inference = SNPE_B(**creation_args) - theta, x = simulate_for_sbi(simulator, prior, 500, simulation_batch_size=10) - posterior_estimator = inference.append_simulations(theta, x).train() - posterior1 = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - theta, x = simulate_for_sbi( - simulator, posterior1, 1000, simulation_batch_size=10 - ) - posterior_estimator = inference.append_simulations( - theta, x, proposal=posterior1 - ).train() - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - elif method_str == "snpe_c": - inference = SNPE_C(**creation_args) - theta, x = simulate_for_sbi(simulator, prior, 900, simulation_batch_size=50) - posterior_estimator = inference.append_simulations(theta, x).train() - posterior1 = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - theta = posterior1.sample((1000,)) - x = simulator(theta) - _ = inference.append_simulations(theta, x, proposal=posterior1).train() - posterior = inference.build_posterior().set_default_x(x_o) - elif method_str == "snpe_a": - inference = SNPE_A(**creation_args) - proposal = prior - final_round = False - num_rounds = 3 - for r in range(num_rounds): - if r == 2: - final_round = True - theta, x = simulate_for_sbi( - simulator, proposal, 500, simulation_batch_size=50 - ) - inference = inference.append_simulations(theta, x, proposal=proposal) - _ = inference.train(max_num_epochs=200, final_round=final_round) - posterior = inference.build_posterior().set_default_x(x_o) - proposal = posterior - - samples = posterior.sample((num_samples,)) - - # Compute the c2st and assert it is near chance level of 0.5. - check_c2st(samples, target_samples, alg=method_str) - - -# Testing rejection and mcmc sampling methods. -@pytest.mark.slow -@pytest.mark.parametrize( - "sample_with, mcmc_method, prior_str", - ( - ("mcmc", "slice_np", "gaussian"), - ("mcmc", "slice", "gaussian"), - # XXX (True, "slice", "uniform"), - # XXX takes very long. fix when refactoring pyro sampling - ("rejection", "rejection", "uniform"), - ), -) -def test_api_snpe_c_posterior_correction(sample_with, mcmc_method, prior_str): - """Test that leakage correction applied to sampling works, with both MCMC and - rejection. - - """ - - num_dim = 2 - x_o = zeros(1, num_dim) - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - if prior_str == "gaussian": - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - else: - prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior - ) - inference = SNPE_C(prior, show_progress_bars=False) - - theta, x = simulate_for_sbi(simulator, prior, 1000) - posterior_estimator = inference.append_simulations(theta, x).train() - potential_fn, theta_transform = posterior_estimator_based_potential( - posterior_estimator, prior, x_o - ) - if sample_with == "mcmc": - posterior = MCMCPosterior( - potential_fn=potential_fn, - theta_transform=theta_transform, - proposal=prior, - method=mcmc_method, - ) - elif sample_with == "rejection": - posterior = RejectionPosterior( - potential_fn=potential_fn, - proposal=prior, - theta_transform=theta_transform, - ) - - # Posterior should be corrected for leakage even if num_rounds just 1. - samples = posterior.sample((10,)) - - # Evaluate the samples to check correction factor. - _ = posterior.log_prob(samples) - - -@pytest.mark.slow -def test_sample_conditional(): - """ - Test whether sampling from the conditional gives the same results as evaluating. - - This compares samples that get smoothed with a Gaussian kde to evaluating the - conditional log-probability with `eval_conditional_density`. - - `eval_conditional_density` is itself tested in `sbiutils_test.py`. Here, we use - a bimodal posterior to test the conditional. - """ - - num_dim = 3 - dim_to_sample_1 = 0 - dim_to_sample_2 = 2 - - x_o = zeros(1, num_dim) - - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.1 * eye(num_dim) - - prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) - - def simulator(theta): - if torch.rand(1) > 0.5: - return linear_gaussian(theta, likelihood_shift, likelihood_cov) - else: - return linear_gaussian(theta, -likelihood_shift, likelihood_cov) - - # Test whether SNPE works properly with structured z-scoring. - net = utils.posterior_nn("maf", z_score_x="structured", hidden_features=20) - - simulator, prior = prepare_for_sbi(simulator, prior) - - inference = SNPE_C(prior, density_estimator=net, show_progress_bars=False) - - # We need a pretty big dataset to properly model the bimodality. - theta, x = simulate_for_sbi(simulator, prior, 10000) - posterior_estimator = inference.append_simulations(theta, x).train( - max_num_epochs=60 - ) - - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - samples = posterior.sample((50,)) - - # Evaluate the conditional density be drawing samples and smoothing with a Gaussian - # kde. - potential_fn, theta_transform = posterior_estimator_based_potential( - posterior_estimator, prior=prior, x_o=x_o - ) - (conditioned_potential_fn, restricted_tf, restricted_prior,) = conditonal_potential( - potential_fn=potential_fn, - theta_transform=theta_transform, - prior=prior, - condition=samples[0], - dims_to_sample=[dim_to_sample_1, dim_to_sample_2], - ) - mcmc_posterior = MCMCPosterior( - potential_fn=conditioned_potential_fn, - theta_transform=restricted_tf, - proposal=restricted_prior, - ) - cond_samples = mcmc_posterior.sample((500,)) - - _ = analysis.pairplot( - cond_samples, - limits=[[-2, 2], [-2, 2], [-2, 2]], - figsize=(2, 2), - diag="kde", - upper="kde", - ) - - limits = [[-2, 2], [-2, 2], [-2, 2]] - - density = gaussian_kde(cond_samples.numpy().T, bw_method="scott") - - X, Y = np.meshgrid( - np.linspace(limits[0][0], limits[0][1], 50), - np.linspace(limits[1][0], limits[1][1], 50), - ) - positions = np.vstack([X.ravel(), Y.ravel()]) - sample_kde_grid = np.reshape(density(positions).T, X.shape) - - # Evaluate the conditional with eval_conditional_density. - eval_grid = analysis.eval_conditional_density( - posterior, - condition=samples[0], - dim1=dim_to_sample_1, - dim2=dim_to_sample_2, - limits=torch.tensor([[-2, 2], [-2, 2], [-2, 2]]), - ) - - # Compare the two densities. - sample_kde_grid = sample_kde_grid / np.sum(sample_kde_grid) - eval_grid = eval_grid / torch.sum(eval_grid) - - error = np.abs(sample_kde_grid - eval_grid.numpy()) - - max_err = np.max(error) - assert max_err < 0.0027 - - -def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): - """Test whether the conditional density infered from MDN parameters of a - `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and - conditions on the last m values to generate a conditional. - - Gaussian prior used for easier ground truthing of conditional posterior. - - Args: - num_dim: Dimensionality of the MVM. - cond_dim: Dimensionality of the condition. - """ - - assert ( - num_dim > cond_dim - ), "The number of dimensions needs to be greater than that of the condition!" - - x_o = zeros(1, num_dim) - num_samples = 1000 - num_simulations = 2700 - condition = 0.1 * ones(1, num_dim) - - dims = list(range(num_dim)) - dims2sample = dims[-cond_dim:] - dims2condition = dims[:-cond_dim] - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - - joint_posterior = true_posterior_linear_gaussian_mvn_prior( - x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - joint_cov = joint_posterior.covariance_matrix - joint_mean = joint_posterior.loc - - conditional_mean, conditional_cov = conditional_of_mvn( - joint_mean, joint_cov, condition[0, dims2condition] - ) - conditional_dist_gt = MultivariateNormal(conditional_mean, conditional_cov) - - conditional_samples_gt = conditional_dist_gt.sample((num_samples,)) - - def simulator(theta): - return linear_gaussian(theta, likelihood_shift, likelihood_cov) - - simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE_C(density_estimator="mdn", show_progress_bars=False) - - theta, x = simulate_for_sbi( - simulator, prior, num_simulations, simulation_batch_size=1000 - ) - posterior_mdn = inference.append_simulations(theta, x).train( - training_batch_size=100 - ) - conditioned_mdn = ConditionedMDN( - posterior_mdn, x_o, condition=condition, dims_to_sample=[0] - ) - conditional_samples_sbi = conditioned_mdn.sample((num_samples,)) - check_c2st( - conditional_samples_sbi, - conditional_samples_gt, - alg="analytic_mdn_conditioning_of_direct_posterior", - ) - - -@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) -def test_example_posterior(snpe_method: type): - """Return an inferred `NeuralPosterior` for interactive examination.""" - num_dim = 2 - x_o = zeros(1, num_dim) - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - - if snpe_method == SNPE_A: - extra_kwargs = dict(final_round=True) - else: - extra_kwargs = dict() - - simulator, prior = prepare_for_sbi( - lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior - ) - inference = snpe_method(prior, show_progress_bars=False) - - theta, x = simulate_for_sbi( - simulator, prior, 1000, simulation_batch_size=10, num_workers=6 - ) - posterior_estimator = inference.append_simulations(theta, x).train(**extra_kwargs) - if snpe_method == SNPE_A: - posterior_estimator = inference.correct_for_proposal() - posterior = DirectPosterior( - prior=prior, posterior_estimator=posterior_estimator - ).set_default_x(x_o) - assert posterior is not None +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from __future__ import annotations + +import numpy as np +import pytest +import torch +from scipy.stats import gaussian_kde +from torch import eye, ones, zeros +from torch.distributions import MultivariateNormal + +from sbi import analysis as analysis +from sbi import utils as utils +from sbi.analysis import ConditionedMDN, conditonal_potential +from sbi.inference import ( + SNPE_A, + SNPE_B, + SNPE_C, + DirectPosterior, + MCMCPosterior, + RejectionPosterior, + posterior_estimator_based_potential, + prepare_for_sbi, + simulate_for_sbi, +) +from sbi.simulators.linear_gaussian import ( + linear_gaussian, + samples_true_posterior_linear_gaussian_mvn_prior_different_dims, + samples_true_posterior_linear_gaussian_uniform_prior, + true_posterior_linear_gaussian_mvn_prior, +) +from tests.sbiutils_test import conditional_of_mvn +from tests.test_utils import ( + check_c2st, + get_dkl_gaussian_prior, + get_normalization_uniform_prior, + get_prob_outside_uniform_prior, +) + + +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +@pytest.mark.parametrize( + "num_dim, prior_str, num_trials", + ( + (2, "gaussian", 1), + (2, "uniform", 1), + (1, "gaussian", 1), + # no iid x in snpe. + pytest.param(1, "gaussian", 2, marks=pytest.mark.xfail), + pytest.param(2, "gaussian", 2, marks=pytest.mark.xfail), + ), +) +def test_c2st_snpe_on_linearGaussian( + snpe_method, num_dim: int, prior_str: str, num_trials: int +): + """Test whether SNPE infers well a simple example with available ground truth.""" + + x_o = zeros(num_trials, num_dim) + num_samples = 1000 + num_simulations = 2600 + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + if prior_str == "gaussian": + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + gt_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + target_samples = gt_posterior.sample((num_samples,)) + else: + prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) + target_samples = samples_true_posterior_linear_gaussian_uniform_prior( + x_o, likelihood_shift, likelihood_cov, prior=prior, num_samples=num_samples + ) + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior + ) + + inference = snpe_method(prior, show_progress_bars=False) + + theta, x = simulate_for_sbi( + simulator, prior, num_simulations, simulation_batch_size=1000 + ) + posterior_estimator = inference.append_simulations(theta, x).train( + training_batch_size=100 + ) + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + samples = posterior.sample((num_samples,)) + + # Compute the c2st and assert it is near chance level of 0.5. + check_c2st(samples, target_samples, alg="snpe_c") + + map_ = posterior.map(num_init_samples=1_000, show_progress_bars=False) + + # Checks for log_prob() + if prior_str == "gaussian": + # For the Gaussian prior, we compute the KLd between ground truth and posterior. + dkl = get_dkl_gaussian_prior( + posterior, x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + + max_dkl = 0.15 + + assert ( + dkl < max_dkl + ), f"D-KL={dkl} is more than 2 stds above the average performance." + + assert ((map_ - gt_posterior.mean) ** 2).sum() < 0.5 + + elif prior_str == "uniform": + # Check whether the returned probability outside of the support is zero. + posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim) + assert ( + posterior_prob == 0.0 + ), "The posterior probability outside of the prior support is not zero" + + # Check whether normalization (i.e. scaling up the density due + # to leakage into regions without prior support) scales up the density by the + # correct factor. + ( + posterior_likelihood_unnorm, + posterior_likelihood_norm, + acceptance_prob, + ) = get_normalization_uniform_prior(posterior, prior, x=x_o) + # The acceptance probability should be *exactly* the ratio of the unnormalized + # and the normalized likelihood. However, we allow for an error margin of 1%, + # since the estimation of the acceptance probability is random (based on + # rejection sampling). + assert ( + acceptance_prob * 0.99 + < posterior_likelihood_unnorm / posterior_likelihood_norm + < acceptance_prob * 1.01 + ), "Normalizing the posterior density using the acceptance probability failed." + + assert ((map_ - ones(num_dim)) ** 2).sum() < 0.5 + + +def test_c2st_snpe_on_linearGaussian_different_dims(): + """Test whether SNPE B/C infer well a simple example with available ground truth. + + This example has different number of parameters theta than number of x. Also + this implicitly tests simulation_batch_size=1. It also impleictly tests whether the + prior can be `None` and whether we can stop and resume training. + + """ + + theta_dim = 3 + x_dim = 2 + discard_dims = theta_dim - x_dim + + x_o = zeros(1, x_dim) + num_samples = 1000 + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(x_dim) + likelihood_cov = 0.3 * eye(x_dim) + + prior_mean = zeros(theta_dim) + prior_cov = eye(theta_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + target_samples = samples_true_posterior_linear_gaussian_mvn_prior_different_dims( + x_o, + likelihood_shift, + likelihood_cov, + prior_mean, + prior_cov, + num_discarded_dims=discard_dims, + num_samples=num_samples, + ) + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian( + theta, likelihood_shift, likelihood_cov, num_discarded_dims=discard_dims + ), + prior, + ) + # Test whether prior can be `None`. + inference = SNPE_C(prior=None, density_estimator="maf", show_progress_bars=False) + + # type: ignore + theta, x = simulate_for_sbi(simulator, prior, 2000, simulation_batch_size=1) + + inference = inference.append_simulations(theta, x) + posterior_estimator = inference.train( + max_num_epochs=10 + ) # Test whether we can stop and resume. + posterior_estimator = inference.train( + resume_training=True, force_first_round_loss=True + ) + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + samples = posterior.sample((num_samples,)) + + # Compute the c2st and assert it is near chance level of 0.5. + check_c2st(samples, target_samples, alg="snpe_c") + + +# Test multi-round SNPE. +@pytest.mark.slow +@pytest.mark.parametrize( + "method_str", + ( + "snpe_a", + pytest.param( + "snpe_b", + marks=pytest.mark.xfail( + raises=NotImplementedError, reason="""SNPE-B not implemented""" + ), + ), + "snpe_c", + "snpe_c_non_atomic", + ), +) +def test_c2st_multi_round_snpe_on_linearGaussian(method_str: str): + """Test whether SNPE B/C infer well a simple example with available ground truth. + . + """ + + num_dim = 2 + x_o = zeros((1, num_dim)) + num_samples = 1000 + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + gt_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + target_samples = gt_posterior.sample((num_samples,)) + + if method_str == "snpe_c_non_atomic": + # Test whether SNPE works properly with structured z-scoring. + density_estimator = utils.posterior_nn( + "mdn", z_score_x="structured", num_components=5 + ) + method_str = "snpe_c" + elif method_str == "snpe_a": + density_estimator = "mdn_snpe_a" + else: + density_estimator = "maf" + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior + ) + creation_args = dict( + prior=prior, + density_estimator=density_estimator, + show_progress_bars=False, + ) + + if method_str == "snpe_b": + inference = SNPE_B(**creation_args) + theta, x = simulate_for_sbi(simulator, prior, 500, simulation_batch_size=10) + posterior_estimator = inference.append_simulations(theta, x).train() + posterior1 = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + theta, x = simulate_for_sbi( + simulator, posterior1, 1000, simulation_batch_size=10 + ) + posterior_estimator = inference.append_simulations( + theta, x, proposal=posterior1 + ).train() + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + elif method_str == "snpe_c": + inference = SNPE_C(**creation_args) + theta, x = simulate_for_sbi(simulator, prior, 900, simulation_batch_size=50) + posterior_estimator = inference.append_simulations(theta, x).train() + posterior1 = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + theta = posterior1.sample((1000,)) + x = simulator(theta) + _ = inference.append_simulations(theta, x, proposal=posterior1).train() + posterior = inference.build_posterior().set_default_x(x_o) + elif method_str == "snpe_a": + inference = SNPE_A(**creation_args) + proposal = prior + final_round = False + num_rounds = 3 + for r in range(num_rounds): + if r == 2: + final_round = True + theta, x = simulate_for_sbi( + simulator, proposal, 500, simulation_batch_size=50 + ) + inference = inference.append_simulations(theta, x, proposal=proposal) + _ = inference.train(max_num_epochs=200, final_round=final_round) + posterior = inference.build_posterior().set_default_x(x_o) + proposal = posterior + + samples = posterior.sample((num_samples,)) + + # Compute the c2st and assert it is near chance level of 0.5. + check_c2st(samples, target_samples, alg=method_str) + + +# Testing rejection and mcmc sampling methods. +@pytest.mark.slow +@pytest.mark.parametrize( + "sample_with, mcmc_method, prior_str", + ( + ("mcmc", "slice_np", "gaussian"), + ("mcmc", "slice", "gaussian"), + # XXX (True, "slice", "uniform"), + # XXX takes very long. fix when refactoring pyro sampling + ("rejection", "rejection", "uniform"), + ), +) +def test_api_snpe_c_posterior_correction(sample_with, mcmc_method, prior_str): + """Test that leakage correction applied to sampling works, with both MCMC and + rejection. + + """ + + num_dim = 2 + x_o = zeros(1, num_dim) + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + if prior_str == "gaussian": + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + else: + prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior + ) + inference = SNPE_C(prior, show_progress_bars=False) + + theta, x = simulate_for_sbi(simulator, prior, 1000) + posterior_estimator = inference.append_simulations(theta, x).train() + potential_fn, theta_transform = posterior_estimator_based_potential( + posterior_estimator, prior, x_o + ) + if sample_with == "mcmc": + posterior = MCMCPosterior( + potential_fn=potential_fn, + theta_transform=theta_transform, + proposal=prior, + method=mcmc_method, + ) + elif sample_with == "rejection": + posterior = RejectionPosterior( + potential_fn=potential_fn, + proposal=prior, + theta_transform=theta_transform, + ) + + # Posterior should be corrected for leakage even if num_rounds just 1. + samples = posterior.sample((10,)) + + # Evaluate the samples to check correction factor. + _ = posterior.log_prob(samples) + + +@pytest.mark.slow +def test_sample_conditional(): + """ + Test whether sampling from the conditional gives the same results as evaluating. + + This compares samples that get smoothed with a Gaussian kde to evaluating the + conditional log-probability with `eval_conditional_density`. + + `eval_conditional_density` is itself tested in `sbiutils_test.py`. Here, we use + a bimodal posterior to test the conditional. + """ + + num_dim = 3 + dim_to_sample_1 = 0 + dim_to_sample_2 = 2 + + x_o = zeros(1, num_dim) + + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.1 * eye(num_dim) + + prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) + + def simulator(theta): + if torch.rand(1) > 0.5: + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + else: + return linear_gaussian(theta, -likelihood_shift, likelihood_cov) + + # Test whether SNPE works properly with structured z-scoring. + net = utils.posterior_nn("maf", z_score_x="structured", hidden_features=20) + + simulator, prior = prepare_for_sbi(simulator, prior) + + inference = SNPE_C(prior, density_estimator=net, show_progress_bars=False) + + # We need a pretty big dataset to properly model the bimodality. + theta, x = simulate_for_sbi(simulator, prior, 10000) + posterior_estimator = inference.append_simulations(theta, x).train( + max_num_epochs=60 + ) + + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + samples = posterior.sample((50,)) + + # Evaluate the conditional density be drawing samples and smoothing with a Gaussian + # kde. + potential_fn, theta_transform = posterior_estimator_based_potential( + posterior_estimator, prior=prior, x_o=x_o + ) + (conditioned_potential_fn, restricted_tf, restricted_prior,) = conditonal_potential( + potential_fn=potential_fn, + theta_transform=theta_transform, + prior=prior, + condition=samples[0], + dims_to_sample=[dim_to_sample_1, dim_to_sample_2], + ) + mcmc_posterior = MCMCPosterior( + potential_fn=conditioned_potential_fn, + theta_transform=restricted_tf, + proposal=restricted_prior, + ) + cond_samples = mcmc_posterior.sample((500,)) + + _ = analysis.pairplot( + cond_samples, + limits=[[-2, 2], [-2, 2], [-2, 2]], + figsize=(2, 2), + diag="kde", + upper="kde", + ) + + limits = [[-2, 2], [-2, 2], [-2, 2]] + + density = gaussian_kde(cond_samples.numpy().T, bw_method="scott") + + X, Y = np.meshgrid( + np.linspace(limits[0][0], limits[0][1], 50), + np.linspace(limits[1][0], limits[1][1], 50), + ) + positions = np.vstack([X.ravel(), Y.ravel()]) + sample_kde_grid = np.reshape(density(positions).T, X.shape) + + # Evaluate the conditional with eval_conditional_density. + eval_grid = analysis.eval_conditional_density( + posterior, + condition=samples[0], + dim1=dim_to_sample_1, + dim2=dim_to_sample_2, + limits=torch.tensor([[-2, 2], [-2, 2], [-2, 2]]), + ) + + # Compare the two densities. + sample_kde_grid = sample_kde_grid / np.sum(sample_kde_grid) + eval_grid = eval_grid / torch.sum(eval_grid) + + error = np.abs(sample_kde_grid - eval_grid.numpy()) + + max_err = np.max(error) + assert max_err < 0.0027 + + +def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): + """Test whether the conditional density infered from MDN parameters of a + `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and + conditions on the last m values to generate a conditional. + + Gaussian prior used for easier ground truthing of conditional posterior. + + Args: + num_dim: Dimensionality of the MVM. + cond_dim: Dimensionality of the condition. + """ + + assert ( + num_dim > cond_dim + ), "The number of dimensions needs to be greater than that of the condition!" + + x_o = zeros(1, num_dim) + num_samples = 1000 + num_simulations = 2700 + condition = 0.1 * ones(1, num_dim) + + dims = list(range(num_dim)) + dims2sample = dims[-cond_dim:] + dims2condition = dims[:-cond_dim] + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + joint_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + joint_cov = joint_posterior.covariance_matrix + joint_mean = joint_posterior.loc + + conditional_mean, conditional_cov = conditional_of_mvn( + joint_mean, joint_cov, condition[0, dims2condition] + ) + conditional_dist_gt = MultivariateNormal(conditional_mean, conditional_cov) + + conditional_samples_gt = conditional_dist_gt.sample((num_samples,)) + + def simulator(theta): + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + + simulator, prior = prepare_for_sbi(simulator, prior) + inference = SNPE_C(density_estimator="mdn", show_progress_bars=False) + + theta, x = simulate_for_sbi( + simulator, prior, num_simulations, simulation_batch_size=1000 + ) + posterior_mdn = inference.append_simulations(theta, x).train( + training_batch_size=100 + ) + conditioned_mdn = ConditionedMDN( + posterior_mdn, x_o, condition=condition, dims_to_sample=[0] + ) + conditional_samples_sbi = conditioned_mdn.sample((num_samples,)) + check_c2st( + conditional_samples_sbi, + conditional_samples_gt, + alg="analytic_mdn_conditioning_of_direct_posterior", + ) + + +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +def test_example_posterior(snpe_method: type): + """Return an inferred `NeuralPosterior` for interactive examination.""" + num_dim = 2 + x_o = zeros(1, num_dim) + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + if snpe_method == SNPE_A: + extra_kwargs = dict(final_round=True) + else: + extra_kwargs = dict() + + simulator, prior = prepare_for_sbi( + lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior + ) + inference = snpe_method(prior, show_progress_bars=False) + + theta, x = simulate_for_sbi( + simulator, prior, 1000, simulation_batch_size=10, num_workers=6 + ) + posterior_estimator = inference.append_simulations(theta, x).train(**extra_kwargs) + if snpe_method == SNPE_A: + posterior_estimator = inference.correct_for_proposal() + posterior = DirectPosterior( + prior=prior, posterior_estimator=posterior_estimator + ).set_default_x(x_o) + assert posterior is not None diff --git a/tutorials/07_conditional_distributions.ipynb b/tutorials/07_conditional_distributions.ipynb index 2e46f8cd0..f84014a88 100644 --- a/tutorials/07_conditional_distributions.ipynb +++ b/tutorials/07_conditional_distributions.ipynb @@ -1,16588 +1,16588 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Analysing variability and compensation mechansims with conditional distributions\n", - "\n", - "A central advantage of `sbi` over parameter search methods such as genetic algorithms is that the posterior captures **all** models that can reproduce experimental data. This allows us to analyse whether parameters can be variable or have to be narrowly tuned, and to analyse compensation mechanisms between different parameters. See also [Marder and Taylor, 2011](https://www.nature.com/articles/nn.2735?page=2) for further motivation to identify all models that capture experimental data. \n", - "\n", - "In this tutorial, we will show how one can use the posterior distribution to identify whether parameters can be variable or have to be finely tuned, and how we can use the posterior to find potential compensation mechanisms between model parameters. To investigate this, we will extract **conditional distributions** from the posterior inferred with `sbi`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note, you can find the original version of this notebook at [https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb](https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb) in the `sbi` repository." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Main syntax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sbi.analysis import conditional_pairplot, conditional_corrcoeff\n", - "\n", - "# Plot slices through posterior, i.e. conditionals.\n", - "_ = conditional_pairplot(\n", - " density=posterior,\n", - " condition=posterior.sample((1,)),\n", - " limits=torch.tensor([[-2., 2.], [-2., 2.]]),\n", - ")\n", - "\n", - "# Compute the matrix of correlation coefficients of the slices.\n", - "cond_coeff_mat = conditional_corrcoeff(\n", - " density=posterior,\n", - " condition=posterior.sample((1,)),\n", - " limits=torch.tensor([[-2., 2.], [-2., 2.]]),\n", - ")\n", - "plt.imshow(cond_coeff_mat, clim=[-1, 1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Analysing variability and compensation mechanisms in a toy example\n", - "Below, we use a simple toy example to demonstrate the above described features. For an application of these features to a neuroscience problem, see figure 6 in [Gonçalves, Lueckmann, Deistler et al., 2019](https://arxiv.org/abs/1907.00770)." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from sbi import utils as utils\n", - "from sbi.analysis import pairplot, conditional_pairplot, conditional_corrcoeff\n", - "import torch\n", - "import numpy as np\n", - "\n", - "import matplotlib.pyplot as plt\n", - "from mpl_toolkits.mplot3d import Axes3D\n", - "from matplotlib import animation, rc\n", - "from IPython.display import HTML, Image\n", - "\n", - "_ = torch.manual_seed(0)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's say we have used SNPE to obtain a posterior distribution over three parameters. In this tutorial, we just load the posterior from a file:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from toy_posterior_for_07_cc import ExamplePosterior\n", - "posterior = ExamplePosterior()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we specify the experimental observation $x_o$ at which we want to evaluate and sample the posterior $p(\\theta|x_o)$:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "x_o = torch.ones(1, 20) # simulator output was 20-dimensional\n", - "posterior.set_default_x(x_o)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As always, we can inspect the posterior marginals with the `pairplot()` function:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "posterior_samples = posterior.sample((5000,))\n", - "\n", - "fig, ax = pairplot(\n", - " samples=posterior_samples,\n", - " limits=torch.tensor([[-2., 2.]]*3),\n", - " upper=['kde'],\n", - " diag=['kde'],\n", - " figsize=(5,5)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The 1D and 2D marginals of the posterior fill almost the entire parameter space! Also, the Pearson correlation coefficient matrix of the marginal shows rather weak interactions (low correlations):" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAADxCAYAAAD7hRNxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWYklEQVR4nO3df6xcZZ3H8fen97Y0u6wWKCmI2EJokBpWkAYwRGUBtRBDu4rabpTiQroY3F1FDSKJbHBJym4iiz+xgQooCygi1ohhkR+LRmEpbKH8WOQKKq2VSvmhhFK4vd/94zxThunMnXM7Z849c+7nZU46c84z85yJuV+eX+f5KiIwM+u3aZN9A2Y2NTjYmFkpHGzMrBQONmZWCgcbMyuFg42ZlcLBxqymJK2WtFnSgx2uS9KXJY1IekDS25quLZf0WDqWF3E/DjZm9XUFsGic6ycC89OxAvgGgKQ9gfOBo4AjgfMl7dHrzTjYmNVURNwJPDNOkcXAVZG5C5glaV/gvcAtEfFMRDwL3ML4QSuX4V6/wMyK85f7z4ztL43lKrvt6VceAl5qOrUqIlZNoLr9gCeb3m9I5zqd74mDjVmFjL00xgHvn52r7P+t2vRSRCzs8y0Vxt0osyoRTJumXEcBNgL7N71/YzrX6XxPHGzMKkbKdxRgDXBqmpU6Gng+IjYBNwPvkbRHGhh+TzrXE3ejzCpEwLSCmgCSrgGOBWZL2kA2wzQdICIuBW4CTgJGgBeBj6Vrz0j6InBP+qoLImK8geZcHGzMqkQwNFxMsyUilnW5HsBZHa6tBlYXciOJg41ZxaimgxsONmYVIsG0ggZkqsbBxqxi3LIxs1IUNUBcNQ42ZhUiuWVjZiUZGvKYjZn1meRulJmVQqiYRxEqx8HGrErcsjGzsniA2Mz6TvIAcVtp+8DrgHnAb4APpZ29WsttB9ant7+LiJN7qdeszurajer1Z30OuDUi5gO3pvftbI2Iw9LhQGPWgQBNU65j0PQabBYDV6bXVwJLevw+s6ktDRDnOQZNr2M2c9JmOwB/AOZ0KDdT0lpgFFgZETe2KyRpBdku72hYR8yYNXWGlA7+iwMn+xZKs23W1sm+hVI9+sDvno6IvfOWr+lzmN2DjaSfAvu0uXRe85uICEnR4WvmRsRGSQcCt0laHxG/bi2UNmteBTBz7xkxb0m+vVjr4OYjLp/sWyjNyOK2aYxq6x1vOPO3ectmm2fVM9p0DTYRcUKna5KekrRvRGxKKSA2d/iOjenfxyXdARwO7BRszKa8AjfPqppee35rgEa2vOXAD1sLpH1Md0uvZwPHAA/3WK9ZLQkxTfmOQdNrsFkJvFvSY8AJ6T2SFkq6LJU5BFgr6X7gdrIxGwcbs3YKzq4gaZGkR1OK3Z1miyVdLGldOn4l6bmma9ubrq3p9af1NAIbEVuA49ucXwuckV7/Aji0l3rMpooix2wkDQFfA95NlmjuHklrmv9jHxGfair/j2RDHA1bI+KwQm4Gp3Ixq5xpmpbryOFIYCQiHo+Il4FryZardLIMuKaAn9CWg41ZlShfFypn6yd3Gl1Jc4EDgNuaTs+UtFbSXZKW7OIv2mHqLGQxGwAChodytwFmp/VrDRPN9d1sKXB9RGxvOpdryUpeDjZmFZJtnpU72DzdJdf3RNLoLqUlh1TRS1bcjTKrlEK7UfcA8yUdIGkGWUDZaVZJ0puBPYBfNp0rfMmKWzZmVVJg3qiIGJX0CbI83UPA6oh4SNIFwNqIaASepcC1KUNmwyHANyWNkTVKel6y4mBjViFFP64QETeR5fRuPveFlvf/0uZzhS9ZcbAxqxKJoaGhyb6LvnCwMauQKf0gppmVy8HGzPpOIu/q4IHjYGNWKfkfshw0DjZmFTOI20fk4WBjViESDA97NsrM+qyxeVYdOdiYVYk8G2VmJXGwMbO+E576NrMyyFPfZlYCIYaHpk/2bfRFIe21HDu47ybpunT9bknziqjXrG4EDGko1zFoeg42TTu4nwgsAJZJWtBS7HTg2Yg4CLgYuKjXes1qSWLatKFcx6ApomWTZwf3xcCV6fX1wPFSTRcTmPVomoZyHYOmiDGbdju4H9WpTNo97HlgL+DpAuo3qw2hiexBPFAqNUAsaQWwAmB498GL3Ga9ksT0oRmTfRt9UUSwybODe6PMBknDwOuBLa1flNJQrAKYufeMaL1uVn+q7TqbIn5Vnh3c1wDL0+tTgNtaNlc2MxqpXIobIM4xU3yapD825fQ+o+nackmPpWN562cnqueWTc4d3C8Hvi1pBHiGLCCZ2U5U2LR2nlzfyXUR8YmWz+4JnA8sBAK4N3322V29n0LGbLrt4B4RLwEfLKIuszor+HGFHTPFAJIaM8V5UrK8F7glIp5Jn70FWEQPucDr2Tk0G1gTWmczO+XibhwrWr4sb67vD0h6QNL1khrjr7nzhOdVqdkos6lugrNR3dLv5vEj4JqI2CbpH8jWwx3X43e25ZaNWYU0ulF5jhy6zhRHxJaI2JbeXgYckfezE+VgY1YlxT6u0HWmWNK+TW9PBh5Jr28G3pNyfu8BvCed22XuRplVigp7FCHnTPE/SToZGCWbKT4tffYZSV8kC1gAFzQGi3eVg41ZhWQZMYvrcOSYKT4XOLfDZ1cDq4u6Fwcbs0opbp1N1TjYmFWIJIb9bJSZ9Zv3IDazckgMDeDGWHk42JhVSNaycbAxs76r7xYTDjZmFSLE8DQPEJtZv0nI3SgzK4PHbMys74SYhoONmZXALRsz6zsV+CBm1TjYmFWKGJJno8yszySvszGzktS1G1VICO0lN42ZNZNzfXfSS24aM3st4UV94+klN42ZtfA6m87a5Zc5qk25D0h6J/Ar4FMR8WRrgZT3ZgXAfnvO4Y4jvlfA7Q2GY++dOjn8PnLwoZN9C5UlFftslKRFwCVkexBfFhErW66fDZxBtgfxH4G/j4jfpmvbgfWp6O8i4uRe7qWsYe8fAfMi4q+BW8hy0+wkIlZFxMKIWLjX7rNKujWzKiluzKZpiONEYAGwTNKClmL/CyxMf5vXA//WdG1rRByWjp4CDRQTbHrJTWNmr5GN2eQ5ctgxxBERLwONIY4dIuL2iHgxvb2L7O+3L4oINr3kpjGzJiIbs8lzUFz63YbTgZ80vZ+ZvvcuSUt6/W09j9n0kpvGzFpNaFFfEel3s1qljwALgXc1nZ4bERslHQjcJml9RPx6V+soZFFfL7lpzOxVBQ8Q50qhK+kE4DzgXU3DHUTExvTv45LuAA4HdjnY1HNdtNkAE0O5jhzyDHEcDnwTODkiNjed30PSbun1bOAYelzO4scVzCqkyKe+cw5x/DuwO/A9SfDqFPchwDcljZE1Sla2Wag7IQ42ZpVS7BYTOYY4TujwuV8AhS6IcrAxqxjVdHTDwcascjTZN9AXDjZmFeI9iM2sRO5GmVkJ5G6UmfWfkLcFNbNyuGVjZn3nAWIzK427UWbWZ8IDxGZWCnkFsZmVxS0bMyuBWzZmVgLl3atm4DjYmFVINkDslo2ZlcCzUWbWfxL4cQUzK0NdWzaFhFBJqyVtlvRgh+uS9GVJI5IekPS2Iuo1q59snU2eI9e3SYskPZr+9j7X5vpukq5L1++WNK/p2rnp/KOS3tvrLyuqvXYFsGic6ycC89OxAvhGQfWa1U5R2RVypt89HXg2Ig4CLgYuSp9dQJaN4S1kf9tfV840nJ0UEmwi4k6y5HOdLAauisxdwKyWLJlmxquPK+T5Xw5d0++m91em19cDxytLs7AYuDYitkXEE8BI+r5dVtZIVK40oJJWNFKJbnnhuZJuzaxKNIGjkPS7O8pExCjwPLBXzs9OSKUGiCNiFbAK4K1z3xyTfDtm5Yt05FNY+t0ylNWyyZUG1MwCRb4jhzx/dzvKSBoGXg9syfnZCSkr2KwBTk2zUkcDz0fEppLqNhssYzmP7rqm303vl6fXpwC3RUSk80vTbNUBZJM7/9PDryqmGyXpGuBYsj7kBuB8YDpARFxKlpHvJLJBpheBjxVRr1kt5Wu15PiaXOl3Lwe+LWmEbJJnafrsQ5K+S5bfexQ4KyK293I/hQSbiFjW5XoAZxVRl1mtBajA0coc6XdfAj7Y4bMXAhcWdS+VGiA2MyYyQDxQHGzMqqagblTVONiYVU09Y42DjVmlBHmntQeOg41Z1dQz1jjYmFWOg42ZlcLdKDMrQ5HrbKrEwcasSib2IOZAcbAxq5SAsXpGGwcbswoR9e1G1XMbdzOrHLdszKrGs1Fm1nceIDazssgDxGZWinrGGgcbs0oJPPVtZmUIoqYDxJ76Nqua4jY870jSnpJukfRY+nePNmUOk/RLSQ+ltNkfbrp2haQnJK1Lx2Hd6nSwMauQCIixyHX06HPArRExH7g1vW/1InBqRDRS8P6HpFlN1z8bEYelY123Ct2NMquY2N5jsyWfxWQZUSBLv3sHcM5r7iPiV02vfy9pM7A38NyuVFhIy0bSakmbJT3Y4fqxkp5vanJ9oV05symvMUCc5+iefnc8c5pyt/0BmDNeYUlHAjOAXzedvjB1ry6WtFu3Cotq2VwBfBW4apwyP4uI9xVUn1lNTWiAeNz0u5J+CuzT5tJ5r6kxIqTOT2RJ2hf4NrA8IhrNrnPJgtQMspTZ5wAXjHezReWNulPSvCK+q2HbrK2MLG7bUKqljxx86GTfQmm+85/rJ/sWqq2gXlREnNDpmqSnJO0bEZtSMNncodzrgB8D50XEXU3f3WgVbZP0LeAz3e6nzAHit0u6X9JPJL2lXQFJKxpNwue2vFDirZlVR0TkOnrUnHZ3OfDD1gIpZe8PgKsi4vqWa/umfwUsAbq2DMoKNvcBcyPircBXgBvbFYqIVRGxMCIWztpr95JuzaxCJjZm04uVwLslPQackN4jaaGky1KZDwHvBE5rM8V9taT1wHpgNvCv3SosZTYqIv7U9PomSV+XNDsini6jfrNBUsZsVERsAY5vc34tcEZ6/R3gOx0+f9xE6ywl2EjaB3gqDUQdSdai2lJG3WaDJKKQNTSVVEiwkXQN2Zz9bEkbgPOB6QARcSlwCvBxSaPAVmBp1HVNtlmvSllmU76iZqOWdbn+VbKpcTProq7/HfYKYrMq8VPfZlaWkh5XKJ2DjVmVBB6zMbMyeDbKzMriAWIz67u0n00dOdiYVY2DjZn1W0R4NsrMyuFgY2b95zEbMyuHu1FmVoYAxhxszKzPIoKxV7ZP9m30hYONWcW4G2Vm/VfjbpQzYppVSr5smL3OWOVJv5vKbW/af3hN0/kDJN0taUTSdWlz9HE52JhVSWTdqDxHj/Kk3wXY2pRi9+Sm8xcBF0fEQcCzwOndKnSwMauQAGJsLNfRo8VkaXdJ/y7J+8GUvuU4oJHeJdfnPWZjViURRP7ZqNmS1ja9XxURq3J+Nm/63ZmpjlFgZUTcCOwFPBcRo6nMBmC/bhU62JhVSUxoNqqM9LtzI2KjpAOB21KuqOfz3mCznoONpP3JcnzPIWsFroqIS1rKCLgEOAl4ETgtIu7rtW6z+okiukjZNxWQfjciNqZ/H5d0B3A48H1glqTh1Lp5I7Cx2/0UMWYzCnw6IhYARwNnSVrQUuZEYH46VgDfKKBes/oJYHvkO3qTJ/3uHpJ2S69nA8cAD6c0TLeTpWjq+PlWPQebiNjUaKVExJ+BR9i5/7aYLF9wpOTksxq5gs3stUoaIM6TfvcQYK2k+8mCy8qIeDhdOwc4W9II2RjO5d0qLHTMRtI8smbW3S2X9gOebHrfGFDahJntUNZ+NjnT7/4COLTD5x8HjpxInYUFG0m7k/XlPtmc23uC37GCrJvFnP32LOrWzAZHMJHZqIFSyDobSdPJAs3VEXFDmyIbgf2b3rcdUIqIVRGxMCIWztpr9yJuzWzARFndqNL1HGzSTNPlwCMR8aUOxdYApypzNPB80xy/mTWUt4K4dEV0o44BPgqsl7Qunfs88CaAiLgUuIls2nuEbOr7YwXUa1ZDxU19V03PwSYifg6oS5kAzuq1LrPaa0x915BXEJtVSEQw9vJo94IDyMHGrEom9rjCQHGwMauSgBh1sDGzvnN2BTMrQbhlY2aliHCwMbMSBIxtq+fjCg42ZlVS0oOYk8HBxqxCPGZjZuXwmI2ZlcXdKDPrP3ejzKwMEcHYtno+G+UkdWZVksZs8hy9yJN+V9LfNKXeXSfpJUlL0rUrJD3RdO2wbnU62JhVSYXS70bE7Y3Uu2QZMF8E/qupyGebUvOu61ahg41ZlaQxm363bJh4+t1TgJ9ExIu7WqGDjVmllNONIn/63YalwDUt5y6U9ICkixv5pcbjAWKzComxCT2uMG6u74LS75JyvB0K3Nx0+lyyIDUDWEWWR+qC8W7WwcasUib0uMK4ub6LSL+bfAj4QUS80vTdjVbRNknfAj7T7WbdjTKrkvLGbLqm322yjJYuVCOjbcqusgR4sFuFbtmYVUl5jyusBL4r6XTgt2StFyQtBM6MiDPS+3lkOd/+u+XzV0vamyzZwTrgzG4V9hxsJO0PXEU2wBRk/cZLWsocSxY5n0inboiIcft3ZlNRlLQHcZ70u+n9b8hSZbeWO26idRbRshkFPh0R90n6K+BeSbc0JSBv+FlEvK+A+sxqzc9GdZAGijal13+W9AhZJGwNNmbWRUQwOubNs7pK/bvDgbvbXH67pPuB3wOfiYiH2nx+BbAivX3hHW8489Ei7y+n2cDTk1DvZJlKv3eyfuvciRTeHm7ZjEvS7sD3gU9GxJ9aLt8HzI2IFySdBNwIzG/9jrRGYFXr+TJJWjvedGLdTKXfOwi/NQjGahpsCpn6ljSdLNBcHRE3tF6PiD9FxAvp9U3AdEmzi6jbrG7GInIdg6aI2SgBlwOPRMSXOpTZB3gqrVQ8kizIbem1brM6qmvLpohu1DHAR4H1ktalc58H3gQQEZeSPcT1cUmjwFZgaURlQ/OkduMmwVT6vZX/rRH17Uapun/zZlPPQcNz4kuv+3Cusouf/cq9VR+DauYVxGaVEp6NMrP+CxjIwd88/CBmE0mLJD0qaUTSTjuX1Ymk1ZI2S+r6AN2gk7S/pNslPSzpIUn/PNn31FFkA8R5jkHjYJNIGgK+BpwILACWSVowuXfVV1cAiyb7JkrSeKRmAXA0cFZ1/7+N2gYbd6NedSQwEhGPA0i6lmzrxFo+dhERd6YV37U3SI/UBPhxhSlgP+DJpvcbgKMm6V6sT7o8UjPpwgPEZoOvyyM11RBe1DcVbCTbJKjhjemc1UC3R2qqos6zUQ42r7oHmC/pALIgsxT4u8m9JStCnkdqqqO+K4g9G5VExCjwCbId5B8BvttuG4y6kHQN8EvgYEkb0vaQddV4pOa4pgyOJ032TbWTtWw8G1V76Yn0myb7PsoQEcsm+x7KEhE/J9srt/IigldqOhvllo1ZxZTRspH0wbTAcSxtct6pXNuFrpIOkHR3On+dpBnd6nSwMauYkvazeRB4P3BnpwJdFrpeBFwcEQcBzwJdu+EONmYVEiWtII6IRyKi27a7Oxa6RsTLwLXA4jTgfhxwfSqXJ1e4x2zMqmQDz918dtyQdxfLmeOl3y1Ap4WuewHPpUmVxvmd0r20crAxq5CIKOx5tfFyfUfEeBkw+8LBxqymxsv1nVOnha5bgFmShlPrJtcCWI/ZmFknOxa6ptmmpcCatKXv7WTb/UL3XOGAg43ZlCTpbyVtAN4O/FjSzen8GyTdBF0Xup4DnC1phGwM5/KudXoPYjMrg1s2ZlYKBxszK4WDjZmVwsHGzErhYGNmpXCwMbNSONiYWSn+H4jRLO0e+4NOAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "corr_matrix_marginal = np.corrcoef(posterior_samples.T)\n", - "fig, ax = plt.subplots(1,1, figsize=(4, 4))\n", - "im = plt.imshow(corr_matrix_marginal, clim=[-1, 1], cmap='PiYG')\n", - "_ = fig.colorbar(im)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It might be tempting to conclude that the experimental data barely constrains our parameters and that almost all parameter combinations can reproduce the experimental data. As we will show below, this is not the case." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Because our toy posterior has only three parameters, we can plot posterior samples in a 3D plot:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "rc('animation', html='html5')\n", - "\n", - "# First set up the figure, the axis, and the plot element we want to animate\n", - "fig = plt.figure(figsize=(6,6))\n", - "ax = fig.add_subplot(111, projection='3d')\n", - "\n", - "ax.set_xlim((-2, 2))\n", - "ax.set_ylim((-2, 2))\n", - "\n", - "def init():\n", - " line, = ax.plot([], [], lw=2)\n", - " line.set_data([], [])\n", - " return (line,)\n", - "\n", - "def animate(angle):\n", - " num_samples_vis = 1000\n", - " line = ax.scatter(posterior_samples[:num_samples_vis, 0], posterior_samples[:num_samples_vis, 1], posterior_samples[:num_samples_vis, 2], zdir='z', s=15, c='#2171b5', depthshade=False)\n", - " ax.view_init(20, angle)\n", - " return (line,)\n", - "\n", - "anim = animation.FuncAnimation(fig, animate, init_func=init,\n", - " frames=range(0,360,5), interval=150, blit=True)\n", - "\n", - "plt.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "HTML(anim.to_html5_video())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Clearly, the range of admissible parameters is constrained to a narrow region in parameter space, which had not been evident from the marginals." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If the posterior has more than three dimensions, inspecting all dimensions at once will not be possible anymore. One way to still reveal structures in high-dimensional posteriors is to inspect 2D-slices through the posterior. In `sbi`, this can be done with the `conditional_pairplot()` function, which computes the conditional distributions within the posterior. We can slice (i.e. condition) the posterior at any location, given by the `condition`. In the plot below, for all upper diagonal plots, we keep all but two parameters constant at values sampled from the posterior, and inspect what combinations of the remaining two parameters can reproduce experimental data. For the plots on the diagonal (the 1D conditionals), we keep all but one parameter constant." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "condition = posterior.sample((1,))\n", - "\n", - "_ = conditional_pairplot(\n", - " density=posterior,\n", - " condition=condition,\n", - " limits=torch.tensor([[-2., 2.]]*3),\n", - " figsize=(5,5)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This plot looks completely different from the marginals obtained with `pairplot()`. As it can be seen on the diagonal plots, if all parameters but one are kept constant, the remaining parameter has to be tuned to a narrow region in parameter space. In addition, the upper diagonal plots show strong correlations: deviations in one parameter can be compensated through changes in another parameter." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can summarize these correlations in a conditional correlation matrix, which computes the Pearson correlation coefficient of each of these pairwise plots. This matrix (below) shows strong correlations between many parameters, which can be interpreted as potential compensation mechansims:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAADxCAYAAAD7hRNxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWV0lEQVR4nO3df6xcZZ3H8ffn3tvSdVGLFAEBKYZmlxp2QRvAkCgLiIUYWhG13awUF3JXg7urqEEkkRWXWHYTu/gTG6iAEsCtiDViWOTHglFYKlsoPxapoNJaqZRSJIXC7f3uH+eZMkzv3Dm3c+bcM+d+XuakM+c8M88ZoV+eX+f5KiIwM+u1gcm+ATObGhxszKwUDjZmVgoHGzMrhYONmZXCwcbMSuFgY1ZTklZI2iTpwTbXJekrktZJekDS25quLZH0WDqWFHE/DjZm9XUlMH+c6ycDc9IxDHwTQNIbgAuBo4GjgAsl7dXtzTjYmNVURNwJPDNOkQXA1ZG5G5gpaX/gPcAtEfFMRGwBbmH8oJXLULdfYGbF+fODZsSOF0dzld3+9MsPAS82nVoeEcsnUN0BwJNN79enc+3Od8XBxqxCRl8c5ZDTZuUq+3/LN74YEfN6fEuFcTfKrEoEAwPKdRRgA3BQ0/sD07l257viYGNWMVK+owCrgDPSrNQxwNaI2AjcDJwkaa80MHxSOtcVd6PMKkTAQEFNAEnXAscBsyStJ5thmgYQEZcBNwGnAOuAbcBH0rVnJH0RuDd91UURMd5Acy4ONmZVIhgcKqbZEhGLO1wP4Jw211YAKwq5kcTBxqxiVNPBDQcbswqRYKCgAZmqcbAxqxi3bMysFEUNEFeNg41ZhUhu2ZhZSQYHPWZjZj0muRtlZqUQKuZRhMpxsDGrErdszKwsHiA2s56TPEA8prR94PXAbOA3wAfTzl6t5XYAa9Pb30XEqd3Ua1Znde1GdfuzPgvcGhFzgFvT+7G8EBFHpMOBxqwNARpQrqPfdBtsFgBXpddXAQu7/D6zqS0NEOc5+k23Yzb7ps12AP4A7Num3AxJq4ERYGlE3DhWIUnDZLu8oyG9ffrMqTOkdOCWmZN9C6V56o3PTvYtlOr5jS8/HRH75C1f0+cwOwcbST8F9hvj0gXNbyIiJEWbrzk4IjZIegtwm6S1EfHr1kJps+blADP2mR6zF+bbi7UOvrTytMm+hdIsG1412bdQqru+8ORv85bNNs+qZ7TpGGwi4sR21yQ9JWn/iNiYUkBsavMdG9Kfj0u6AzgS2CXYmE15BW6eVTXd9vxWAY1seUuAH7YWSPuY7pFezwKOBR7usl6zWhJiQPmOftNtsFkKvFvSY8CJ6T2S5km6PJU5DFgt6X7gdrIxGwcbs7EUnF1B0nxJj6YUu7vMFktaJmlNOn4l6dmmazuarnXd9+1qBDYiNgMnjHF+NXB2ev1z4PBu6jGbKoocs5E0CHwdeDdZorl7Ja1q/o99RHyyqfw/kg1xNLwQEUcUcjM4lYtZ5QxoINeRw1HAuoh4PCJeAq4jW67SzmLg2gJ+wpgcbMyqRPm6UDlbP7nT6Eo6GDgEuK3p9AxJqyXdLWnhbv6inabOQhazPiBgaDB3G2BWWr/WMNFc380WASsjYkfTuVxLVvJysDGrkGzzrNzB5ukOub4nkkZ3ES05pIpesuJulFmlFNqNuheYI+kQSdPJAsous0qS/hLYC/hF07nCl6y4ZWNWJQXmjYqIEUkfJ8vTPQisiIiHJF0ErI6IRuBZBFyXMmQ2HAZ8S9IoWaOk6yUrDjZmFVL04woRcRNZTu/mc59vef8vY3yu8CUrDjZmVSIxODg42XfREw42ZhUypR/ENLNyOdiYWc9J5F0d3HccbMwqJf9Dlv3GwcasYvpx+4g8HGzMKkSCoSHPRplZjzU2z6ojBxuzKpFno8ysJA42ZtZzwlPfZlYGeerbzEogxNDgtMm+jZ4opL2WYwf3PSRdn67fI2l2EfWa1Y2AQQ3mOvpN18GmaQf3k4G5wGJJc1uKnQVsiYhDgWXAJd3Wa1ZLEgMDg7mOflNEyybPDu4LgKvS65XACVJNFxOYdWlAg7mOflPEmM1YO7gf3a5M2j1sK7A38HQB9ZvVhtBE9iDuK5UaIJY0DAwDDO3Zf5HbrFuSmDY4fbJvoyeKCDZ5dnBvlFkvaQh4PbC59YtSGorlADP2mR6t183qT7VdZ1PEr8qzg/sqYEl6fTpwW8vmymZGI5VLcQPEOWaKz5T0x6ac3mc3XVsi6bF0LGn97ER13bLJuYP7FcB3JK0DniELSGa2CxU2rZ0n13dyfUR8vOWzbwAuBOYBAfwyfXbL7t5PIWM2nXZwj4gXgQ8UUZdZnRX8uMLOmWIASY2Z4jwpWd4D3BIRz6TP3gLMp4tc4PXsHJr1rQmts5mVcnE3juGWL8ub6/v9kh6QtFJSY/w1d57wvCo1G2U21U1wNqpT+t08fgRcGxHbJf0D2Xq447v8zjG5ZWNWIY1uVJ4jh44zxRGxOSK2p7eXA2/P+9mJcrAxq5JiH1foOFMsaf+mt6cCj6TXNwMnpZzfewEnpXO7zd0os0pRYY8i5Jwp/idJpwIjZDPFZ6bPPiPpi2QBC+CixmDx7nKwMauQLCNmcR2OHDPF5wPnt/nsCmBFUffiYGNWKcWts6kaBxuzCpHEkJ+NMrNe8x7EZlYOicE+3BgrDwcbswrJWjYONmbWc/XdYsLBxqxChBga8ACxmfWahNyNMrMyeMzGzHpOiAEcbMysBG7ZmFnPqcAHMavGwcasUsSgPBtlZj0meZ2NmZWkrt2oQkJoN7lpzKyZnOu7nW5y05jZqwkv6htPN7lpzKyF19m0N1Z+maPHKPd+Se8EfgV8MiKebC2Q8t4MA+yjPfnSytMKuL3+cP7pN0z2LZTmNVvrOQBaBKnYZ6MkzQcuJduD+PKIWNpy/VzgbLI9iP8I/H1E/DZd2wGsTUV/FxGndnMvZf1T/xEwOyL+CriFLDfNLiJieUTMi4h5rxv4s5JuzaxKihuzaRriOBmYCyyWNLel2P8C89LfzZXAvzVdeyEijkhHV4EGigk23eSmMbNXycZs8hw57BziiIiXgMYQx04RcXtEbEtv7yb7+9sTRQSbbnLTmFkTkY3Z5DkoLv1uw1nAT5rez0jfe7ekhd3+tq7HbLrJTWNmrSa0qK+I9LtZrdLfAfOAdzWdPjgiNkh6C3CbpLUR8evdraOQRX3d5KYxs1cUPECcK4WupBOBC4B3NQ13EBEb0p+PS7oDOBLY7WDjaQGzihGDuY4c8gxxHAl8Czg1IjY1nd9L0h7p9SzgWLpczuLHFcwqpMinvnMOcfw7sCfwn5LglSnuw4BvSRola5QsHWOh7oQ42JhVSrFbTOQY4jixzed+Dhxe2I3gYGNWOarp6IaDjVnlaLJvoCccbMwqxHsQm1mJ3I0ysxLI3Sgz6z0hbwtqZuVwy8bMes4DxGZWGnejzKzHhAeIzawU8gpiMyuLWzZmVgK3bMysBMq7V03fcbAxq5BsgNgtGzMrgWejzKz3JPDjCmZWhrq2bAoJoZJWSNok6cE21yXpK5LWSXpA0tuKqNesfrJ1NnmOXN8mzZf0aPq799kxru8h6fp0/R5Js5uunZ/OPyrpPd3+sqLaa1cC88e5fjIwJx3DwDcLqtesdorKrpAz/e5ZwJaIOBRYBlySPjuXLBvDW8n+bn9DOdNwtlNIsImIO8mSz7WzALg6MncDM1uyZJoZrzyukOd/OXRMv5veX5VerwROUJZmYQFwXURsj4gngHXp+3ZbWSNRudKAShpupBJ9bvSFkm7NrEo0gaOQ9Ls7y0TECLAV2DvnZyekUgPEEbEcWA5w6NAbY5Jvx6x8kY58Cku/W4ayWja50oCaWaDId+SQ5+/dzjKShoDXA5tzfnZCygo2q4Az0qzUMcDWiNhYUt1m/WU059FZx/S76f2S9Pp04LaIiHR+UZqtOoRscud/uvhVxXSjJF0LHEfWh1wPXAhMA4iIy8gy8p1CNsi0DfhIEfWa1VK+VkuOr8mVfvcK4DuS1pFN8ixKn31I0vfI8nuPAOdExI5u7qeQYBMRiztcD+CcIuoyq7UAFThamSP97ovAB9p89mLg4qLupVIDxGbGRAaI+4qDjVnVFNSNqhoHG7OqqWescbAxq5Qg77R233GwMauaesYaBxuzynGwMbNSuBtlZmUocp1NlTjYmFXJxB7E7CsONmaVEjBaz2jjYGNWIaK+3ah6buNuZpXjlo1Z1Xg2ysx6zgPEZlYWeYDYzEpRz1jjYGNWKYGnvs2sDEHUdIDYU99mVVPchudtSXqDpFskPZb+3GuMMkdI+oWkh1La7A81XbtS0hOS1qTjiE51OtiYVUgExGjkOrr0WeDWiJgD3Jret9oGnBERjRS8/yFpZtP1z0TEEelY06lCd6PMKiZ2dNlsyWcBWUYUyNLv3gGc96r7iPhV0+vfS9oE7AM8uzsVFtKykbRC0iZJD7a5fpykrU1Nrs+PVc5symsMEOc5OqffHc++Tbnb/gDsO15hSUcB04FfN52+OHWvlknao1OFRbVsrgS+Blw9Tpm7IuK9BdVnVlMTGiAeN/2upJ8C+41x6YJX1RgRUvsnsiTtD3wHWBIRjWbX+WRBajpZyuzzgIvGu9mi8kbdKWl2Ed/V8NQbn2XZcGvyvvp6zdapM3y27flSugn9q6D/eyLixHbXJD0laf+I2JiCyaY25V4H/Bi4ICLubvruRqtou6RvA5/udD9l/hv+Dkn3S/qJpLeOVUDScKNJ+PI2/wtpU1NE5Dq61Jx2dwnww9YCKWXvD4CrI2Jly7X9058CFgJjDqE0KyvY3AccHBF/DXwVuHGsQhGxPCLmRcS8aa+ZOv+lN9tpYmM23VgKvFvSY8CJ6T2S5km6PJX5IPBO4MwxprivkbQWWAvMAv61U4WlzEZFxHNNr2+S9A1JsyLi6TLqN+snZcxGRcRm4IQxzq8Gzk6vvwt8t83nj59onaUEG0n7AU+lgaijyFpUm8uo26yfRBSyhqaSCgk2kq4lm7OfJWk9cCEwDSAiLgNOBz4maQR4AVgUdV2Tbdatmg5XFjUbtbjD9a+RTY2bWQd1/e+wVxCbVYmf+jazspT0uELpHGzMqiTwmI2ZlcGzUWZWFg8Qm1nPpf1s6sjBxqxqHGzMrNciwrNRZlYOBxsz6z2P2ZhZOdyNMrMyBDDqYGNmPRYRjL68Y7JvoyccbMwqxt0oM+u9GnejvNGvWaXky4bZ7YxVnvS7qdyOpv2HVzWdP0TSPZLWSbo+bY4+LgcbsyqJrBuV5+hSnvS7AC80pdg9ten8JcCyiDgU2AKc1alCBxuzCgkgRkdzHV1aQJZ2l/TnwrwfTOlbjgca6V1yfd5jNmZVEkHkn42aJWl10/vlEbE852fzpt+dkeoYAZZGxI3A3sCzETGSyqwHDuhUoYONWZXEhGajyki/e3BEbJD0FuC2lCtqa94bbNZ1sJF0EFmO733JWoHLI+LSljICLgVOAbYBZ0bEfd3WbVY/UUQXKfumAtLvRsSG9Ofjku4AjgS+D8yUNJRaNwcCGzrdTxFjNiPApyJiLnAMcI6kuS1lTgbmpGMY+GYB9ZrVTwA7It/RnTzpd/eStEd6PQs4Fng4pWG6nSxFU9vPt+o62ETExkYrJSL+BDzCrv23BWT5giMlJ5/ZyBVsZq9W0gBxnvS7hwGrJd1PFlyWRsTD6dp5wLmS1pGN4VzRqcJCx2wkzSZrZt3TcukA4Mmm940BpY2Y2U5l7WeTM/3uz4HD23z+ceCoidRZWLCRtCdZX+4Tzbm9J/gdw2TdLPZ4/WBRt2bWP4KJzEb1laLS704jCzTXRMQNYxTZABzU9H7MAaU0bbcc4LVvml7PTT3MxlXcAHHVdD1mk2aargAeiYgvtym2CjhDmWOArU1z/GbWUN4K4tIV0bI5FvgwsFbSmnTuc8CbASLiMuAmsmnvdWRT3x8poF6zGqpvy6brYBMRPwPUoUwA53Rbl1ntNaa+a8griM0qJCIYfWmkc8E+5GBjViUTe1yhrzjYmFVJQIw42JhZzzm7gpmVINyyMbNSRDjYmFkJAka3+3EFM+u1kh7EnAwONmYV4jEbMyuHx2zMrCzuRplZ77kbZWZliAhGt9fz2SgnqTOrkjRmk+foRp70u5L+pin17hpJL0pamK5dKemJpmtHdKrTwcasSiqUfjcibm+k3iXLgLkN+K+mIp9pSs27plOFDjZmVZLGbHrdsmHi6XdPB34SEdt2t0IHG7NKKacbRf70uw2LgGtbzl0s6QFJyxr5pcbjAWKzConRCT2uMG6u74LS75JyvB0O3Nx0+nyyIDWdLEnBecBF492sg41ZpUzocYVxc30XkX43+SDwg4h4uem7G62i7ZK+DXy60826G2VWJeWN2XRMv9tkMS1dqEZG25RdZSHwYKcK3bIxq5LyHldYCnxP0lnAb8laL0iaB3w0Is5O72eT5Xz775bPXyNpH7JkB2uAj3aqsOtgI+kg4GqyAaYg6zde2lLmOLLI+UQ6dUNEjNu/M5uKoqQ9iPOk303vf0OWKru13PETrbOIls0I8KmIuE/Sa4FfSrqlKQF5w10R8d4C6jOrNT8b1UYaKNqYXv9J0iNkkbA12JhZBxHByKg3z+oo9e+OBO4Z4/I7JN0P/B74dEQ8NMbnh4Hh9Pb5u77w5KNF3l9Os4CnJ6HeyTKVfu9k/daDJ1J4R7hlMy5JewLfBz4REc+1XL4PODginpd0CnAjMKf1O9IageWt58skafV404l1M5V+bz/81iAYrWmwKWTqW9I0skBzTUTc0Ho9Ip6LiOfT65uAaZJmFVG3Wd2MRuQ6+k0Rs1ECrgAeiYgvtymzH/BUWql4FFmQ29xt3WZ1VNeWTRHdqGOBDwNrJa1J5z4HvBkgIi4je4jrY5JGgBeARRGVDc2T2o2bBFPp91b+t0bUtxul6v6dN5t6Dh3aN778ug/lKrtgy1d/WfUxqGZeQWxWKeHZKDPrvYC+HPzNww9iNpE0X9KjktZJ2mXnsjqRtELSJkkdH6Drd5IOknS7pIclPSTpnyf7ntqKbIA4z9FvHGwSSYPA14GTgbnAYklzJ/eueupKYP5k30RJGo/UzAWOAc6p7j/bqG2wcTfqFUcB6yLicQBJ15FtnVjLxy4i4s604rv2+umRmgA/rjAFHAA82fR+PXD0JN2L9UiHR2omXXiA2Kz/dXikphrCi/qmgg1kmwQ1HJjOWQ10eqSmKuo8G+Vg84p7gTmSDiELMouAv53cW7Ii5Hmkpjrqu4LYs1FJRIwAHyfbQf4R4HtjbYNRF5KuBX4B/IWk9Wl7yLpqPFJzfFMGx1Mm+6bGkrVsPBtVe+mJ9Jsm+z7KEBGLJ/seyhIRPyPbK7fyIoKXazob5ZaNWcWU0bKR9IG0wHE0bXLertyYC10lHSLpnnT+eknTO9XpYGNWMSXtZ/MgcBpwZ7sCHRa6XgIsi4hDgS1Ax264g41ZhURJK4gj4pGI6LTt7s6FrhHxEnAdsCANuB8PrEzl8uQK95iNWZWs59mbz40b8u5iOWO89LsFaLfQdW/g2TSp0ji/S7qXVg42ZhUSEYU9rzZeru+IGC8DZk842JjV1Hi5vnNqt9B1MzBT0lBq3eRaAOsxGzNrZ+dC1zTbtAhYlbb0vZ1su1/onCsccLAxm5IkvU/SeuAdwI8l3ZzOv0nSTdBxoet5wLmS1pGN4VzRsU7vQWxmZXDLxsxK4WBjZqVwsDGzUjjYmFkpHGzMrBQONmZWCgcbMyvF/wPZqxE8c4CsDAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "cond_coeff_mat = conditional_corrcoeff(\n", - " density=posterior,\n", - " condition=condition,\n", - " limits=torch.tensor([[-2., 2.]]*3),\n", - ")\n", - "fig, ax = plt.subplots(1,1, figsize=(4,4))\n", - "im = plt.imshow(cond_coeff_mat, clim=[-1, 1], cmap='PiYG')\n", - "_ = fig.colorbar(im)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "So far, we have investigated the conditional distribution only at a specific `condition` sampled from the posterior. In many applications, it makes sense to repeat the above analyses with a different `condition` (another sample from the posterior), which can be interpreted as slicing the posterior at a different location. Note that `conditional_corrcoeff()` can directly compute the matrix for several `conditions` and then outputs the average over them. This can be done by passing a batch of $N$ conditions as the `condition` argument." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Sampling conditional distributions\n", - "\n", - "So far, we have demonstrated how one can plot 2D conditional distributions with `conditional_pairplot()` and how one can compute the pairwise conditional correlation coefficient with `conditional_corrcoeff()`. In some cases, it can be useful to keep a subset of parameters fixed and to vary **more than two** parameters. This can be done by sampling the conditonal posterior $p(\\theta_i | \\theta_{j \\neq i}, x_o)$. As of `sbi` `v0.18.0`, this functionality requires using the [sampler interface](https://www.mackelab.org/sbi/tutorial/11_sampler_interface/). In this tutorial, we demonstrate this functionality on a linear gaussian simulator with four parameters. We would like to fix the forth parameter to $\\theta_4=0.2$ and sample the first three parameters given that value, i.e. we want to sample $p(\\theta_1, \\theta_2, \\theta_3 | \\theta_4 = 0.2, x_o)$. For an application in neuroscience, see [Deistler, Gonçalves, Macke, 2021](https://www.biorxiv.org/content/10.1101/2021.07.30.454484v4.abstract)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial, we will use SNPE, but the same also works for SNLE and SNRE. First, we define the prior and the simulator and train the deep neural density estimator:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4683b28ba0614e87b33854d207e30da2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Running 1000 simulations.: 0%| | 0/1000 [00:00" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "from sbi.analysis import pairplot\n", - "\n", - "_ = pairplot(cond_samples, limits=[[-2, 2], [-2, 2], [-2, 2], [-2, 2]], figsize=(4, 4))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Analysing variability and compensation mechansims with conditional distributions\n", + "\n", + "A central advantage of `sbi` over parameter search methods such as genetic algorithms is that the posterior captures **all** models that can reproduce experimental data. This allows us to analyse whether parameters can be variable or have to be narrowly tuned, and to analyse compensation mechanisms between different parameters. See also [Marder and Taylor, 2011](https://www.nature.com/articles/nn.2735?page=2) for further motivation to identify all models that capture experimental data. \n", + "\n", + "In this tutorial, we will show how one can use the posterior distribution to identify whether parameters can be variable or have to be finely tuned, and how we can use the posterior to find potential compensation mechanisms between model parameters. To investigate this, we will extract **conditional distributions** from the posterior inferred with `sbi`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note, you can find the original version of this notebook at [https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb](https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb) in the `sbi` repository." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Main syntax" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sbi.analysis import conditional_pairplot, conditional_corrcoeff\n", + "\n", + "# Plot slices through posterior, i.e. conditionals.\n", + "_ = conditional_pairplot(\n", + " density=posterior,\n", + " condition=posterior.sample((1,)),\n", + " limits=torch.tensor([[-2., 2.], [-2., 2.]]),\n", + ")\n", + "\n", + "# Compute the matrix of correlation coefficients of the slices.\n", + "cond_coeff_mat = conditional_corrcoeff(\n", + " density=posterior,\n", + " condition=posterior.sample((1,)),\n", + " limits=torch.tensor([[-2., 2.], [-2., 2.]]),\n", + ")\n", + "plt.imshow(cond_coeff_mat, clim=[-1, 1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Analysing variability and compensation mechanisms in a toy example\n", + "Below, we use a simple toy example to demonstrate the above described features. For an application of these features to a neuroscience problem, see figure 6 in [Gonçalves, Lueckmann, Deistler et al., 2019](https://arxiv.org/abs/1907.00770)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sbi import utils as utils\n", + "from sbi.analysis import pairplot, conditional_pairplot, conditional_corrcoeff\n", + "import torch\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "from matplotlib import animation, rc\n", + "from IPython.display import HTML, Image\n", + "\n", + "_ = torch.manual_seed(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's say we have used SNPE to obtain a posterior distribution over three parameters. In this tutorial, we just load the posterior from a file:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from toy_posterior_for_07_cc import ExamplePosterior\n", + "posterior = ExamplePosterior()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we specify the experimental observation $x_o$ at which we want to evaluate and sample the posterior $p(\\theta|x_o)$:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "x_o = torch.ones(1, 20) # simulator output was 20-dimensional\n", + "posterior.set_default_x(x_o)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As always, we can inspect the posterior marginals with the `pairplot()` function:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "posterior_samples = posterior.sample((5000,))\n", + "\n", + "fig, ax = pairplot(\n", + " samples=posterior_samples,\n", + " limits=torch.tensor([[-2., 2.]]*3),\n", + " upper=['kde'],\n", + " diag=['kde'],\n", + " figsize=(5,5)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The 1D and 2D marginals of the posterior fill almost the entire parameter space! Also, the Pearson correlation coefficient matrix of the marginal shows rather weak interactions (low correlations):" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAADxCAYAAAD7hRNxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWYklEQVR4nO3df6xcZZ3H8fen97Y0u6wWKCmI2EJokBpWkAYwRGUBtRBDu4rabpTiQroY3F1FDSKJbHBJym4iiz+xgQooCygi1ohhkR+LRmEpbKH8WOQKKq2VSvmhhFK4vd/94zxThunMnXM7Z849c+7nZU46c84z85yJuV+eX+f5KiIwM+u3aZN9A2Y2NTjYmFkpHGzMrBQONmZWCgcbMyuFg42ZlcLBxqymJK2WtFnSgx2uS9KXJY1IekDS25quLZf0WDqWF3E/DjZm9XUFsGic6ycC89OxAvgGgKQ9gfOBo4AjgfMl7dHrzTjYmNVURNwJPDNOkcXAVZG5C5glaV/gvcAtEfFMRDwL3ML4QSuX4V6/wMyK85f7z4ztL43lKrvt6VceAl5qOrUqIlZNoLr9gCeb3m9I5zqd74mDjVmFjL00xgHvn52r7P+t2vRSRCzs8y0Vxt0osyoRTJumXEcBNgL7N71/YzrX6XxPHGzMKkbKdxRgDXBqmpU6Gng+IjYBNwPvkbRHGhh+TzrXE3ejzCpEwLSCmgCSrgGOBWZL2kA2wzQdICIuBW4CTgJGgBeBj6Vrz0j6InBP+qoLImK8geZcHGzMqkQwNFxMsyUilnW5HsBZHa6tBlYXciOJg41ZxaimgxsONmYVIsG0ggZkqsbBxqxi3LIxs1IUNUBcNQ42ZhUiuWVjZiUZGvKYjZn1meRulJmVQqiYRxEqx8HGrErcsjGzsniA2Mz6TvIAcVtp+8DrgHnAb4APpZ29WsttB9ant7+LiJN7qdeszurajer1Z30OuDUi5gO3pvftbI2Iw9LhQGPWgQBNU65j0PQabBYDV6bXVwJLevw+s6ktDRDnOQZNr2M2c9JmOwB/AOZ0KDdT0lpgFFgZETe2KyRpBdku72hYR8yYNXWGlA7+iwMn+xZKs23W1sm+hVI9+sDvno6IvfOWr+lzmN2DjaSfAvu0uXRe85uICEnR4WvmRsRGSQcCt0laHxG/bi2UNmteBTBz7xkxb0m+vVjr4OYjLp/sWyjNyOK2aYxq6x1vOPO3ectmm2fVM9p0DTYRcUKna5KekrRvRGxKKSA2d/iOjenfxyXdARwO7BRszKa8AjfPqppee35rgEa2vOXAD1sLpH1Md0uvZwPHAA/3WK9ZLQkxTfmOQdNrsFkJvFvSY8AJ6T2SFkq6LJU5BFgr6X7gdrIxGwcbs3YKzq4gaZGkR1OK3Z1miyVdLGldOn4l6bmma9ubrq3p9af1NAIbEVuA49ucXwuckV7/Aji0l3rMpooix2wkDQFfA95NlmjuHklrmv9jHxGfair/j2RDHA1bI+KwQm4Gp3Ixq5xpmpbryOFIYCQiHo+Il4FryZardLIMuKaAn9CWg41ZlShfFypn6yd3Gl1Jc4EDgNuaTs+UtFbSXZKW7OIv2mHqLGQxGwAChodytwFmp/VrDRPN9d1sKXB9RGxvOpdryUpeDjZmFZJtnpU72DzdJdf3RNLoLqUlh1TRS1bcjTKrlEK7UfcA8yUdIGkGWUDZaVZJ0puBPYBfNp0rfMmKWzZmVVJg3qiIGJX0CbI83UPA6oh4SNIFwNqIaASepcC1KUNmwyHANyWNkTVKel6y4mBjViFFP64QETeR5fRuPveFlvf/0uZzhS9ZcbAxqxKJoaGhyb6LvnCwMauQKf0gppmVy8HGzPpOIu/q4IHjYGNWKfkfshw0DjZmFTOI20fk4WBjViESDA97NsrM+qyxeVYdOdiYVYk8G2VmJXGwMbO+E576NrMyyFPfZlYCIYaHpk/2bfRFIe21HDu47ybpunT9bknziqjXrG4EDGko1zFoeg42TTu4nwgsAJZJWtBS7HTg2Yg4CLgYuKjXes1qSWLatKFcx6ApomWTZwf3xcCV6fX1wPFSTRcTmPVomoZyHYOmiDGbdju4H9WpTNo97HlgL+DpAuo3qw2hiexBPFAqNUAsaQWwAmB498GL3Ga9ksT0oRmTfRt9UUSwybODe6PMBknDwOuBLa1flNJQrAKYufeMaL1uVn+q7TqbIn5Vnh3c1wDL0+tTgNtaNlc2MxqpXIobIM4xU3yapD825fQ+o+nackmPpWN562cnqueWTc4d3C8Hvi1pBHiGLCCZ2U5U2LR2nlzfyXUR8YmWz+4JnA8sBAK4N3322V29n0LGbLrt4B4RLwEfLKIuszor+HGFHTPFAJIaM8V5UrK8F7glIp5Jn70FWEQPucDr2Tk0G1gTWmczO+XibhwrWr4sb67vD0h6QNL1khrjr7nzhOdVqdkos6lugrNR3dLv5vEj4JqI2CbpH8jWwx3X43e25ZaNWYU0ulF5jhy6zhRHxJaI2JbeXgYckfezE+VgY1YlxT6u0HWmWNK+TW9PBh5Jr28G3pNyfu8BvCed22XuRplVigp7FCHnTPE/SToZGCWbKT4tffYZSV8kC1gAFzQGi3eVg41ZhWQZMYvrcOSYKT4XOLfDZ1cDq4u6Fwcbs0opbp1N1TjYmFWIJIb9bJSZ9Zv3IDazckgMDeDGWHk42JhVSNaycbAxs76r7xYTDjZmFSLE8DQPEJtZv0nI3SgzK4PHbMys74SYhoONmZXALRsz6zsV+CBm1TjYmFWKGJJno8yszySvszGzktS1G1VICO0lN42ZNZNzfXfSS24aM3st4UV94+klN42ZtfA6m87a5Zc5qk25D0h6J/Ar4FMR8WRrgZT3ZgXAfnvO4Y4jvlfA7Q2GY++dOjn8PnLwoZN9C5UlFftslKRFwCVkexBfFhErW66fDZxBtgfxH4G/j4jfpmvbgfWp6O8i4uRe7qWsYe8fAfMi4q+BW8hy0+wkIlZFxMKIWLjX7rNKujWzKiluzKZpiONEYAGwTNKClmL/CyxMf5vXA//WdG1rRByWjp4CDRQTbHrJTWNmr5GN2eQ5ctgxxBERLwONIY4dIuL2iHgxvb2L7O+3L4oINr3kpjGzJiIbs8lzUFz63YbTgZ80vZ+ZvvcuSUt6/W09j9n0kpvGzFpNaFFfEel3s1qljwALgXc1nZ4bERslHQjcJml9RPx6V+soZFFfL7lpzOxVBQ8Q50qhK+kE4DzgXU3DHUTExvTv45LuAA4HdjnY1HNdtNkAE0O5jhzyDHEcDnwTODkiNjed30PSbun1bOAYelzO4scVzCqkyKe+cw5x/DuwO/A9SfDqFPchwDcljZE1Sla2Wag7IQ42ZpVS7BYTOYY4TujwuV8AhS6IcrAxqxjVdHTDwcascjTZN9AXDjZmFeI9iM2sRO5GmVkJ5G6UmfWfkLcFNbNyuGVjZn3nAWIzK427UWbWZ8IDxGZWCnkFsZmVxS0bMyuBWzZmVgLl3atm4DjYmFVINkDslo2ZlcCzUWbWfxL4cQUzK0NdWzaFhFBJqyVtlvRgh+uS9GVJI5IekPS2Iuo1q59snU2eI9e3SYskPZr+9j7X5vpukq5L1++WNK/p2rnp/KOS3tvrLyuqvXYFsGic6ycC89OxAvhGQfWa1U5R2RVypt89HXg2Ig4CLgYuSp9dQJaN4S1kf9tfV840nJ0UEmwi4k6y5HOdLAauisxdwKyWLJlmxquPK+T5Xw5d0++m91em19cDxytLs7AYuDYitkXEE8BI+r5dVtZIVK40oJJWNFKJbnnhuZJuzaxKNIGjkPS7O8pExCjwPLBXzs9OSKUGiCNiFbAK4K1z3xyTfDtm5Yt05FNY+t0ylNWyyZUG1MwCRb4jhzx/dzvKSBoGXg9syfnZCSkr2KwBTk2zUkcDz0fEppLqNhssYzmP7rqm303vl6fXpwC3RUSk80vTbNUBZJM7/9PDryqmGyXpGuBYsj7kBuB8YDpARFxKlpHvJLJBpheBjxVRr1kt5Wu15PiaXOl3Lwe+LWmEbJJnafrsQ5K+S5bfexQ4KyK293I/hQSbiFjW5XoAZxVRl1mtBajA0coc6XdfAj7Y4bMXAhcWdS+VGiA2MyYyQDxQHGzMqqagblTVONiYVU09Y42DjVmlBHmntQeOg41Z1dQz1jjYmFWOg42ZlcLdKDMrQ5HrbKrEwcasSib2IOZAcbAxq5SAsXpGGwcbswoR9e1G1XMbdzOrHLdszKrGs1Fm1nceIDazssgDxGZWinrGGgcbs0oJPPVtZmUIoqYDxJ76Nqua4jY870jSnpJukfRY+nePNmUOk/RLSQ+ltNkfbrp2haQnJK1Lx2Hd6nSwMauQCIixyHX06HPArRExH7g1vW/1InBqRDRS8P6HpFlN1z8bEYelY123Ct2NMquY2N5jsyWfxWQZUSBLv3sHcM5r7iPiV02vfy9pM7A38NyuVFhIy0bSakmbJT3Y4fqxkp5vanJ9oV05symvMUCc5+iefnc8c5pyt/0BmDNeYUlHAjOAXzedvjB1ry6WtFu3Cotq2VwBfBW4apwyP4uI9xVUn1lNTWiAeNz0u5J+CuzT5tJ5r6kxIqTOT2RJ2hf4NrA8IhrNrnPJgtQMspTZ5wAXjHezReWNulPSvCK+q2HbrK2MLG7bUKqljxx86GTfQmm+85/rJ/sWqq2gXlREnNDpmqSnJO0bEZtSMNncodzrgB8D50XEXU3f3WgVbZP0LeAz3e6nzAHit0u6X9JPJL2lXQFJKxpNwue2vFDirZlVR0TkOnrUnHZ3OfDD1gIpZe8PgKsi4vqWa/umfwUsAbq2DMoKNvcBcyPircBXgBvbFYqIVRGxMCIWztpr95JuzaxCJjZm04uVwLslPQackN4jaaGky1KZDwHvBE5rM8V9taT1wHpgNvCv3SosZTYqIv7U9PomSV+XNDsini6jfrNBUsZsVERsAY5vc34tcEZ6/R3gOx0+f9xE6ywl2EjaB3gqDUQdSdai2lJG3WaDJKKQNTSVVEiwkXQN2Zz9bEkbgPOB6QARcSlwCvBxSaPAVmBp1HVNtlmvSllmU76iZqOWdbn+VbKpcTProq7/HfYKYrMq8VPfZlaWkh5XKJ2DjVmVBB6zMbMyeDbKzMriAWIz67u0n00dOdiYVY2DjZn1W0R4NsrMyuFgY2b95zEbMyuHu1FmVoYAxhxszKzPIoKxV7ZP9m30hYONWcW4G2Vm/VfjbpQzYppVSr5smL3OWOVJv5vKbW/af3hN0/kDJN0taUTSdWlz9HE52JhVSWTdqDxHj/Kk3wXY2pRi9+Sm8xcBF0fEQcCzwOndKnSwMauQAGJsLNfRo8VkaXdJ/y7J+8GUvuU4oJHeJdfnPWZjViURRP7ZqNmS1ja9XxURq3J+Nm/63ZmpjlFgZUTcCOwFPBcRo6nMBmC/bhU62JhVSUxoNqqM9LtzI2KjpAOB21KuqOfz3mCznoONpP3JcnzPIWsFroqIS1rKCLgEOAl4ETgtIu7rtW6z+okiukjZNxWQfjciNqZ/H5d0B3A48H1glqTh1Lp5I7Cx2/0UMWYzCnw6IhYARwNnSVrQUuZEYH46VgDfKKBes/oJYHvkO3qTJ/3uHpJ2S69nA8cAD6c0TLeTpWjq+PlWPQebiNjUaKVExJ+BR9i5/7aYLF9wpOTksxq5gs3stUoaIM6TfvcQYK2k+8mCy8qIeDhdOwc4W9II2RjO5d0qLHTMRtI8smbW3S2X9gOebHrfGFDahJntUNZ+NjnT7/4COLTD5x8HjpxInYUFG0m7k/XlPtmc23uC37GCrJvFnP32LOrWzAZHMJHZqIFSyDobSdPJAs3VEXFDmyIbgf2b3rcdUIqIVRGxMCIWztpr9yJuzWzARFndqNL1HGzSTNPlwCMR8aUOxdYApypzNPB80xy/mTWUt4K4dEV0o44BPgqsl7Qunfs88CaAiLgUuIls2nuEbOr7YwXUa1ZDxU19V03PwSYifg6oS5kAzuq1LrPaa0x915BXEJtVSEQw9vJo94IDyMHGrEom9rjCQHGwMauSgBh1sDGzvnN2BTMrQbhlY2aliHCwMbMSBIxtq+fjCg42ZlVS0oOYk8HBxqxCPGZjZuXwmI2ZlcXdKDPrP3ejzKwMEcHYtno+G+UkdWZVksZs8hy9yJN+V9LfNKXeXSfpJUlL0rUrJD3RdO2wbnU62JhVSYXS70bE7Y3Uu2QZMF8E/qupyGebUvOu61ahg41ZlaQxm363bJh4+t1TgJ9ExIu7WqGDjVmllNONIn/63YalwDUt5y6U9ICkixv5pcbjAWKzComxCT2uMG6u74LS75JyvB0K3Nx0+lyyIDUDWEWWR+qC8W7WwcasUib0uMK4ub6LSL+bfAj4QUS80vTdjVbRNknfAj7T7WbdjTKrkvLGbLqm322yjJYuVCOjbcqusgR4sFuFbtmYVUl5jyusBL4r6XTgt2StFyQtBM6MiDPS+3lkOd/+u+XzV0vamyzZwTrgzG4V9hxsJO0PXEU2wBRk/cZLWsocSxY5n0inboiIcft3ZlNRlLQHcZ70u+n9b8hSZbeWO26idRbRshkFPh0R90n6K+BeSbc0JSBv+FlEvK+A+sxqzc9GdZAGijal13+W9AhZJGwNNmbWRUQwOubNs7pK/bvDgbvbXH67pPuB3wOfiYiH2nx+BbAivX3hHW8489Ei7y+n2cDTk1DvZJlKv3eyfuvciRTeHm7ZjEvS7sD3gU9GxJ9aLt8HzI2IFySdBNwIzG/9jrRGYFXr+TJJWjvedGLdTKXfOwi/NQjGahpsCpn6ljSdLNBcHRE3tF6PiD9FxAvp9U3AdEmzi6jbrG7GInIdg6aI2SgBlwOPRMSXOpTZB3gqrVQ8kizIbem1brM6qmvLpohu1DHAR4H1ktalc58H3gQQEZeSPcT1cUmjwFZgaURlQ/OkduMmwVT6vZX/rRH17Uapun/zZlPPQcNz4kuv+3Cusouf/cq9VR+DauYVxGaVEp6NMrP+CxjIwd88/CBmE0mLJD0qaUTSTjuX1Ymk1ZI2S+r6AN2gk7S/pNslPSzpIUn/PNn31FFkA8R5jkHjYJNIGgK+BpwILACWSVowuXfVV1cAiyb7JkrSeKRmAXA0cFZ1/7+N2gYbd6NedSQwEhGPA0i6lmzrxFo+dhERd6YV37U3SI/UBPhxhSlgP+DJpvcbgKMm6V6sT7o8UjPpwgPEZoOvyyM11RBe1DcVbCTbJKjhjemc1UC3R2qqos6zUQ42r7oHmC/pALIgsxT4u8m9JStCnkdqqqO+K4g9G5VExCjwCbId5B8BvttuG4y6kHQN8EvgYEkb0vaQddV4pOa4pgyOJ032TbWTtWw8G1V76Yn0myb7PsoQEcsm+x7KEhE/J9srt/IigldqOhvllo1ZxZTRspH0wbTAcSxtct6pXNuFrpIOkHR3On+dpBnd6nSwMauYkvazeRB4P3BnpwJdFrpeBFwcEQcBzwJdu+EONmYVEiWtII6IRyKi27a7Oxa6RsTLwLXA4jTgfhxwfSqXJ1e4x2zMqmQDz918dtyQdxfLmeOl3y1Ap4WuewHPpUmVxvmd0r20crAxq5CIKOx5tfFyfUfEeBkw+8LBxqymxsv1nVOnha5bgFmShlPrJtcCWI/ZmFknOxa6ptmmpcCatKXv7WTb/UL3XOGAg43ZlCTpbyVtAN4O/FjSzen8GyTdBF0Xup4DnC1phGwM5/KudXoPYjMrg1s2ZlYKBxszK4WDjZmVwsHGzErhYGNmpXCwMbNSONiYWSn+H4jRLO0e+4NOAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "corr_matrix_marginal = np.corrcoef(posterior_samples.T)\n", + "fig, ax = plt.subplots(1,1, figsize=(4, 4))\n", + "im = plt.imshow(corr_matrix_marginal, clim=[-1, 1], cmap='PiYG')\n", + "_ = fig.colorbar(im)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It might be tempting to conclude that the experimental data barely constrains our parameters and that almost all parameter combinations can reproduce the experimental data. As we will show below, this is not the case." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because our toy posterior has only three parameters, we can plot posterior samples in a 3D plot:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "rc('animation', html='html5')\n", + "\n", + "# First set up the figure, the axis, and the plot element we want to animate\n", + "fig = plt.figure(figsize=(6,6))\n", + "ax = fig.add_subplot(111, projection='3d')\n", + "\n", + "ax.set_xlim((-2, 2))\n", + "ax.set_ylim((-2, 2))\n", + "\n", + "def init():\n", + " line, = ax.plot([], [], lw=2)\n", + " line.set_data([], [])\n", + " return (line,)\n", + "\n", + "def animate(angle):\n", + " num_samples_vis = 1000\n", + " line = ax.scatter(posterior_samples[:num_samples_vis, 0], posterior_samples[:num_samples_vis, 1], posterior_samples[:num_samples_vis, 2], zdir='z', s=15, c='#2171b5', depthshade=False)\n", + " ax.view_init(20, angle)\n", + " return (line,)\n", + "\n", + "anim = animation.FuncAnimation(fig, animate, init_func=init,\n", + " frames=range(0,360,5), interval=150, blit=True)\n", + "\n", + "plt.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "HTML(anim.to_html5_video())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Clearly, the range of admissible parameters is constrained to a narrow region in parameter space, which had not been evident from the marginals." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the posterior has more than three dimensions, inspecting all dimensions at once will not be possible anymore. One way to still reveal structures in high-dimensional posteriors is to inspect 2D-slices through the posterior. In `sbi`, this can be done with the `conditional_pairplot()` function, which computes the conditional distributions within the posterior. We can slice (i.e. condition) the posterior at any location, given by the `condition`. In the plot below, for all upper diagonal plots, we keep all but two parameters constant at values sampled from the posterior, and inspect what combinations of the remaining two parameters can reproduce experimental data. For the plots on the diagonal (the 1D conditionals), we keep all but one parameter constant." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "condition = posterior.sample((1,))\n", + "\n", + "_ = conditional_pairplot(\n", + " density=posterior,\n", + " condition=condition,\n", + " limits=torch.tensor([[-2., 2.]]*3),\n", + " figsize=(5,5)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This plot looks completely different from the marginals obtained with `pairplot()`. As it can be seen on the diagonal plots, if all parameters but one are kept constant, the remaining parameter has to be tuned to a narrow region in parameter space. In addition, the upper diagonal plots show strong correlations: deviations in one parameter can be compensated through changes in another parameter." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can summarize these correlations in a conditional correlation matrix, which computes the Pearson correlation coefficient of each of these pairwise plots. This matrix (below) shows strong correlations between many parameters, which can be interpreted as potential compensation mechansims:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAADxCAYAAAD7hRNxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWV0lEQVR4nO3df6xcZZ3H8ffn3tvSdVGLFAEBKYZmlxp2QRvAkCgLiIUYWhG13awUF3JXg7urqEEkkRWXWHYTu/gTG6iAEsCtiDViWOTHglFYKlsoPxapoNJaqZRSJIXC7f3uH+eZMkzv3Dm3c+bcM+d+XuakM+c8M88ZoV+eX+f5KiIwM+u1gcm+ATObGhxszKwUDjZmVgoHGzMrhYONmZXCwcbMSuFgY1ZTklZI2iTpwTbXJekrktZJekDS25quLZH0WDqWFHE/DjZm9XUlMH+c6ycDc9IxDHwTQNIbgAuBo4GjgAsl7dXtzTjYmNVURNwJPDNOkQXA1ZG5G5gpaX/gPcAtEfFMRGwBbmH8oJXLULdfYGbF+fODZsSOF0dzld3+9MsPAS82nVoeEcsnUN0BwJNN79enc+3Od8XBxqxCRl8c5ZDTZuUq+3/LN74YEfN6fEuFcTfKrEoEAwPKdRRgA3BQ0/sD07l257viYGNWMVK+owCrgDPSrNQxwNaI2AjcDJwkaa80MHxSOtcVd6PMKkTAQEFNAEnXAscBsyStJ5thmgYQEZcBNwGnAOuAbcBH0rVnJH0RuDd91UURMd5Acy4ONmZVIhgcKqbZEhGLO1wP4Jw211YAKwq5kcTBxqxiVNPBDQcbswqRYKCgAZmqcbAxqxi3bMysFEUNEFeNg41ZhUhu2ZhZSQYHPWZjZj0muRtlZqUQKuZRhMpxsDGrErdszKwsHiA2s56TPEA8prR94PXAbOA3wAfTzl6t5XYAa9Pb30XEqd3Ua1Znde1GdfuzPgvcGhFzgFvT+7G8EBFHpMOBxqwNARpQrqPfdBtsFgBXpddXAQu7/D6zqS0NEOc5+k23Yzb7ps12AP4A7Num3AxJq4ERYGlE3DhWIUnDZLu8oyG9ffrMqTOkdOCWmZN9C6V56o3PTvYtlOr5jS8/HRH75C1f0+cwOwcbST8F9hvj0gXNbyIiJEWbrzk4IjZIegtwm6S1EfHr1kJps+blADP2mR6zF+bbi7UOvrTytMm+hdIsG1412bdQqru+8ORv85bNNs+qZ7TpGGwi4sR21yQ9JWn/iNiYUkBsavMdG9Kfj0u6AzgS2CXYmE15BW6eVTXd9vxWAY1seUuAH7YWSPuY7pFezwKOBR7usl6zWhJiQPmOftNtsFkKvFvSY8CJ6T2S5km6PJU5DFgt6X7gdrIxGwcbs7EUnF1B0nxJj6YUu7vMFktaJmlNOn4l6dmmazuarnXd9+1qBDYiNgMnjHF+NXB2ev1z4PBu6jGbKoocs5E0CHwdeDdZorl7Ja1q/o99RHyyqfw/kg1xNLwQEUcUcjM4lYtZ5QxoINeRw1HAuoh4PCJeAq4jW67SzmLg2gJ+wpgcbMyqRPm6UDlbP7nT6Eo6GDgEuK3p9AxJqyXdLWnhbv6inabOQhazPiBgaDB3G2BWWr/WMNFc380WASsjYkfTuVxLVvJysDGrkGzzrNzB5ukOub4nkkZ3ES05pIpesuJulFmlFNqNuheYI+kQSdPJAsous0qS/hLYC/hF07nCl6y4ZWNWJQXmjYqIEUkfJ8vTPQisiIiHJF0ErI6IRuBZBFyXMmQ2HAZ8S9IoWaOk6yUrDjZmFVL04woRcRNZTu/mc59vef8vY3yu8CUrDjZmVSIxODg42XfREw42ZhUypR/ENLNyOdiYWc9J5F0d3HccbMwqJf9Dlv3GwcasYvpx+4g8HGzMKkSCoSHPRplZjzU2z6ojBxuzKpFno8ysJA42ZtZzwlPfZlYGeerbzEogxNDgtMm+jZ4opL2WYwf3PSRdn67fI2l2EfWa1Y2AQQ3mOvpN18GmaQf3k4G5wGJJc1uKnQVsiYhDgWXAJd3Wa1ZLEgMDg7mOflNEyybPDu4LgKvS65XACVJNFxOYdWlAg7mOflPEmM1YO7gf3a5M2j1sK7A38HQB9ZvVhtBE9iDuK5UaIJY0DAwDDO3Zf5HbrFuSmDY4fbJvoyeKCDZ5dnBvlFkvaQh4PbC59YtSGorlADP2mR6t183qT7VdZ1PEr8qzg/sqYEl6fTpwW8vmymZGI5VLcQPEOWaKz5T0x6ac3mc3XVsi6bF0LGn97ER13bLJuYP7FcB3JK0DniELSGa2CxU2rZ0n13dyfUR8vOWzbwAuBOYBAfwyfXbL7t5PIWM2nXZwj4gXgQ8UUZdZnRX8uMLOmWIASY2Z4jwpWd4D3BIRz6TP3gLMp4tc4PXsHJr1rQmts5mVcnE3juGWL8ub6/v9kh6QtFJSY/w1d57wvCo1G2U21U1wNqpT+t08fgRcGxHbJf0D2Xq447v8zjG5ZWNWIY1uVJ4jh44zxRGxOSK2p7eXA2/P+9mJcrAxq5JiH1foOFMsaf+mt6cCj6TXNwMnpZzfewEnpXO7zd0os0pRYY8i5Jwp/idJpwIjZDPFZ6bPPiPpi2QBC+CixmDx7nKwMauQLCNmcR2OHDPF5wPnt/nsCmBFUffiYGNWKcWts6kaBxuzCpHEkJ+NMrNe8x7EZlYOicE+3BgrDwcbswrJWjYONmbWc/XdYsLBxqxChBga8ACxmfWahNyNMrMyeMzGzHpOiAEcbMysBG7ZmFnPqcAHMavGwcasUsSgPBtlZj0meZ2NmZWkrt2oQkJoN7lpzKyZnOu7nW5y05jZqwkv6htPN7lpzKyF19m0N1Z+maPHKPd+Se8EfgV8MiKebC2Q8t4MA+yjPfnSytMKuL3+cP7pN0z2LZTmNVvrOQBaBKnYZ6MkzQcuJduD+PKIWNpy/VzgbLI9iP8I/H1E/DZd2wGsTUV/FxGndnMvZf1T/xEwOyL+CriFLDfNLiJieUTMi4h5rxv4s5JuzaxKihuzaRriOBmYCyyWNLel2P8C89LfzZXAvzVdeyEijkhHV4EGigk23eSmMbNXycZs8hw57BziiIiXgMYQx04RcXtEbEtv7yb7+9sTRQSbbnLTmFkTkY3Z5DkoLv1uw1nAT5rez0jfe7ekhd3+tq7HbLrJTWNmrSa0qK+I9LtZrdLfAfOAdzWdPjgiNkh6C3CbpLUR8evdraOQRX3d5KYxs1cUPECcK4WupBOBC4B3NQ13EBEb0p+PS7oDOBLY7WDjaQGzihGDuY4c8gxxHAl8Czg1IjY1nd9L0h7p9SzgWLpczuLHFcwqpMinvnMOcfw7sCfwn5LglSnuw4BvSRola5QsHWOh7oQ42JhVSrFbTOQY4jixzed+Dhxe2I3gYGNWOarp6IaDjVnlaLJvoCccbMwqxHsQm1mJ3I0ysxLI3Sgz6z0hbwtqZuVwy8bMes4DxGZWGnejzKzHhAeIzawU8gpiMyuLWzZmVgK3bMysBMq7V03fcbAxq5BsgNgtGzMrgWejzKz3JPDjCmZWhrq2bAoJoZJWSNok6cE21yXpK5LWSXpA0tuKqNesfrJ1NnmOXN8mzZf0aPq799kxru8h6fp0/R5Js5uunZ/OPyrpPd3+sqLaa1cC88e5fjIwJx3DwDcLqtesdorKrpAz/e5ZwJaIOBRYBlySPjuXLBvDW8n+bn9DOdNwtlNIsImIO8mSz7WzALg6MncDM1uyZJoZrzyukOd/OXRMv5veX5VerwROUJZmYQFwXURsj4gngHXp+3ZbWSNRudKAShpupBJ9bvSFkm7NrEo0gaOQ9Ls7y0TECLAV2DvnZyekUgPEEbEcWA5w6NAbY5Jvx6x8kY58Cku/W4ayWja50oCaWaDId+SQ5+/dzjKShoDXA5tzfnZCygo2q4Az0qzUMcDWiNhYUt1m/WU059FZx/S76f2S9Pp04LaIiHR+UZqtOoRscud/uvhVxXSjJF0LHEfWh1wPXAhMA4iIy8gy8p1CNsi0DfhIEfWa1VK+VkuOr8mVfvcK4DuS1pFN8ixKn31I0vfI8nuPAOdExI5u7qeQYBMRiztcD+CcIuoyq7UAFThamSP97ovAB9p89mLg4qLupVIDxGbGRAaI+4qDjVnVFNSNqhoHG7OqqWescbAxq5Qg77R233GwMauaesYaBxuzynGwMbNSuBtlZmUocp1NlTjYmFXJxB7E7CsONmaVEjBaz2jjYGNWIaK+3ah6buNuZpXjlo1Z1Xg2ysx6zgPEZlYWeYDYzEpRz1jjYGNWKYGnvs2sDEHUdIDYU99mVVPchudtSXqDpFskPZb+3GuMMkdI+oWkh1La7A81XbtS0hOS1qTjiE51OtiYVUgExGjkOrr0WeDWiJgD3Jret9oGnBERjRS8/yFpZtP1z0TEEelY06lCd6PMKiZ2dNlsyWcBWUYUyNLv3gGc96r7iPhV0+vfS9oE7AM8uzsVFtKykbRC0iZJD7a5fpykrU1Nrs+PVc5symsMEOc5OqffHc++Tbnb/gDsO15hSUcB04FfN52+OHWvlknao1OFRbVsrgS+Blw9Tpm7IuK9BdVnVlMTGiAeN/2upJ8C+41x6YJX1RgRUvsnsiTtD3wHWBIRjWbX+WRBajpZyuzzgIvGu9mi8kbdKWl2Ed/V8NQbn2XZcGvyvvp6zdapM3y27flSugn9q6D/eyLixHbXJD0laf+I2JiCyaY25V4H/Bi4ICLubvruRqtou6RvA5/udD9l/hv+Dkn3S/qJpLeOVUDScKNJ+PI2/wtpU1NE5Dq61Jx2dwnww9YCKWXvD4CrI2Jly7X9058CFgJjDqE0KyvY3AccHBF/DXwVuHGsQhGxPCLmRcS8aa+ZOv+lN9tpYmM23VgKvFvSY8CJ6T2S5km6PJX5IPBO4MwxprivkbQWWAvMAv61U4WlzEZFxHNNr2+S9A1JsyLi6TLqN+snZcxGRcRm4IQxzq8Gzk6vvwt8t83nj59onaUEG0n7AU+lgaijyFpUm8uo26yfRBSyhqaSCgk2kq4lm7OfJWk9cCEwDSAiLgNOBz4maQR4AVgUdV2Tbdatmg5XFjUbtbjD9a+RTY2bWQd1/e+wVxCbVYmf+jazspT0uELpHGzMqiTwmI2ZlcGzUWZWFg8Qm1nPpf1s6sjBxqxqHGzMrNciwrNRZlYOBxsz6z2P2ZhZOdyNMrMyBDDqYGNmPRYRjL68Y7JvoyccbMwqxt0oM+u9GnejvNGvWaXky4bZ7YxVnvS7qdyOpv2HVzWdP0TSPZLWSbo+bY4+LgcbsyqJrBuV5+hSnvS7AC80pdg9ten8JcCyiDgU2AKc1alCBxuzCgkgRkdzHV1aQJZ2l/TnwrwfTOlbjgca6V1yfd5jNmZVEkHkn42aJWl10/vlEbE852fzpt+dkeoYAZZGxI3A3sCzETGSyqwHDuhUoYONWZXEhGajyki/e3BEbJD0FuC2lCtqa94bbNZ1sJF0EFmO733JWoHLI+LSljICLgVOAbYBZ0bEfd3WbVY/UUQXKfumAtLvRsSG9Ofjku4AjgS+D8yUNJRaNwcCGzrdTxFjNiPApyJiLnAMcI6kuS1lTgbmpGMY+GYB9ZrVTwA7It/RnTzpd/eStEd6PQs4Fng4pWG6nSxFU9vPt+o62ETExkYrJSL+BDzCrv23BWT5giMlJ5/ZyBVsZq9W0gBxnvS7hwGrJd1PFlyWRsTD6dp5wLmS1pGN4VzRqcJCx2wkzSZrZt3TcukA4Mmm940BpY2Y2U5l7WeTM/3uz4HD23z+ceCoidRZWLCRtCdZX+4Tzbm9J/gdw2TdLPZ4/WBRt2bWP4KJzEb1laLS704jCzTXRMQNYxTZABzU9H7MAaU0bbcc4LVvml7PTT3MxlXcAHHVdD1mk2aargAeiYgvtym2CjhDmWOArU1z/GbWUN4K4tIV0bI5FvgwsFbSmnTuc8CbASLiMuAmsmnvdWRT3x8poF6zGqpvy6brYBMRPwPUoUwA53Rbl1ntNaa+a8griM0qJCIYfWmkc8E+5GBjViUTe1yhrzjYmFVJQIw42JhZzzm7gpmVINyyMbNSRDjYmFkJAka3+3EFM+u1kh7EnAwONmYV4jEbMyuHx2zMrCzuRplZ77kbZWZliAhGt9fz2SgnqTOrkjRmk+foRp70u5L+pin17hpJL0pamK5dKemJpmtHdKrTwcasSiqUfjcibm+k3iXLgLkN+K+mIp9pSs27plOFDjZmVZLGbHrdsmHi6XdPB34SEdt2t0IHG7NKKacbRf70uw2LgGtbzl0s6QFJyxr5pcbjAWKzConRCT2uMG6u74LS75JyvB0O3Nx0+nyyIDWdLEnBecBF492sg41ZpUzocYVxc30XkX43+SDwg4h4uem7G62i7ZK+DXy60826G2VWJeWN2XRMv9tkMS1dqEZG25RdZSHwYKcK3bIxq5LyHldYCnxP0lnAb8laL0iaB3w0Is5O72eT5Xz775bPXyNpH7JkB2uAj3aqsOtgI+kg4GqyAaYg6zde2lLmOLLI+UQ6dUNEjNu/M5uKoqQ9iPOk303vf0OWKru13PETrbOIls0I8KmIuE/Sa4FfSrqlKQF5w10R8d4C6jOrNT8b1UYaKNqYXv9J0iNkkbA12JhZBxHByKg3z+oo9e+OBO4Z4/I7JN0P/B74dEQ8NMbnh4Hh9Pb5u77w5KNF3l9Os4CnJ6HeyTKVfu9k/daDJ1J4R7hlMy5JewLfBz4REc+1XL4PODginpd0CnAjMKf1O9IageWt58skafV404l1M5V+bz/81iAYrWmwKWTqW9I0skBzTUTc0Ho9Ip6LiOfT65uAaZJmFVG3Wd2MRuQ6+k0Rs1ECrgAeiYgvtymzH/BUWql4FFmQ29xt3WZ1VNeWTRHdqGOBDwNrJa1J5z4HvBkgIi4je4jrY5JGgBeARRGVDc2T2o2bBFPp91b+t0bUtxul6v6dN5t6Dh3aN778ug/lKrtgy1d/WfUxqGZeQWxWKeHZKDPrvYC+HPzNww9iNpE0X9KjktZJ2mXnsjqRtELSJkkdH6Drd5IOknS7pIclPSTpnyf7ntqKbIA4z9FvHGwSSYPA14GTgbnAYklzJ/eueupKYP5k30RJGo/UzAWOAc6p7j/bqG2wcTfqFUcB6yLicQBJ15FtnVjLxy4i4s604rv2+umRmgA/rjAFHAA82fR+PXD0JN2L9UiHR2omXXiA2Kz/dXikphrCi/qmgg1kmwQ1HJjOWQ10eqSmKuo8G+Vg84p7gTmSDiELMouAv53cW7Ii5Hmkpjrqu4LYs1FJRIwAHyfbQf4R4HtjbYNRF5KuBX4B/IWk9Wl7yLpqPFJzfFMGx1Mm+6bGkrVsPBtVe+mJ9Jsm+z7KEBGLJ/seyhIRPyPbK7fyIoKXazob5ZaNWcWU0bKR9IG0wHE0bXLertyYC10lHSLpnnT+eknTO9XpYGNWMSXtZ/MgcBpwZ7sCHRa6XgIsi4hDgS1Ax264g41ZhURJK4gj4pGI6LTt7s6FrhHxEnAdsCANuB8PrEzl8uQK95iNWZWs59mbz40b8u5iOWO89LsFaLfQdW/g2TSp0ji/S7qXVg42ZhUSEYU9rzZeru+IGC8DZk842JjV1Hi5vnNqt9B1MzBT0lBq3eRaAOsxGzNrZ+dC1zTbtAhYlbb0vZ1su1/onCsccLAxm5IkvU/SeuAdwI8l3ZzOv0nSTdBxoet5wLmS1pGN4VzRsU7vQWxmZXDLxsxK4WBjZqVwsDGzUjjYmFkpHGzMrBQONmZWCgcbMyvF/wPZqxE8c4CsDAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "cond_coeff_mat = conditional_corrcoeff(\n", + " density=posterior,\n", + " condition=condition,\n", + " limits=torch.tensor([[-2., 2.]]*3),\n", + ")\n", + "fig, ax = plt.subplots(1,1, figsize=(4,4))\n", + "im = plt.imshow(cond_coeff_mat, clim=[-1, 1], cmap='PiYG')\n", + "_ = fig.colorbar(im)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So far, we have investigated the conditional distribution only at a specific `condition` sampled from the posterior. In many applications, it makes sense to repeat the above analyses with a different `condition` (another sample from the posterior), which can be interpreted as slicing the posterior at a different location. Note that `conditional_corrcoeff()` can directly compute the matrix for several `conditions` and then outputs the average over them. This can be done by passing a batch of $N$ conditions as the `condition` argument." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampling conditional distributions\n", + "\n", + "So far, we have demonstrated how one can plot 2D conditional distributions with `conditional_pairplot()` and how one can compute the pairwise conditional correlation coefficient with `conditional_corrcoeff()`. In some cases, it can be useful to keep a subset of parameters fixed and to vary **more than two** parameters. This can be done by sampling the conditonal posterior $p(\\theta_i | \\theta_{j \\neq i}, x_o)$. As of `sbi` `v0.18.0`, this functionality requires using the [sampler interface](https://www.mackelab.org/sbi/tutorial/11_sampler_interface/). In this tutorial, we demonstrate this functionality on a linear gaussian simulator with four parameters. We would like to fix the forth parameter to $\\theta_4=0.2$ and sample the first three parameters given that value, i.e. we want to sample $p(\\theta_1, \\theta_2, \\theta_3 | \\theta_4 = 0.2, x_o)$. For an application in neuroscience, see [Deistler, Gonçalves, Macke, 2021](https://www.biorxiv.org/content/10.1101/2021.07.30.454484v4.abstract)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we will use SNPE, but the same also works for SNLE and SNRE. First, we define the prior and the simulator and train the deep neural density estimator:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4683b28ba0614e87b33854d207e30da2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running 1000 simulations.: 0%| | 0/1000 [00:00" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from sbi.analysis import pairplot\n", + "\n", + "_ = pairplot(cond_samples, limits=[[-2, 2], [-2, 2], [-2, 2], [-2, 2]], figsize=(4, 4))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From d14fec3e06ec283d17cc0cf3343d2fe05257a585 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Fri, 22 Apr 2022 10:46:15 -0400 Subject: [PATCH 10/15] Ran isort, made sure _round is updated correctly --- sbi/inference/snle/snle_base.py | 8 ++++---- sbi/inference/snpe/snpe_base.py | 7 +++++-- sbi/inference/snre/snre_base.py | 8 ++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 33aefa8c4..2560c4e51 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -21,8 +21,8 @@ from sbi.utils import ( check_estimator_arg, check_prior, - mask_sims_from_prior, handle_invalid_x, + mask_sims_from_prior, validate_theta_and_x, warn_if_zscoring_changes_data, warn_on_invalid_x, @@ -182,11 +182,11 @@ def train( Returns: Density estimator that has learned the distribution $p(x|\theta)$. """ - - # Starting index for the training set (1 = discard round-0 samples). - start_idx = int(discard_prior_samples and self._round > 0) # Load data from most recent round. self._round = max(self._data_round_index) + # Starting index for the training set (1 = discard round-0 samples). + start_idx = int(discard_prior_samples and self._round > 0) + train_loader, val_loader = self.get_dataloaders( start_idx, diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 6d7b6ccef..7ba845f4f 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -26,13 +26,13 @@ from sbi.utils import ( RestrictedPrior, check_estimator_arg, + handle_invalid_x, test_posterior_net_for_multi_d_x, validate_theta_and_x, - x_shape_from_simulation, - handle_invalid_x, warn_if_zscoring_changes_data, warn_on_invalid_x, warn_on_invalid_x_for_snpec_leakage, + x_shape_from_simulation, ) from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior @@ -249,6 +249,9 @@ def train( Returns: Density estimator that approximates the distribution $p(\theta|x)$. """ + # Load data from most recent round. + self._round = max(self._data_round_index) + if self._round == 0 and self._neural_net is not None: assert force_first_round_loss, ( "You have already trained this neural network. After you had trained " diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index 698a57612..df1ef022f 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -19,9 +19,9 @@ clamp_and_warn, handle_invalid_x, validate_theta_and_x, - x_shape_from_simulation, warn_if_zscoring_changes_data, warn_on_invalid_x, + x_shape_from_simulation, ) from sbi.utils.sbiutils import mask_sims_from_prior @@ -186,11 +186,11 @@ def train( Returns: Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. """ - + # Load data from most recent round. + self._round = max(self._data_round_index) # Starting index for the training set (1 = discard round-0 samples). start_idx = int(discard_prior_samples and self._round > 0) - # Load data from most recent round. - self._round = max(self._data_round_index) + train_loader, val_loader = self.get_dataloaders( start_idx, From 2c87aa2080fbbdcb6b8c8b0602561812aae9095a Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Fri, 22 Apr 2022 12:29:38 -0400 Subject: [PATCH 11/15] Fixed bug with indicies --- sbi/inference/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index daae107f1..844cf09ec 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -258,13 +258,13 @@ def get_dataloaders( train_loader_kwargs = { "batch_size": min(training_batch_size, num_training_examples), "drop_last": True, - "sampler": SubsetRandomSampler(torch.arange(len(self.train_indices)).tolist() ), + "sampler": SubsetRandomSampler(self.train_indices.tolist() ), } val_loader_kwargs = { "batch_size": min(training_batch_size, num_validation_examples), "shuffle": False, "drop_last": True, - "sampler": SubsetRandomSampler(torch.arange(len(self.val_indices)).tolist() ), + "sampler": SubsetRandomSampler(self.val_indices.tolist() ), } if dataloader_kwargs is not None: train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs) From 7eb4551782b5650fa3cc0c7b5bfcaeb379458064 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Mon, 25 Apr 2022 15:00:46 -0400 Subject: [PATCH 12/15] Changes to tests to match new syntax --- tests/base_test.py | 7 ++----- tests/inference_with_NaN_simulator_test.py | 4 +--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/base_test.py b/tests/base_test.py index ca78fce97..e4cc4078c 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -1,6 +1,5 @@ import pytest import torch -from torch.utils.data import TensorDataset from sbi.inference import SNPE @@ -11,12 +10,10 @@ def test_get_dataloaders(training_batch_size): N = 1000 validation_fraction = 0.1 - dataset = TensorDataset(torch.ones(N), torch.zeros(N)) - inferer = SNPE() - + inferer.append_simulations(torch.ones(N), torch.zeros(N),warn_if_zscoring=False) _, val_loader = inferer.get_dataloaders( - dataset, + 0, training_batch_size=training_batch_size, validation_fraction=validation_fraction, ) diff --git a/tests/inference_with_NaN_simulator_test.py b/tests/inference_with_NaN_simulator_test.py index 368c2be1d..4380285d6 100644 --- a/tests/inference_with_NaN_simulator_test.py +++ b/tests/inference_with_NaN_simulator_test.py @@ -102,9 +102,7 @@ def linear_gaussian_nan( inference = method(prior=prior) theta, x = simulate_for_sbi(simulator, prior, num_simulations) - _ = inference.append_simulations(theta, x).train( - exclude_invalid_x=exclude_invalid_x - ) + _ = inference.append_simulations(theta, x, exclude_invalid_x=exclude_invalid_x).train() posterior = inference.build_posterior() samples = posterior.sample((num_samples,), x=x_o) From 980a95bea4de6732b19e4ed0c5d2a8b5268a2f9e Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Fri, 27 May 2022 13:27:03 -0400 Subject: [PATCH 13/15] Now compatible with black and pyright --- sbi/examples/minimal.py | 3 +- sbi/inference/base.py | 46 ++++++++++-------- sbi/inference/snle/snle_base.py | 42 +++++++++------- sbi/inference/snpe/snpe_base.py | 56 ++++++++++++---------- sbi/inference/snre/snre_base.py | 50 ++++++++++--------- sbi/utils/sbiutils.py | 12 +++-- tests/base_test.py | 2 +- tests/inference_with_NaN_simulator_test.py | 4 +- 8 files changed, 123 insertions(+), 92 deletions(-) diff --git a/sbi/examples/minimal.py b/sbi/examples/minimal.py index 408d6d4e6..31864be94 100644 --- a/sbi/examples/minimal.py +++ b/sbi/examples/minimal.py @@ -37,7 +37,8 @@ def flexible(): inference = SNPE(prior) theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=500) - density_estimator = inference.append_simulations(theta, x).train() + inference.append_simulations(theta, x) + density_estimator = inference.train() posterior = inference.build_posterior(density_estimator) posterior.sample((100,), x=x_o) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 844cf09ec..5f4ffafd4 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -128,14 +128,13 @@ def __init__( # Initialize roundwise (theta, x, prior_masks) for storage of parameters, # simulations and masks indicating if simulations came from prior. - self._dataset = None + self._dataset = data.Dataset() self._num_sims_per_round = [] self._model_bank = [] # Initialize list that indicates the round from which simulations were drawn. self._data_round_index = [] - self._round = 0 self._val_log_prob = float("-Inf") @@ -177,19 +176,21 @@ def get_simulations( Returns: Parameters, simulation outputs, prior masks. """ - #This is a pretty clunky implementation but not sure this will be used much with - #new implementation of `get_dataloaders` - indicies = get_simulations_indcies(self._num_sims_per_round, self._data_round_index, starting_round) - theta,x,prior_masks = [],[],[] + # This is a pretty clunky implementation but not sure this will be used much with + # new implementation of `get_dataloaders` + indicies = get_simulations_indcies( + self._num_sims_per_round, self._data_round_index, starting_round + ) + theta, x, prior_masks = [], [], [] for ind in indicies: - theta_cur,x_cur,prior_mask_cur = self._dataset[ind] + theta_cur, x_cur, prior_mask_cur = self._dataset[ind] theta.append(theta_cur) x.append(x_cur) prior_masks.append(prior_mask_cur) - - theta = torch.stack(theta).squeeze() - x = torch.stack(x).squeeze() + + theta = torch.stack(theta) + x = torch.stack(x) prior_masks = torch.stack(prior_masks).squeeze() return theta, x, prior_masks @@ -218,7 +219,7 @@ def get_dataloaders( validation_fraction: float = 0.1, resume_training: bool = False, dataloader_kwargs: Optional[dict] = None, - ) -> Tuple[data.DataLoader, data.DataLoader]: + ) -> Tuple[data.DataLoader, data.DataLoader]: """Return dataloaders for training and validation. Args: @@ -233,9 +234,11 @@ def get_dataloaders( Tuple of dataloaders for training and validation. """ - - #Generate indicies to use based on starting_round - indices = get_simulations_indcies(self._num_sims_per_round, self._data_round_index, starting_round) + + # Generate indicies to use based on starting_round + indices = get_simulations_indcies( + self._num_sims_per_round, self._data_round_index, starting_round + ) # Get total number of training examples. num_examples = len(indices) @@ -258,22 +261,22 @@ def get_dataloaders( train_loader_kwargs = { "batch_size": min(training_batch_size, num_training_examples), "drop_last": True, - "sampler": SubsetRandomSampler(self.train_indices.tolist() ), + "sampler": SubsetRandomSampler(self.train_indices.tolist()), } val_loader_kwargs = { "batch_size": min(training_batch_size, num_validation_examples), "shuffle": False, "drop_last": True, - "sampler": SubsetRandomSampler(self.val_indices.tolist() ), + "sampler": SubsetRandomSampler(self.val_indices.tolist()), } if dataloader_kwargs is not None: train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs) val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs) train_loader = data.DataLoader(self._dataset, **train_loader_kwargs) - val_loader = data.DataLoader(self._dataset, **val_loader_kwargs) + val_loader = data.DataLoader(self._dataset, **val_loader_kwargs) - return train_loader,val_loader + return train_loader, val_loader def _converged(self, epoch: int, stop_after_epochs: int) -> bool: """Return whether the training converged yet and save best model state so far. @@ -356,8 +359,11 @@ def _report_convergence_at_end( ) def _summarize( - self, round_: int, x_o: Union[Tensor, None], theta_bank: Union[Tensor,None] - , x_bank: Union[Tensor,None] + self, + round_: int, + x_o: Union[Tensor, None], + theta_bank: Union[Tensor, None], + x_bank: Union[Tensor, None], ) -> None: """Update the summary_writer with statistics for a given round. diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 2560c4e51..436125aac 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -90,8 +90,8 @@ def append_simulations( warn_on_invalid: bool = True, warn_if_zscoring: bool = True, return_self: bool = True, - data_device: str = None, - ) -> "LikelihoodEstimator": + data_device: Optional[str] = None, + ) -> Union["LikelihoodEstimator", None]: r"""Store parameters and simulation outputs to use them for later training. Data are stored as entries in lists for each type of variable (parameter/data). @@ -131,20 +131,30 @@ def append_simulations( x = x[is_valid_x] theta = theta[is_valid_x] - if data_device is None: data_device = self._device + if data_device is None: + data_device = self._device theta, x = validate_theta_and_x(theta, x, training_device=data_device) prior_masks = mask_sims_from_prior(int(from_round), theta.size(0)) - if self._dataset is None: - #If first round, set up ConcatDataset - self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) + if len(self._num_sims_per_round) == 0: + # If first round, set up ConcatDataset + self._dataset = data.ConcatDataset( + [ + data.TensorDataset(theta, x, prior_masks), + ] + ) else: - #Otherwise append to Dataset - self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) + # Otherwise append to Dataset + self._dataset = data.ConcatDataset( + self._dataset.datasets + + [ + data.TensorDataset(theta, x, prior_masks), + ] + ) self._num_sims_per_round.append(theta.size(0)) - self._data_round_index.append(int(from_round) ) - + self._data_round_index.append(int(from_round)) + if return_self: return self @@ -186,7 +196,6 @@ def train( self._round = max(self._data_round_index) # Starting index for the training set (1 = discard round-0 samples). start_idx = int(discard_prior_samples and self._round > 0) - train_loader, val_loader = self.get_dataloaders( start_idx, @@ -202,15 +211,14 @@ def train( # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - - #Get theta,x from dataset to initialize NN - test_theta = self._dataset.datasets[0].tensors[0][:100] - test_x = self._dataset.datasets[0].tensors[1][:100] + # Get theta,x from dataset to initialize NN + theta, x, _ = self.get_simulations() self._neural_net = self._build_neural_net( - test_theta, test_x + theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu") ) - self._x_shape = x_shape_from_simulation(test_x) + self._x_shape = x_shape_from_simulation(x[:training_batch_size].to("cpu")) + del theta, x assert ( len(self._x_shape) < 3 ), "SNLE cannot handle multi-dimensional simulator output." diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 7ba845f4f..44ffb19e7 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -96,8 +96,8 @@ def append_simulations( warn_on_invalid: bool = True, warn_if_zscoring: bool = True, return_self: bool = True, - data_device: str = None, - ) -> "PosteriorEstimator": + data_device: Optional[str] = None, + ) -> Union["PosteriorEstimator", None]: r"""Store parameters and simulation outputs to use them for later training. Data are stored as entries in lists for each type of variable (parameter/data). @@ -127,9 +127,6 @@ def append_simulations( """ # Add ability to specify device data is saved on - if data_device is None: data_device = self._device - theta, x = validate_theta_and_x(theta, x, training_device=data_device) - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) @@ -145,6 +142,9 @@ def append_simulations( x = x[is_valid_x] theta = theta[is_valid_x] + if data_device is None: + data_device = self._device + theta, x = validate_theta_and_x(theta, x, training_device=data_device) self._check_proposal(proposal) @@ -169,13 +169,21 @@ def append_simulations( self._data_round_index.append(max(self._data_round_index) + 1) prior_masks = mask_sims_from_prior(1, theta.size(0)) - - if self._dataset is None: - #If first round, set up ConcatDataset - self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) + if len(self._num_sims_per_round) == 0: + # If first round, set up ConcatDataset + self._dataset = data.ConcatDataset( + [ + data.TensorDataset(theta, x, prior_masks), + ] + ) else: - #Otherwise append to Dataset - self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) + # Otherwise append to Dataset + self._dataset = data.ConcatDataset( + self._dataset.datasets + + [ + data.TensorDataset(theta, x, prior_masks), + ] + ) self._num_sims_per_round.append(theta.size(0)) self._proposal_roundwise.append(proposal) @@ -194,7 +202,7 @@ def append_simulations( theta_prior = self.get_simulations()[0] self._prior = ImproperEmpirical(theta_prior, ones(theta_prior.shape[0])) - #Add ability to not return self + # Add ability to not return self if return_self: return self @@ -301,22 +309,20 @@ def train( # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - #Get theta,x from dataset to initialize NN - test_theta = self._dataset.datasets[0].tensors[0][:100] - test_x = self._dataset.datasets[0].tensors[1][:100] - + # Get theta,x from dataset to initialize NN + theta, x, _ = self.get_simulations() self._neural_net = self._build_neural_net( - test_theta, test_x + theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu") + ) + self._x_shape = x_shape_from_simulation(x[:training_batch_size].to("cpu")) + + test_posterior_net_for_multi_d_x( + self._neural_net, + theta[:training_batch_size].to("cpu"), + x[:training_batch_size].to("cpu"), ) - # If data on training device already move net as well. - if ( - not self._device == "cpu" - and f"{test_x.device.type}:{test_x.device.index}" == self._device - ): - self._neural_net.to(self._device) - test_posterior_net_for_multi_d_x(self._neural_net, test_theta, test_x) - self._x_shape = x_shape_from_simulation(test_x) + del theta, x # Move entire net to device for training. self._neural_net.to(self._device) diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index df1ef022f..df5204045 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -89,8 +89,8 @@ def append_simulations( warn_on_invalid: bool = True, warn_if_zscoring: bool = True, return_self: bool = True, - data_device: str = None, - ) -> "RatioEstimator": + data_device: Optional[str] = None, + ) -> Union["RatioEstimator", None]: r"""Store parameters and simulation outputs to use them for later training. Data are stored as entries in lists for each type of variable (parameter/data). @@ -119,8 +119,6 @@ def append_simulations( NeuralInference object (returned so that this function is chainable). """ - theta, x = validate_theta_and_x(theta, x, training_device=self._device) - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) # Check for problematic z-scoring @@ -132,19 +130,30 @@ def append_simulations( x = x[is_valid_x] theta = theta[is_valid_x] - if data_device is None: data_device = self._device - theta, x = validate_theta_and_x(theta, x, training_device=data_device) + if data_device is None: + data_device = self._device + theta, x = validate_theta_and_x(theta, x, training_device=self._device) + prior_masks = mask_sims_from_prior(int(from_round), theta.size(0)) - if self._dataset is None: - #If first round, set up ConcatDataset - self._dataset = data.ConcatDataset( [data.TensorDataset(theta,x,prior_masks),] ) + if len(self._num_sims_per_round) == 0: + # If first round, set up ConcatDataset + self._dataset = data.ConcatDataset( + [ + data.TensorDataset(theta, x, prior_masks), + ] + ) else: - #Otherwise append to Dataset - self._dataset = data.ConcatDataset( self._dataset.datasets + [data.TensorDataset(theta,x,prior_masks),] ) - + # Otherwise append to Dataset + self._dataset = data.ConcatDataset( + self._dataset.datasets + + [ + data.TensorDataset(theta, x, prior_masks), + ] + ) + self._num_sims_per_round.append(theta.size(0)) - self._data_round_index.append(int(from_round) ) + self._data_round_index.append(int(from_round)) if return_self: return self @@ -158,7 +167,6 @@ def train( stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, - exclude_invalid_x: bool = True, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, @@ -187,10 +195,9 @@ def train( Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. """ # Load data from most recent round. - self._round = max(self._data_round_index) + self._round = max(self._data_round_index) # Starting index for the training set (1 = discard round-0 samples). start_idx = int(discard_prior_samples and self._round > 0) - train_loader, val_loader = self.get_dataloaders( start_idx, @@ -215,14 +222,13 @@ def train( # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - #Get theta,x from dataset to initialize NN - test_theta = self._dataset.datasets[0].tensors[0][:100] - test_x = self._dataset.datasets[0].tensors[1][:100] - + # Get theta,x from dataset to initialize NN + theta, x, _ = self.get_simulations() self._neural_net = self._build_neural_net( - test_theta, test_x + theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu") ) - self._x_shape = x_shape_from_simulation(test_x) + self._x_shape = x_shape_from_simulation(x[:training_batch_size].to("cpu")) + del x, theta self._neural_net.to(self._device) if not resume_training: diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 4f26cf848..acfc3fdf4 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -348,6 +348,7 @@ def get_simulations_since_round( [t for t, r in zip(data, data_round_indices) if r >= starting_round_index] ) + def get_simulations_indcies( num_sims_per_round: List, data_round_indices: List, starting_round_index: int ) -> Tensor: @@ -363,15 +364,16 @@ def get_simulations_indcies( counting from 0. """ inds = [] - for j, (n,r) in enumerate(zip(num_sims_per_round,data_round_indices)): - - #Where to start counting + for j, (n, r) in enumerate(zip(num_sims_per_round, data_round_indices)): + + # Where to start counting s_ind = sum(num_sims_per_round[:j]) - + if r >= starting_round_index: - inds.append(torch.arange(s_ind,s_ind + n) ) + inds.append(torch.arange(s_ind, s_ind + n)) return torch.cat(inds) + def mask_sims_from_prior(round_: int, num_simulations: int) -> Tensor: """Returns Tensor True where simulated from prior parameters. diff --git a/tests/base_test.py b/tests/base_test.py index e4cc4078c..b41c47728 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -11,7 +11,7 @@ def test_get_dataloaders(training_batch_size): validation_fraction = 0.1 inferer = SNPE() - inferer.append_simulations(torch.ones(N), torch.zeros(N),warn_if_zscoring=False) + inferer.append_simulations(torch.ones(N), torch.zeros(N), warn_if_zscoring=False) _, val_loader = inferer.get_dataloaders( 0, training_batch_size=training_batch_size, diff --git a/tests/inference_with_NaN_simulator_test.py b/tests/inference_with_NaN_simulator_test.py index 4380285d6..a9d5378cc 100644 --- a/tests/inference_with_NaN_simulator_test.py +++ b/tests/inference_with_NaN_simulator_test.py @@ -102,7 +102,9 @@ def linear_gaussian_nan( inference = method(prior=prior) theta, x = simulate_for_sbi(simulator, prior, num_simulations) - _ = inference.append_simulations(theta, x, exclude_invalid_x=exclude_invalid_x).train() + _ = inference.append_simulations( + theta, x, exclude_invalid_x=exclude_invalid_x + ).train() posterior = inference.build_posterior() samples = posterior.sample((num_samples,), x=x_o) From 4198a82ad3ca23b509feca23b877a29fc672191d Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Thu, 23 Jun 2022 11:02:46 -0400 Subject: [PATCH 14/15] Reverted to storing data as tensors in list --- examples/HH_helper_functions.py | 2 +- sbi/examples/minimal.py | 3 +- sbi/inference/base.py | 55 ++++++++++--------------- sbi/inference/snle/snle_base.py | 51 +++++++---------------- sbi/inference/snpe/snpe_base.py | 72 ++++++++++----------------------- sbi/inference/snre/snre_base.py | 51 +++++++---------------- sbi/utils/sbiutils.py | 25 ------------ sbi/utils/user_input_checks.py | 22 +++++----- tests/base_test.py | 2 +- 9 files changed, 85 insertions(+), 198 deletions(-) diff --git a/examples/HH_helper_functions.py b/examples/HH_helper_functions.py index 38cb91b06..ecced908b 100644 --- a/examples/HH_helper_functions.py +++ b/examples/HH_helper_functions.py @@ -145,7 +145,7 @@ def tau_p(x): + g_leak * E_leak + gbar_M * p[i - 1] * E_K + I[i - 1] - + nois_fact * rng.randn() / (tstep ** 0.5) + + nois_fact * rng.randn() / (tstep**0.5) ) / (tau_V_inv * C) V[i] = V_inf + (V[i - 1] - V_inf) * np.exp(-tstep * tau_V_inv) n[i] = n_inf(V[i]) + (n[i - 1] - n_inf(V[i])) * np.exp(-tstep / tau_n(V[i])) diff --git a/sbi/examples/minimal.py b/sbi/examples/minimal.py index 31864be94..408d6d4e6 100644 --- a/sbi/examples/minimal.py +++ b/sbi/examples/minimal.py @@ -37,8 +37,7 @@ def flexible(): inference = SNPE(prior) theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=500) - inference.append_simulations(theta, x) - density_estimator = inference.train() + density_estimator = inference.append_simulations(theta, x).train() posterior = inference.build_posterior(density_estimator) posterior.sample((100,), x=x_o) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 5f4ffafd4..880b12ef1 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -18,15 +18,8 @@ import sbi.inference from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.simulators.simutils import simulate_in_batches -from sbi.utils import ( - check_prior, - get_log_root, - handle_invalid_x, - warn_if_zscoring_changes_data, - warn_on_invalid_x, - warn_on_invalid_x_for_snpec_leakage, -) -from sbi.utils.sbiutils import get_simulations_indcies +from sbi.utils import check_prior, get_log_root +from sbi.utils.sbiutils import get_simulations_since_round from sbi.utils.torchutils import check_if_prior_on_device, process_device from sbi.utils.user_input_checks import prepare_for_sbi @@ -128,8 +121,9 @@ def __init__( # Initialize roundwise (theta, x, prior_masks) for storage of parameters, # simulations and masks indicating if simulations came from prior. - self._dataset = data.Dataset() - self._num_sims_per_round = [] + self._theta_roundwise = [] + self._x_roundwise = [] + self._prior_masks = [] self._model_bank = [] # Initialize list that indicates the round from which simulations were drawn. @@ -176,22 +170,15 @@ def get_simulations( Returns: Parameters, simulation outputs, prior masks. """ - # This is a pretty clunky implementation but not sure this will be used much with - # new implementation of `get_dataloaders` - indicies = get_simulations_indcies( - self._num_sims_per_round, self._data_round_index, starting_round + theta = get_simulations_since_round( + self._theta_roundwise, self._data_round_index, starting_round + ) + x = get_simulations_since_round( + self._x_roundwise, self._data_round_index, starting_round + ) + prior_masks = get_simulations_since_round( + self._prior_masks, self._data_round_index, starting_round ) - theta, x, prior_masks = [], [], [] - - for ind in indicies: - theta_cur, x_cur, prior_mask_cur = self._dataset[ind] - theta.append(theta_cur) - x.append(x_cur) - prior_masks.append(prior_mask_cur) - - theta = torch.stack(theta) - x = torch.stack(x) - prior_masks = torch.stack(prior_masks).squeeze() return theta, x, prior_masks @@ -235,20 +222,20 @@ def get_dataloaders( """ - # Generate indicies to use based on starting_round - indices = get_simulations_indcies( - self._num_sims_per_round, self._data_round_index, starting_round - ) + # + theta, x, prior_masks = self.get_simulations(starting_round) + + dataset = data.TensorDataset(theta, x, prior_masks) # Get total number of training examples. - num_examples = len(indices) + num_examples = theta.size(0) # Select random train and validation splits from (theta, x) pairs. num_training_examples = int((1 - validation_fraction) * num_examples) num_validation_examples = num_examples - num_training_examples if not resume_training: # Seperate indicies for training and validation - permuted_indices = indices[torch.randperm(num_examples)] + permuted_indices = torch.randperm(num_examples) self.train_indices, self.val_indices = ( permuted_indices[:num_training_examples], permuted_indices[num_training_examples:], @@ -273,8 +260,8 @@ def get_dataloaders( train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs) val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs) - train_loader = data.DataLoader(self._dataset, **train_loader_kwargs) - val_loader = data.DataLoader(self._dataset, **val_loader_kwargs) + train_loader = data.DataLoader(dataset, **train_loader_kwargs) + val_loader = data.DataLoader(dataset, **val_loader_kwargs) return train_loader, val_loader diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 436125aac..626690b51 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -86,12 +86,8 @@ def append_simulations( theta: Tensor, x: Tensor, from_round: int = 0, - exclude_invalid_x: bool = True, - warn_on_invalid: bool = True, - warn_if_zscoring: bool = True, - return_self: bool = True, data_device: Optional[str] = None, - ) -> Union["LikelihoodEstimator", None]: + ) -> "LikelihoodEstimator": r"""Store parameters and simulation outputs to use them for later training. Data are stored as entries in lists for each type of variable (parameter/data). @@ -107,12 +103,6 @@ def append_simulations( With default settings, this is not used at all for `SNLE`. Only when the user later on requests `.train(discard_prior_samples=True)`, we use these indices to find which training data stemmed from the prior. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. - warn_on_invalid: Whether to warn if data is invalid - warn_if_zscoring: Whether to test if z-scoring causes duplicates - return_self: Whether to return a instance of the class, allows chaining - with `.train()`. Setting `False` decreases memory overhead. data_device: Where to store the data, default is on the same device where the training is happening. If training a large dataset on a GPU with not much VRAM can set to 'cpu' to store data on system memory instead. @@ -120,43 +110,30 @@ def append_simulations( NeuralInference object (returned so that this function is chainable). """ - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) - - # Check for problematic z-scoring - if warn_if_zscoring: - warn_if_zscoring_changes_data(x[is_valid_x]) - if warn_on_invalid: - warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) + is_valid_x, num_nans, num_infs = handle_invalid_x(x, True) # Hardcode to True x = x[is_valid_x] theta = theta[is_valid_x] + # Check for problematic z-scoring + warn_if_zscoring_changes_data(x) + warn_on_invalid_x(num_nans, num_infs, True) + if data_device is None: data_device = self._device - theta, x = validate_theta_and_x(theta, x, training_device=data_device) + theta, x = validate_theta_and_x( + theta, x, data_device=data_device, training_device=self._device + ) + prior_masks = mask_sims_from_prior(int(from_round), theta.size(0)) - if len(self._num_sims_per_round) == 0: - # If first round, set up ConcatDataset - self._dataset = data.ConcatDataset( - [ - data.TensorDataset(theta, x, prior_masks), - ] - ) - else: - # Otherwise append to Dataset - self._dataset = data.ConcatDataset( - self._dataset.datasets - + [ - data.TensorDataset(theta, x, prior_masks), - ] - ) + self._theta_roundwise.append(theta) + self._x_roundwise.append(x) + self._prior_masks.append(prior_masks) - self._num_sims_per_round.append(theta.size(0)) self._data_round_index.append(int(from_round)) - if return_self: - return self + return self def train( self, diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index b0c90730a..e2c45c0dc 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -92,12 +92,8 @@ def append_simulations( theta: Tensor, x: Tensor, proposal: Optional[DirectPosterior] = None, - exclude_invalid_x: bool = True, - warn_on_invalid: bool = True, - warn_if_zscoring: bool = True, - return_self: bool = True, data_device: Optional[str] = None, - ) -> Union["PosteriorEstimator", None]: + ) -> "PosteriorEstimator": r"""Store parameters and simulation outputs to use them for later training. Data are stored as entries in lists for each type of variable (parameter/data). @@ -112,12 +108,6 @@ def append_simulations( proposal: The distribution that the parameters $\theta$ were sampled from. Pass `None` if the parameters were sampled from the prior. If not `None`, it will trigger a different loss-function. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. - warn_on_invalid: Whether to warn if data is invalid - warn_if_zscoring: Whether to test if z-scoring causes duplicates - return_self: Whether to return a instance of the class, allows chaining - with `.train()`. Setting `False` decreases memory overhead. data_device: Where to store the data, default is on the same device where the training is happening. If training a large dataset on a GPU with not much VRAM can set to 'cpu' to store data on system memory instead. @@ -126,26 +116,24 @@ def append_simulations( NeuralInference object (returned so that this function is chainable). """ - # Add ability to specify device data is saved on - - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) - - # Check for problematic z-scoring - if warn_if_zscoring: - warn_if_zscoring_changes_data(x[is_valid_x]) - if warn_on_invalid: - warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) - warn_on_invalid_x_for_snpec_leakage( - num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round - ) + is_valid_x, num_nans, num_infs = handle_invalid_x(x, True) # Hardcode to True x = x[is_valid_x] theta = theta[is_valid_x] + # Check for problematic z-scoring + warn_if_zscoring_changes_data(x) + warn_on_invalid_x(num_nans, num_infs, True) + warn_on_invalid_x_for_snpec_leakage( + num_nans, num_infs, True, type(self).__name__, self._round + ) + if data_device is None: data_device = self._device - theta, x = validate_theta_and_x(theta, x, training_device=data_device) + theta, x = validate_theta_and_x( + theta, x, data_device=data_device, training_device=self._device + ) self._check_proposal(proposal) if ( @@ -169,23 +157,10 @@ def append_simulations( self._data_round_index.append(max(self._data_round_index) + 1) prior_masks = mask_sims_from_prior(1, theta.size(0)) - if len(self._num_sims_per_round) == 0: - # If first round, set up ConcatDataset - self._dataset = data.ConcatDataset( - [ - data.TensorDataset(theta, x, prior_masks), - ] - ) - else: - # Otherwise append to Dataset - self._dataset = data.ConcatDataset( - self._dataset.datasets - + [ - data.TensorDataset(theta, x, prior_masks), - ] - ) + self._theta_roundwise.append(theta) + self._x_roundwise.append(x) + self._prior_masks.append(prior_masks) - self._num_sims_per_round.append(theta.size(0)) self._proposal_roundwise.append(proposal) if self._prior is None or isinstance(self._prior, ImproperEmpirical): @@ -202,9 +177,7 @@ def append_simulations( theta_prior = self.get_simulations()[0] self._prior = ImproperEmpirical(theta_prior, ones(theta_prior.shape[0])) - # Add ability to not return self - if return_self: - return self + return self def train( self, @@ -310,16 +283,15 @@ def train( if self._neural_net is None or retrain_from_scratch: # Get theta,x from dataset to initialize NN - theta, x, _ = self.get_simulations() - self._neural_net = self._build_neural_net( - theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu") - ) - self._x_shape = x_shape_from_simulation(x[:training_batch_size].to("cpu")) + x = self._x_roundwise[0][:training_batch_size] + theta = self._theta_roundwise[0][:training_batch_size] + self._neural_net = self._build_neural_net(theta.to("cpu"), x.to("cpu")) + self._x_shape = x_shape_from_simulation(x.to("cpu")) test_posterior_net_for_multi_d_x( self._neural_net, - theta[:training_batch_size].to("cpu"), - x[:training_batch_size].to("cpu"), + theta.to("cpu"), + x.to("cpu"), ) del theta, x diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index df5204045..b50a1278d 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -85,12 +85,8 @@ def append_simulations( theta: Tensor, x: Tensor, from_round: int = 0, - exclude_invalid_x: bool = True, - warn_on_invalid: bool = True, - warn_if_zscoring: bool = True, - return_self: bool = True, data_device: Optional[str] = None, - ) -> Union["RatioEstimator", None]: + ) -> "RatioEstimator": r"""Store parameters and simulation outputs to use them for later training. Data are stored as entries in lists for each type of variable (parameter/data). @@ -106,12 +102,6 @@ def append_simulations( With default settings, this is not used at all for `SNRE`. Only when the user later on requests `.train(discard_prior_samples=True)`, we use these indices to find which training data stemmed from the prior. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. - warn_on_invalid: Whether to warn if data is invalid - warn_if_zscoring: Whether to test if z-scoring causes duplicates - return_self: Whether to return a instance of the class, allows chaining - with `.train()`. Setting `False` decreases memory overhead. data_device: Where to store the data, default is on the same device where the training is happening. If training a large dataset on a GPU with not much VRAM can set to 'cpu' to store data on system memory instead. @@ -119,44 +109,31 @@ def append_simulations( NeuralInference object (returned so that this function is chainable). """ - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) - - # Check for problematic z-scoring - if warn_if_zscoring: - warn_if_zscoring_changes_data(x[is_valid_x]) - if warn_on_invalid: - warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) + is_valid_x, num_nans, num_infs = handle_invalid_x(x, True) # Hardcode to True x = x[is_valid_x] theta = theta[is_valid_x] + # Check for problematic z-scoring + warn_if_zscoring_changes_data(x) + warn_on_invalid_x(num_nans, num_infs, True) + if data_device is None: data_device = self._device - theta, x = validate_theta_and_x(theta, x, training_device=self._device) + + theta, x = validate_theta_and_x( + theta, x, data_device=data_device, training_device=self._device + ) prior_masks = mask_sims_from_prior(int(from_round), theta.size(0)) - if len(self._num_sims_per_round) == 0: - # If first round, set up ConcatDataset - self._dataset = data.ConcatDataset( - [ - data.TensorDataset(theta, x, prior_masks), - ] - ) - else: - # Otherwise append to Dataset - self._dataset = data.ConcatDataset( - self._dataset.datasets - + [ - data.TensorDataset(theta, x, prior_masks), - ] - ) + self._theta_roundwise.append(theta) + self._x_roundwise.append(x) + self._prior_masks.append(prior_masks) - self._num_sims_per_round.append(theta.size(0)) self._data_round_index.append(int(from_round)) - if return_self: - return self + return self def train( self, diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index acfc3fdf4..bde9d7375 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -349,31 +349,6 @@ def get_simulations_since_round( ) -def get_simulations_indcies( - num_sims_per_round: List, data_round_indices: List, starting_round_index: int -) -> Tensor: - """ - Returns indicies for for all simulations round >= `starting_round`. Used in - `get_dataloaders` and `get_simulations` - - Args: - num_sims_per_round: Number of simulations per round - data_round_indices: List with same length as data, each entry is an integer that - indicates which round the data is from. - starting_round_index: From which round onwards to return the data. We start - counting from 0. - """ - inds = [] - for j, (n, r) in enumerate(zip(num_sims_per_round, data_round_indices)): - - # Where to start counting - s_ind = sum(num_sims_per_round[:j]) - - if r >= starting_round_index: - inds.append(torch.arange(s_ind, s_ind + n)) - return torch.cat(inds) - - def mask_sims_from_prior(round_: int, num_simulations: int) -> Tensor: """Returns Tensor True where simulated from prior parameters. diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 66f5a5ec7..443808c78 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -647,7 +647,7 @@ def check_estimator_arg(estimator: Union[str, Callable]) -> None: def validate_theta_and_x( - theta: Any, x: Any, training_device: str = "cpu" + theta: Any, x: Any, data_device: str = "cpu", training_device: str = "cpu" ) -> Tuple[Tensor, Tensor]: r""" Checks if the passed $(\theta, x)$ are valid. @@ -679,21 +679,21 @@ def validate_theta_and_x( assert theta.dtype == float32, "Type of parameters must be float32." assert x.dtype == float32, "Type of simulator outputs must be float32." - if str(x.device) != training_device: + if str(x.device) != data_device: warnings.warn( - f"Data x has device '{x.device}' " - f"different from the training_device '{training_device}', " - f"moving x to the training_device '{training_device}'." + f"Data x has device '{x.device}'." + f"Moving x to the data_device '{data_device}'." + f"Training will proceed on device '{training_device}'." ) - x = x.to(training_device) + x = x.to(data_device) - if str(theta.device) != training_device: + if str(theta.device) != data_device: warnings.warn( - f"Parameters theta has device '{theta.device}' " - f"different from the training_device '{training_device}', " - f"moving theta to the training_device '{training_device}'." + f"Parameters theta has device '{x.device}'. " + f"Moving theta to the data_device '{data_device}'." + f"Training will proceed on device '{training_device}'." ) - theta = theta.to(training_device) + theta = theta.to(data_device) return theta, x diff --git a/tests/base_test.py b/tests/base_test.py index b41c47728..c1cef1131 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -11,7 +11,7 @@ def test_get_dataloaders(training_batch_size): validation_fraction = 0.1 inferer = SNPE() - inferer.append_simulations(torch.ones(N), torch.zeros(N), warn_if_zscoring=False) + inferer.append_simulations(torch.ones(N), torch.zeros(N)) _, val_loader = inferer.get_dataloaders( 0, training_batch_size=training_batch_size, From 0cbc4fada376193fd3c0a89a540d3db9013e4360 Mon Sep 17 00:00:00 2001 From: tbmiller-astro Date: Mon, 27 Jun 2022 12:19:34 -0400 Subject: [PATCH 15/15] Use all data to initialize NN --- sbi/inference/snle/snle_base.py | 10 ++++------ sbi/inference/snpe/snpe_base.py | 5 ++--- sbi/inference/snre/snre_base.py | 10 ++++------ sbi/utils/user_input_checks.py | 5 +++++ 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 626690b51..1c04584b8 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -189,12 +189,10 @@ def train( # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - # Get theta,x from dataset to initialize NN - theta, x, _ = self.get_simulations() - self._neural_net = self._build_neural_net( - theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu") - ) - self._x_shape = x_shape_from_simulation(x[:training_batch_size].to("cpu")) + # Get theta,x to initialize NN + theta, x, _ = self.get_simulations(starting_round=start_idx) + self._neural_net = self._build_neural_net(theta.to("cpu"), x.to("cpu")) + self._x_shape = x_shape_from_simulation(x.to("cpu")) del theta, x assert ( len(self._x_shape) < 3 diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index e2c45c0dc..481b2de95 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -282,9 +282,8 @@ def train( # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - # Get theta,x from dataset to initialize NN - x = self._x_roundwise[0][:training_batch_size] - theta = self._theta_roundwise[0][:training_batch_size] + # Get theta,x to initialize NN + theta, x, _ = self.get_simulations(starting_round=start_idx) self._neural_net = self._build_neural_net(theta.to("cpu"), x.to("cpu")) self._x_shape = x_shape_from_simulation(x.to("cpu")) diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index b50a1278d..f448eedd4 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -199,12 +199,10 @@ def train( # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - # Get theta,x from dataset to initialize NN - theta, x, _ = self.get_simulations() - self._neural_net = self._build_neural_net( - theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu") - ) - self._x_shape = x_shape_from_simulation(x[:training_batch_size].to("cpu")) + # Get theta,x to initialize NN + theta, x, _ = self.get_simulations(starting_round=start_idx) + self._neural_net = self._build_neural_net(theta.to("cpu"), x.to("cpu")) + self._x_shape = x_shape_from_simulation(x.to("cpu")) del x, theta self._neural_net.to(self._device) diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 443808c78..8bd46c4e2 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -657,6 +657,10 @@ def validate_theta_and_x( 2) If they have the same batchsize. 3) If they are of `dtype=float32`. + Additionally, We move the data to the specified `data_device`. This is where the + data is stored and can be separate from `training_device`, where the + computations for training are performed. + Raises: AssertionError: If theta or x are not torch.Tensor-like, do not yield the same batchsize and do not have dtype==float32. @@ -664,6 +668,7 @@ def validate_theta_and_x( Args: theta: Parameters. x: Simulation outputs. + data_device: Device where data is stored. training_device: Training device for net. """ assert isinstance(theta, Tensor), "Parameters theta must be a `torch.Tensor`."