From 93152e1413b98b7cf6436676b3c0d2a9d6579a67 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Fri, 30 Apr 2021 16:41:11 +0200 Subject: [PATCH] Implementation of SNPE-A --- examples/sb_sbi.py | 179 +++++++ sbi/inference/__init__.py | 2 +- sbi/inference/snpe/snpe_a.py | 307 ++++++++++- sbi/neural_nets/mdn_snpe_a.py | 70 +++ sbi/neural_nets/mog_flow_snpe_a.py | 586 +++++++++++++++++++++ sbi/utils/get_nn_models.py | 41 +- tests/inference_with_NaN_simulator_test.py | 28 +- tests/linearGaussian_snpe_test.py | 42 +- tests/posterior_nn_test.py | 12 +- tests/user_input_checks_test.py | 41 +- 10 files changed, 1249 insertions(+), 59 deletions(-) create mode 100644 examples/sb_sbi.py create mode 100644 sbi/neural_nets/mdn_snpe_a.py create mode 100644 sbi/neural_nets/mog_flow_snpe_a.py diff --git a/examples/sb_sbi.py b/examples/sb_sbi.py new file mode 100644 index 000000000..37d3bea87 --- /dev/null +++ b/examples/sb_sbi.py @@ -0,0 +1,179 @@ +import math + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import torch +from HH_helper_functions import HHsimulator, calculate_summary_statistics, syn_current +from torch.distributions.multivariate_normal import MultivariateNormal + +import sbi.utils as utils +from sbi.inference import simulate_for_sbi +from sbi.inference.snpe.snpe_a import SNPE_A +from sbi.inference.snpe.snpe_c import SNPE_C +from sbi.utils import pairplot, posterior_nn +from sbi.utils.user_input_checks import prepare_for_sbi + +I, t_on, t_off, dt, t, A_soma = syn_current() + + +def run_HH_model(params): + params = np.asarray(params) + + # input current, time step + I, t_on, t_off, dt, t, A_soma = syn_current() + + t = np.arange(0, len(I), 1) * dt + + # initial voltage + V0 = -70 + + states = HHsimulator(V0, params.reshape(1, -1), dt, t, I) + + return dict(data=states.reshape(-1), time=t, dt=dt, I=I.reshape(-1)) + + +def simulation_wrapper(params): + """ + Returns summary statistics from conductance values in `params`. + + Summarizes the output of the HH simulator and converts it to `torch.Tensor`. + """ + obs = run_HH_model(params) + summstats = torch.as_tensor(calculate_summary_statistics(obs)) + return summstats + + +if __name__ == "__main__": + # Configure. + torch.manual_seed(0) + num_sim = 300 + true_params = np.array([50.0, 5.0]) + labels_params = [r"$g_{Na}$", r"$g_{K}$"] + observation_trace = run_HH_model(true_params) + observation_summary_statistics = calculate_summary_statistics(observation_trace) + method = "SNPE_A" + num_rounds = 2 + num_components = 4 + prior_min = [0.5, 1e-4] + prior_max = [80.0, 15.0] + prior = utils.torchutils.BoxUniform( + low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max) + ) + + # mean = torch.tensor([45, 6.5]) + # cov = torch.tensor([[3 * math.sqrt(45), 0], [0, 3 * math.sqrt(6.5)]]) + # prior = MultivariateNormal(loc=mean, covariance_matrix=cov) + + if method == "SNPE_A": + density_estimator = "mdn_snpe_a" + density_estimator = posterior_nn( + model=density_estimator, num_components=num_components + ) + snpe = SNPE_A(prior, density_estimator, num_components, num_rounds) + + else: + density_estimator = "maf" + density_estimator = posterior_nn( + model=density_estimator, num_components=num_components + ) + snpe = SNPE_C(prior, density_estimator) + + simulator, prior = prepare_for_sbi(simulation_wrapper, prior) + proposal = prior + + fig_th, ax_th = plt.subplots(1) + + # Start multi-round training. + for r in range(num_rounds + 1): + # Simulate and append. + thetas, data_sim = simulate_for_sbi( + simulator=simulator, + proposal=proposal, + num_simulations=num_sim, + num_workers=24, + ) + snpe.append_simulations(thetas, data_sim, proposal) + + # Plot the sampled thetas. + ax_th.scatter( + x=thetas[:, 0].numpy(), y=thetas[:, 1].numpy(), label=f"round {r}", s=10 + ) + if r == num_rounds: + break + + # Train. + density_estimator = snpe.train(retrain_from_scratch_each_round=False) + + if method == "SNPE_A": + posterior = snpe.build_posterior( + proposal=proposal, + density_estimator=density_estimator, + sample_with_mcmc=False, + ) + else: + posterior = snpe.build_posterior( + density_estimator=density_estimator, sample_with_mcmc=False + ) + + # Pretend we obtained the perfect + posterior.set_default_x(observation_summary_statistics) + proposal = posterior + + fig = plt.figure(figsize=(7, 5)) + gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1]) + ax = plt.subplot(gs[0]) + plt.plot(observation_trace["time"], observation_trace["data"]) + plt.ylabel("voltage (mV)") + plt.title("observed data") + plt.setp(ax, xticks=[], yticks=[-80, -20, 40]) + + ax = plt.subplot(gs[1]) + plt.plot(observation_trace["time"], I * A_soma * 1e3, "k", lw=2) + plt.xlabel("time (ms)") + plt.ylabel("input (nA)") + + ax.set_xticks( + [0, max(observation_trace["time"]) / 2, max(observation_trace["time"])] + ) + ax.set_yticks([0, 1.1 * np.max(I * A_soma * 1e3)]) + ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter("%.2f")) + + # Analysis of the posterior given the observed data + samples = posterior.sample((10000,), x=observation_summary_statistics) + + fig, axes = pairplot( + samples, + limits=[[0.5, 80], [1e-4, 15.0]], + ticks=[[0.5, 80], [1e-4, 15.0]], + figsize=(5, 5), + points=true_params, + points_offdiag={"markersize": 6}, + points_colors="r", + ) + + # Draw a sample from the posterior and convert to numpy for plotting. + posterior_sample = posterior.sample((1,), x=observation_summary_statistics).numpy() + + fig = plt.figure(figsize=(7, 5)) + + # plot observation + t = observation_trace["time"] + y_obs = observation_trace["data"] + plt.plot(t, y_obs, lw=2, label="observation") + + # simulate and plot samples + x = run_HH_model(posterior_sample) + plt.plot(t, x["data"], "--", lw=2, label="posterior sample") + + plt.xlabel("time (ms)") + plt.ylabel("voltage (mV)") + + ax = plt.gca() + handles, labels = ax.get_legend_handles_labels() + ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.3, 1), loc="upper right") + + ax.set_xticks([0, 60, 120]) + ax.set_yticks([-80, -20, 40]) + + plt.show() diff --git a/sbi/inference/__init__.py b/sbi/inference/__init__.py index c07dd195b..5d09a704b 100644 --- a/sbi/inference/__init__.py +++ b/sbi/inference/__init__.py @@ -21,7 +21,7 @@ ) from sbi.inference.snle.snle_a import SNLE_A # Unimplemented: don't export -# from sbi.inference.snpe.snpe_a import SNPE_A +from sbi.inference.snpe.snpe_a import SNPE_A from sbi.inference.snpe.snpe_b import SNPE_B from sbi.inference.snpe.snpe_c import SNPE_C # noqa: F401 from sbi.inference.snre import SNRE, SNRE_A, SNRE_B # noqa: F401 diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index ae9083d64..04e93cb41 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -1,36 +1,315 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . +import warnings +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, Optional, Union -from typing import Callable, Optional, Union - -from torch.utils.tensorboard import SummaryWriter +import torch +from pyknos.mdn.mdn import MultivariateGaussianMDN +from torch import Tensor +from torch.distributions import MultivariateNormal +import sbi.utils as utils +from sbi.inference.posteriors.direct_posterior import DirectPosterior from sbi.inference.snpe.snpe_base import PosteriorEstimator +from sbi.types import TensorboardSummaryWriter, TorchModule class SNPE_A(PosteriorEstimator): def __init__( self, - simulator: Callable, - prior, - num_workers: int = 1, - simulation_batch_size: int = 1, - density_estimator: Union[str, Callable] = "mdn", - calibration_kernel: Optional[Callable] = None, - exclude_invalid_x: bool = True, + prior: Optional[Any] = None, + density_estimator: Union[str, Callable] = "mdn_snpe_a", + num_components: int = 10, + num_rounds: int = 1, device: str = "cpu", logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[SummaryWriter] = None, + summary_writer: Optional[TensorboardSummaryWriter] = None, show_progress_bars: bool = True, - show_round_summary: bool = False, + **unused_args, ): - """SNPE-A [1]. CURRENTLY NOT IMPLEMENTED. + r"""SNPE-A [1]. + + https://github.com/mackelab/sbi/blob/main/sbi/inference/snpe/snpe_c.py + https://github.com/mackelab/sbi/blob/main/sbi/neural_nets/mdn.py [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional Density Estimation_, Papamakarios et al., NeurIPS 2016, https://arxiv.org/abs/1605.06376. + Args: + prior: A probability distribution that expresses prior knowledge about the + parameters, e.g. which ranges are meaningful for them. Any + object with `.log_prob()`and `.sample()` (for example, a PyTorch + distribution) can be used. + density_estimator: If it is a string, use a pre-configured network of the + provided type (one of nsf, maf, mdn, made). Alternatively, a function + that builds a custom neural network can be provided. The function will + be called with the first batch of simulations (theta, x), which can + thus be used for shape inference and potentially for z-scoring. It + needs to return a PyTorch `nn.Module` implementing the density + estimator. The density estimator needs to provide the methods + `.log_prob` and `.sample()`. + num_components: + Number of components of the mixture of Gaussians. This number is set to 1 before + running Algorithm 1, and then later set to the specified value before running + Algorithm 2. + num_rounds: Total number of training rounds. For all but the last round, Algorithm 1 + from [1] is executed. For last round, Algorithm 2 from [1] is executed once. + By default, `num_rounds` is set to 1, i.e. only Algorithm 2 is executed once + without training the proposal prior using Algorithm 1. + device: torch device on which to compute, e.g. gpu, cpu. + logging_level: Minimum severity of messages to log. One of the strings + INFO, WARNING, DEBUG, ERROR and CRITICAL. + summary_writer: A tensorboard `SummaryWriter` to control, among others, log + file location (default is `/logs`.) + show_progress_bars: Whether to show a progressbar during training. + unused_args: Absorbs additional arguments. No entries will be used. If it + is not empty, we warn. In future versions, when the new interface of + 0.14.0 is more mature, we will remove this argument. + """ + + self._num_rounds = num_rounds + # TODO How to extract the number of components from density_estimator if that is a callable? + self._num_components = num_components + + # WARNING: sneaky trick ahead. We proxy the parent's `train` here, + # requiring the signature to have `num_atoms`, save it for use below, and + # continue. It's sneaky because we are using the object (self) as a namespace + # to pass arguments between functions, and that's implicit state management. + kwargs = utils.del_entries( + locals(), entries=("self", "__class__", "unused_args", "num_rounds", "num_components") + ) + super().__init__(**kwargs) + + def train( + self, + training_batch_size: int = 50, + learning_rate: float = 5e-4, + validation_fraction: float = 0.1, + stop_after_epochs: int = 20, + max_num_epochs: Optional[int] = None, + clip_max_norm: Optional[float] = 5.0, + calibration_kernel: Optional[Callable] = None, + exclude_invalid_x: bool = True, + resume_training: bool = False, + retrain_from_scratch_each_round: bool = False, + show_train_summary: bool = False, + dataloader_kwargs: Optional[Dict] = None, + ) -> DirectPosterior: + r""" + Return density estimator that approximates the distribution $p(\theta|x)$. + Args: + training_batch_size: Training batch size. + learning_rate: Learning rate for Adam optimizer. + validation_fraction: The fraction of data to use for validation. + stop_after_epochs: The number of epochs to wait for improvement on the + validation set before terminating training. + max_num_epochs: Maximum number of epochs to run. If reached, we stop + training even when the validation loss is still decreasing. If None, we + train until validation loss increases (see also `stop_after_epochs`). + clip_max_norm: Value at which to clip the total gradient norm in order to + prevent exploding gradients. Use None for no clipping. + calibration_kernel: A function to calibrate the loss with respect to the + simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. + exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` + during training. Expect errors, silent or explicit, when `False`. + resume_training: Can be used in case training time is limited, e.g. on a + cluster. If `True`, the split between train and validation set, the + optimizer, the number of epochs, and the best validation log-prob will + be restored from the last time `.train()` was called. + retrain_from_scratch_each_round: Whether to retrain the conditional density + estimator for the posterior from scratch each round. + show_train_summary: Whether to print the number of epochs and validation + loss and leakage after the training. + dataloader_kwargs: Additional or updated kwargs to be passed to the training + and validation dataloaders (like, e.g., a collate_fn) + Returns: + Density estimator that approximates the distribution $p(\theta|x)$. + """ + kwargs = utils.del_entries(locals(), entries=("self", "__class__")) + + # SNPE-A always discards the prior samples. + kwargs["discard_prior_samples"] = True + + self._round = max(self._data_round_index) + + # In case there is will only be one round, train with Algorithm 2 from [1]. + if self._num_rounds == 1: + self._build_neural_net = partial( + self._build_neural_net, num_components=self._num_components + ) + + # Run Algorithm 1 from [1]. + elif self._round + 1 < self._num_rounds: + # Wrap the function that builds the MDN such that we can make + # sure that there is only one component when running. + self._build_neural_net = partial(self._build_neural_net, num_components=1) + + # Run Algorithm 2 from [1]. + elif self._round + 1 == self._num_rounds: + # Now switch to the specified number of components. + self._build_neural_net = partial( + self._build_neural_net, num_components=self._num_components + ) + + # Extend the MDN to the originally desired number of components. + self._expand_mog() + + else: + warnings.warn( + f"Running SNPE-A for more than the specified number of rounds {self._num_rounds} implies running" + f"Algorithm 2 from [1] multiple times, which can lead to numerical issues. Moreover, the number of " + f"components in the mixture of Gaussian increases with every round after {self._num_rounds}.", + UserWarning, + ) + + return super().train(**kwargs) + + def build_posterior( + self, + proposal: Union[MultivariateNormal, utils.BoxUniform, DirectPosterior], + density_estimator: Optional[TorchModule] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, + sample_with_mcmc: bool = False, + mcmc_method: str = "slice_np", + mcmc_parameters: Optional[Dict[str, Any]] = None, + ) -> DirectPosterior: + r""" + Build posterior from the neural density estimator. + + For SNPE, the posterior distribution that is returned here implements the TODO + following functionality over the raw neural density estimator: + + - correct the calculation of the log probability such that it compensates for + the leakage. + - reject samples that lie outside of the prior bounds. + - alternatively, if leakage is very high (which can happen for multi-round + SNPE), sample from the posterior with MCMC. + + Args: + proposal: The distribution that the parameters $\theta$ were sampled from. + density_estimator: The density estimator that the posterior is based on. + If `None`, use the latest neural density estimator that was trained. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `max_sampling_batch_size` to set the batch size for drawing new + samples from the candidate distribution, e.g., the posterior. Larger + batch size speeds up sampling. + sample_with_mcmc: Whether to sample with MCMC. MCMC can be used to deal + with high leakage. + mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, + `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy + implementation of slice sampling; select `hmc`, `nuts` or `slice` for + Pyro-based sampling. + mcmc_parameters: Dictionary overriding the default parameters for MCMC. + The following parameters are supported: `thin` to set the thinning + factor for the chain, `warmup_steps` to set the initial number of + samples to discard, `num_chains` for the number of chains, + `init_strategy` for the initialisation strategy for chains; `prior` will + draw init locations from prior, whereas `sir` will use + Sequential-Importance-Resampling using `init_strategy_num_candidates` + to find init locations. + + Returns: + Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. + """ + + if density_estimator is None: + density_estimator = deepcopy( + self._neural_net + ) # PosteriorEstimator.train() also returns a deepcopy, mimic this here + # If internal net is used device is defined. + device = self._device + else: + # Otherwise, infer it from the device of the net parameters. + device = next(density_estimator.parameters()).device + + # Set proposal of the density estimator. + # This also evokes the z-scoring correction is necessary. + if isinstance(proposal, (MultivariateNormal, utils.BoxUniform)): + density_estimator.set_proposal(proposal) + elif isinstance(proposal, DirectPosterior): + # Extract the MoGFlow_SNPE_A from the DirectPosterior. + density_estimator.set_proposal(proposal.net) + else: + raise TypeError( + "So far, only MultivariateNormal, BoxUniform, and DirectPosterior are" + "supported for the `proposal` arg in SNPE_A.build_posterior()." + ) + + self._posterior = DirectPosterior( + method_family="snpe", + neural_net=density_estimator, + prior=self._prior, + x_shape=self._x_shape, + rejection_sampling_parameters=rejection_sampling_parameters, + sample_with_mcmc=sample_with_mcmc, + mcmc_method=mcmc_method, + mcmc_parameters=mcmc_parameters, + device=device, + ) + + self._posterior._num_trained_rounds = self._round + 1 + + # Store models at end of each round. + self._model_bank.append(deepcopy(self._posterior)) + self._model_bank[-1].net.eval() + + return deepcopy(self._posterior) + + def _log_prob_proposal_posterior( + self, theta: Tensor, x: Tensor, masks: Tensor, proposal: Optional[Any] + ) -> Tensor: + """ + Return the log-probability of the proposal posterior. + + .. note:: + For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in + `_loss()` to be found in `snpe_base.py`. + + If the proposal is a MoG, the density estimator is a MoG, and the prior is + either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which + suffers from leakage). + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + proposal: Proposal distribution. + + Returns: Log-probability of the proposal posterior. """ + return self._neural_net.log_prob(theta, x) + + def _expand_mog(self, eps: float = 1e-5): + """ + Replicate a singe Gaussian trained with Algorithm 1 before continuing + with Algorithm 2. The weights and biases of the associated MDN layers + are repeated `num_components` times, slightly perturbed to break the + symmetry such that the gradients in the subsequent training are not + all identical. + + :param eps: Standard deviation for the random perturbation. + """ + assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN) + + # Increase the number of components + self._neural_net._distribution._num_components = self._num_components - raise NotImplementedError + # Expand the 1-dim Gaussian. + for name, param in self._neural_net.named_parameters(): + if any( + key in name for key in ["logits", "means", "unconstrained", "upper"] + ): + if "bias" in name: + param.data = param.data.repeat(self._num_components) + param.data.add_(torch.randn_like(param.data) * eps) + param.grad = None # let autograd construct a new gradient + elif "weight" in name: + param.data = param.data.repeat(self._num_components, 1) + param.data.add_(torch.randn_like(param.data) * eps) + param.grad = None # let autograd construct a new gradient diff --git a/sbi/neural_nets/mdn_snpe_a.py b/sbi/neural_nets/mdn_snpe_a.py new file mode 100644 index 000000000..044aa0bf1 --- /dev/null +++ b/sbi/neural_nets/mdn_snpe_a.py @@ -0,0 +1,70 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from pyknos.mdn.mdn import MultivariateGaussianMDN +from pyknos.nflows import transforms +from torch import Tensor, nn + +from sbi.neural_nets.mog_flow_snpe_a import MoGFlow_SNPE_A +from sbi.utils.sbiutils import standardizing_net, standardizing_transform + + +def build_mdn_snpe_a( + batch_x: Tensor = None, + batch_y: Tensor = None, + z_score_x: bool = True, + z_score_y: bool = True, + hidden_features: int = 50, + num_components: int = 10, + embedding_net: nn.Module = nn.Identity(), + **kwargs, +) -> nn.Module: + """Builds MDN p(x|y) with different sampling in training and evaluation mode. + + Args: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network. + z_score_y: Whether to z-score ys passing into the network. + hidden_features: Number of hidden features. + num_components: Number of components. + embedding_net: Optional embedding network for y. + kwargs: Additional arguments that are passed by the build function but are not + relevant for MDNs and are therefore ignored. + + Returns: + Neural network. + """ + x_numel = batch_x[0].numel() + # Infer the output dimensionality of the embedding_net by making a forward pass. + y_numel = embedding_net(batch_y[:1]).numel() + + transform = transforms.IdentityTransform() + + if z_score_x: + transform_zx = standardizing_transform(batch_x) + transform = transforms.CompositeTransform([transform_zx, transform]) + + if z_score_y: + embedding_net = nn.Sequential(standardizing_net(batch_y), embedding_net) + + distribution = MultivariateGaussianMDN( + features=x_numel, + context_features=y_numel, + hidden_features=hidden_features, + hidden_net=nn.Sequential( + nn.Linear(y_numel, hidden_features), + nn.ReLU(), + nn.Dropout(p=0.0), + nn.Linear(hidden_features, hidden_features), + nn.ReLU(), + nn.Linear(hidden_features, hidden_features), + nn.ReLU(), + ), + num_components=num_components, + custom_initialization=True, + ) + + neural_net = MoGFlow_SNPE_A(transform, distribution, embedding_net) + + return neural_net diff --git a/sbi/neural_nets/mog_flow_snpe_a.py b/sbi/neural_nets/mog_flow_snpe_a.py new file mode 100644 index 000000000..1445d57a7 --- /dev/null +++ b/sbi/neural_nets/mog_flow_snpe_a.py @@ -0,0 +1,586 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from typing import Union +from warnings import warn + +import torch +from pyknos.mdn.mdn import MultivariateGaussianMDN +from pyknos.nflows import flows +from pyknos.nflows.transforms import CompositeTransform +from torch import Tensor +from torch.distributions import MultivariateNormal + +import sbi.utils as utils +from sbi.utils import torchutils + + +class MoGFlow_SNPE_A(flows.Flow): + """ + A wrapper for nflow's `Flow` class to enable a different log prob calculation + sampling strategy for training and testing, tailored to SNPE-A [1] + + [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional + Density Estimation_, Papamakarios et al., NeurIPS 2016, + https://arxiv.org/abs/1605.06376. + [2] _Automatic Posterior Transformation for Likelihood-free Inference_, + Greenberg et al., ICML 2019, https://arxiv.org/abs/1905.07488. + """ + + def __init__( + self, + transform, + distribution, + embedding_net=None, + allow_precision_correction: bool = False, + ): + """Constructor. + + Args: + transform: A `Transform` object, it transforms data into noise. + distribution: A `Distribution` object, the base distribution of the flow that + generates the noise. + embedding_net: A `nn.Module` which has trainable parameters to encode the + context (condition). It is trained jointly with the flow. + allow_precision_correction: + Add a diagonal with the smallest eigenvalue in every entry in case + the precision matrix becomes ill-conditioned. + """ + # Construct the flow. + super().__init__(transform, distribution, embedding_net) + + self._proposal = None + self._allow_precision_correction = allow_precision_correction + + @property + def proposal( + self, + ) -> Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"]: + """Get the proposal of the previous round.""" + return self._proposal + + def set_proposal( + self, proposal: Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"] + ): + """Set the proposal of the previous round.""" + self._proposal = proposal + + # Take care of z-scoring, pre-compute and store prior terms. + self._set_state_for_mog_proposal() + + def _get_first_prior_from_proposal( + self, + ) -> Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"]: + """Iterate a possible chain of proposals.""" + curr_prior = self._proposal + + while curr_prior: + if isinstance(curr_prior, (utils.BoxUniform, MultivariateNormal)): + break + else: + curr_prior = curr_prior.proposal + + assert curr_prior is not None + return curr_prior + + def log_prob(self, inputs, context=None): + if self._proposal is None: + # Use Flow.lob_prob() if there has been no previous proposal memorized + # in this instance. This is the case if we are in the training + # loop, i.e. this MoGFlow_SNPE_A instance is not an attribute of a + # DirectPosterior instance. + return super().log_prob(inputs, context) # q_phi from eq (3) in [1] + + elif isinstance(self._proposal, (utils.BoxUniform, MultivariateNormal)): + # No importance re-weighting is needed if the proposal prior is the prior + return super().log_prob(inputs, context) + + else: + # When we want to compute the approx. posterior, a proposal prior \tilde{p} + # has already been observed. To analytically calculate the log-prob of the + # Gaussian, we first need to compute the mixture components. + + # Compute the mixture components of the proposal posterior. + logits_pp, m_pp, prec_pp = self._get_mixture_components(context) + + # z-score theta if it z-scoring had been requested. + theta = self._maybe_z_score_theta(inputs) + + # Compute the log_prob of theta under the product. + log_prob_proposal_posterior = utils.mog_log_prob( + theta, + logits_pp, + m_pp, + prec_pp, + ) + utils.assert_all_finite( + log_prob_proposal_posterior, "proposal posterior eval" + ) + return log_prob_proposal_posterior # \hat{p} from eq (3) in [1] + + def sample(self, num_samples, context=None, batch_size=None) -> Tensor: + if self._proposal is None: + # Use Flow.sample() if there has been no previous proposal memorized + # in this instance. This is the case if we are in the training + # loop, i.e. this MoGFlow_SNPE_A instance is not an attribute of a + # DirectPosterior instance. + return super().sample(num_samples, context, batch_size) + + else: + # No importance re-weighting is needed if the proposal prior is the prior + if isinstance(self._proposal, (utils.BoxUniform, MultivariateNormal)): + return super().sample(num_samples, context, batch_size) + + # When we want to sample from the approx. posterior, a proposal prior \tilde{p} + # has already been observed. To analytically calculate the log-prob of the + # Gaussian, we first need to compute the mixture components. + return self._sample_approx_posterior_mog(num_samples, context, batch_size) + + def _sample_approx_posterior_mog( + self, num_samples, x: Tensor, batch_size: int + ) -> Tensor: + r""" + Sample from the approximate posterior. + + Args: + num_samples: Desired number of samples. + x: Conditioning context for posterior $p(\theta|x)$. + batch_size: Batch size for sampling. + + Returns: + Samples from the approximate mixture of Gaussians posterior. + """ + + # Compute the mixture components of the proposal posterior. + logits_pp, m_pp, prec_pp = self._get_mixture_components(x) + + # Compute the precision factors which represent the upper triangular matrix + # of the cholesky decomposition of the prec_pp. + prec_factors_pp = torch.cholesky(prec_pp, upper=True) + + assert logits_pp.ndim == 2 + assert m_pp.ndim == 3 + assert prec_pp.ndim == 4 + assert prec_factors_pp.ndim == 4 + + # Replicate to use batched sampling from pyknos. + if batch_size is not None and batch_size > 1: + logits_pp = logits_pp.repeat(batch_size, 1) + m_pp = m_pp.repeat(batch_size, 1, 1) + prec_factors_pp = prec_factors_pp.repeat(batch_size, 1, 1, 1) + + # Get (optionally z-scored) MoG samples. + theta = MultivariateGaussianMDN.sample_mog( + num_samples, logits_pp, m_pp, prec_factors_pp + ) + + embedded_context = self._embedding_net(x) + if embedded_context is not None: + # Merge the context dimension with sample dimension in order to + # apply the transform. + theta = torchutils.merge_leading_dims(theta, num_dims=2) + embedded_context = torchutils.repeat_rows( + embedded_context, num_reps=num_samples + ) + + theta, _ = self._transform.inverse(theta, context=embedded_context) + + if embedded_context is not None: + # Split the context dimension from sample dimension. + theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples]) + + return theta + + def _get_mixture_components(self, x: Tensor): + """ + Compute the mixture components of the posterior given the current density + estimator and the proposal. + + Args: + x: Conditioning context for posterior. + + Returns: + Mixture components of the posterior. + """ + + # Evaluate the density estimator. + encoded_x = self._embedding_net(x) + dist = self._distribution # defined to avoid black formatting. + logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) + norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) + + if isinstance(self._proposal, (utils.BoxUniform, MultivariateNormal)): + # Uniform prior is uninformative. + return norm_logits_d, m_d, prec_d + + else: + # Recursive ask for the mixture components until the prior is yielded. + logits_p, m_p, prec_p = self._proposal._get_mixture_components(x) + + # Compute the MoG parameters of the proposal posterior. + logits_pp, m_pp, prec_pp, cov_pp = self._proposal_posterior_transformation( + logits_p, + m_p, + prec_p, + norm_logits_d, + m_d, + prec_d, + ) + return logits_pp, m_pp, prec_pp + + def _proposal_posterior_transformation( + self, + logits_pprior: Tensor, + means_pprior: Tensor, + precisions_pprior: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r""" + Transforms the proposal posterior (the MDN) into the posterior. + + The proposal posterior is: + $p(\theta|x) = 1/Z * q(\theta|x) * p(\theta) / prop(\theta)$ + In words: posterior = proposal posterior estimate * prior / proposal. + + Since the proposal posterior estimate and the proposal are MoG, and the + prior is either Gaussian or uniform, we can solve this in closed-form. + + This function implements Appendix C from [1], and is highly similar to + `SNPE_C._automatic_posterior_transformation()`. + + We have to build L*K components. How do we do this? + Example: proposal has two components, density estimator has three components. + Let's call the two components of the proposal i,j and the three components + of the density estimator x,y,z. We have to multiply every component of the + proposal with every component of the density estimator. So, what we do is: + 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave() + 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat() + 3) Multiply them with simple matrix operations. + + Args: + logits_pprior: Component weight of each Gaussian of the proposal prior. + means_pprior: Mean of each Gaussian of the proposal prior. + precisions_pprior: Precision matrix of each Gaussian of the proposal prior. + logits_d: Component weight for each Gaussian of the density estimator. + means_d: Mean of each Gaussian of the density estimator. + precisions_d: Precision matrix of each Gaussian of the density estimator. + + Returns: (Component weight, mean, precision matrix, covariance matrix) of each + Gaussian of the proposal posterior. Has L*K terms (proposal has L terms, + density estimator has K terms). + """ + + precisions_post, covariances_post = self._precisions_posterior( + precisions_pprior, precisions_d + ) + + means_post = self._means_posterior( + covariances_post, + means_pprior, + precisions_pprior, + means_d, + precisions_d, + ) + + logits_post = MoGFlow_SNPE_A._logits_posterior( + means_post, + precisions_post, + covariances_post, + logits_pprior, + means_pprior, + precisions_pprior, + logits_d, + means_d, + precisions_d, + ) + + return logits_post, means_post, precisions_post, covariances_post + + def _set_state_for_mog_proposal(self) -> None: + """ + Set state variables of the MoGFlow_SNPE_A instance evevy time `set_proposal()` + is called, i.e. every time a posterior is build using `SNPE_A.build_posterior()`. + + This function is almost identical to `SNPE_C._set_state_for_mog_proposal()`. + + Three things are computed: + 1) Check if z-scoring was requested. To do so, we check if the `_transform` + argument of the net had been a `CompositeTransform`. See pyknos mdn.py. + 2) Define a (potentially standardized) prior. It's standardized if z-scoring + had been requested. + 3) Compute (Precision * mean) for the prior. This quantity is used at every + training step if the prior is Gaussian. + """ + + self.z_score_theta = isinstance(self._transform, CompositeTransform) + + self._set_maybe_z_scored_prior() + + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + self.prec_m_prod_prior = torch.mv( + self._maybe_z_scored_prior.precision_matrix, + self._maybe_z_scored_prior.loc, + ) + + def _set_maybe_z_scored_prior(self) -> None: + r""" + Compute and store potentially standardized prior (if z-scoring was requested). + + This function is highly similar to `SNPE_C._set_maybe_z_scored_prior()`. + + The proposal posterior is: + $p(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ + + Let's denote z-scored theta by `a`: a = (theta - mean) / std + Then $p'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ + + The ' indicates that the evaluation occurs in standardized space. The constant + scaling factor has been absorbed into $Z_2$. + From the above equation, we see that we need to evaluate the prior **in + standardized space**. We build the standardized prior in this function. + + The standardize transform that is applied to the samples theta does not use + the exact prior mean and std (due to implementation issues). Hence, the z-scored + prior will not be exactly have mean=0 and std=1. + """ + prior = self._get_first_prior_from_proposal() + + if self.z_score_theta: + scale = self._transform._transforms[0]._scale + shift = self._transform._transforms[0]._shift + + # Following the definition of the linear transform in + # `standardizing_transform` in `sbiutils.py`: + # shift=-mean / std + # scale=1 / std + # Solving these equations for mean and std: + estim_prior_std = 1 / scale + estim_prior_mean = -shift * estim_prior_std + + # Compute the discrepancy of the true prior mean and std and the mean and + # std that was empirically estimated from samples. + # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) + # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean + # and std (estimated from samples and used to build standardize transform). + almost_zero_mean = (prior.mean - estim_prior_mean) / estim_prior_std + almost_one_std = torch.sqrt(prior.variance) / estim_prior_std + + if isinstance(prior, MultivariateNormal): + self._maybe_z_scored_prior = MultivariateNormal( + almost_zero_mean, torch.diag(almost_one_std) + ) + else: + range_ = torch.sqrt(almost_one_std * 3.0) + self._maybe_z_scored_prior = utils.BoxUniform( + almost_zero_mean - range_, almost_zero_mean + range_ + ) + else: + self._maybe_z_scored_prior = prior + + def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: + """Return potentially standardized theta if z-scoring was requested.""" + + if self.z_score_theta: + theta, _ = self._transform(theta) + + return theta + + def _precisions_posterior(self, precisions_pprior: Tensor, precisions_d: Tensor): + r""" + Return the precisions and covariances of the MoG posterior. + + As described at the end of Appendix C in [1], it can happen that the + proposal's precision matrix is not positive definite. + + $S_k^\prime = ( S_k^{-1} - S_0^{-1} )^{-1}$ + (see eq (23) in Appendix C of [1]) + + Args: + precisions_pprior: Precision matrices of the proposal prior. + precisions_d: Precision matrices of the density estimator. + + Returns: (Precisions, Covariances) of the MoG posterior. L*K terms. + """ + + num_comps_p = precisions_pprior.shape[1] + num_comps_d = precisions_d.shape[1] + + # Check if precision matrices are positive definite. + for batches in precisions_pprior: + for pprior in batches: + eig_pprior = torch.symeig(pprior, eigenvectors=False).eigenvalues + assert ( + eig_pprior > 0 + ).all(), ( + "The precision matrix of the proposal is not positive definite!" + ) + for batches in precisions_d: + for d in batches: + eig_d = torch.symeig(d, eigenvectors=False).eigenvalues + assert ( + eig_d > 0 + ).all(), "The precision matrix of the density estimator is not positive definite!" + + precisions_pprior_rep = precisions_pprior.repeat_interleave(num_comps_d, dim=1) + precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) + + precisions_p = precisions_d_rep - precisions_pprior_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + precisions_p += self._maybe_z_scored_prior.precision_matrix + + # Check if precision matrix is positive definite. + for idx_batch, batches in enumerate(precisions_p): + for idx_comp, pp in enumerate(batches): + eig_pp = torch.symeig(pp, eigenvectors=False).eigenvalues + if not (eig_pp > 0).all(): + if self._allow_precision_correction: + # Shift the eigenvalues to be at minimum 1e-6. + precisions_p[idx_batch, idx_comp] = pp - torch.eye( + pp.shape[0] + ) * (min(eig_pp) - 1e-6) + warn( + "The precision matrix of a posterior has not been positive " + "definite at least once. Added diagonal entries with the " + "smallest eigenvalue to 1e-6." + ) + + else: + # Fail when encountering an ill-conditioned precision matrix. + raise AssertionError( + "The precision matrix of a posterior is not positive definite! " + "This is a known issue for SNPE-A. Either try a different parameter " + "setting or pass `allow_precision_correction=True` when constructing " + "the `MoGFlow_SNPE_A` density estimator." + ) + + covariances_p = torch.inverse(precisions_p) + return precisions_p, covariances_p + + def _means_posterior( + self, + covariances_post: Tensor, + means_pprior: Tensor, + precisions_pprior: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r""" + Return the means of the MoG posterior. + + $m_k^\prime = S_k^\prime ( S_k^{-1} m_k - S_0^{-1} m_0 )$ + (see eq (24) in Appendix C of [1]) + + Args: + covariances_post: Covariance matrices of the MoG posterior. + means_pprior: Means of the proposal prior. + precisions_pprior: Precision matrices of the proposal prior. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Means of the MoG posterior. L*K terms. + """ + + num_comps_pprior = precisions_pprior.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute the products P_k * m_k and P_0 * m_0. + prec_m_prod_pprior = utils.batched_mixture_mv(precisions_pprior, means_pprior) + prec_m_prod_d = utils.batched_mixture_mv(precisions_d, means_d) + + # Repeat them to allow for matrix operations: same trick as for the precisions. + prec_m_prod_pprior_rep = prec_m_prod_pprior.repeat_interleave( + num_comps_d, dim=1 + ) + prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_pprior, 1) + + # Compute the means P_k^prime * (P_k * m_k - P_0 * m_0). + summed_cov_m_prod_rep = prec_m_prod_d_rep - prec_m_prod_pprior_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + summed_cov_m_prod_rep += self.prec_m_prod_prior + + means_p = utils.batched_mixture_mv(covariances_post, summed_cov_m_prod_rep) + return means_p + + @staticmethod + def _logits_posterior( + means_post: Tensor, + precisions_post: Tensor, + covariances_post: Tensor, + logits_pprior: Tensor, + means_pprior: Tensor, + precisions_pprior: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r""" + Return the component weights (i.e. logits) of the MoG posterior. + + $\alpha_k^\prime = \frac{ \alpha_k exp(-0.5 c_k) }{ \sum{j} \alpha_j exp(-0.5 c_j) } $ + with + $c_k = logdet(S_k) - logdet(S_0) - logdet(S_k^\prime) + + + m_k^T P_k m_k - m_0^T P_0 m_0 - m_k^\prime^T P_k^\prime m_k^\prime$ + (see eqs. (25, 26) in Appendix C of [1]) + + Args: + means_post: Means of the posterior. + precisions_post: Precision matrices of the posterior. + covariances_post: Covariance matrices of the posterior. + logits_pprior: Component weights (i.e. logits) of the proposal prior. + means_pprior: Means of the proposal prior. + precisions_pprior: Precision matrices of the proposal prior. + logits_d: Component weights (i.e. logits) of the density estimator. + means_d: Means of the density estimator. + precisions_d: Precision matrices of the density estimator. + + Returns: Component weights of the proposal posterior. L*K terms. + """ + + num_comps_pprior = precisions_pprior.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute the ratio of the logits similar to eq (10) in Appendix A.1 of [2] + logits_pprior_rep = logits_pprior.repeat_interleave(num_comps_d, dim=1) + logits_d_rep = logits_d.repeat(1, num_comps_pprior) + logit_factors = logits_d_rep - logits_pprior_rep + + # Compute the log-determinants + logdet_covariances_post = torch.logdet(covariances_post) + logdet_covariances_pprior = -torch.logdet(precisions_pprior) + logdet_covariances_d = -torch.logdet(precisions_d) + + # Repeat the proposal and density estimator terms such that there are LK terms. + # Same trick as has been used above. + logdet_covariances_pprior_rep = logdet_covariances_pprior.repeat_interleave( + num_comps_d, dim=1 + ) + logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_pprior) + + log_sqrt_det_ratio = 0.5 * ( # similar to eq (14) in Appendix A.1 of [2] + logdet_covariances_post + + logdet_covariances_pprior_rep + - logdet_covariances_d_rep + ) + + # Compute for proposal, density estimator, and proposal posterior: + exponent_pprior = utils.batched_mixture_vmv( + precisions_pprior, means_pprior # m_0 in eq (26) in Appendix C of [1] + ) + exponent_d = utils.batched_mixture_vmv( + precisions_d, means_d # m_k in eq (26) in Appendix C of [1] + ) + exponent_post = utils.batched_mixture_vmv( + precisions_post, means_post # m_k^\prime in eq (26) in Appendix C of [1] + ) + + # Extend proposal and density estimator exponents to get LK terms. + exponent_prior_rep = exponent_pprior.repeat_interleave(num_comps_d, dim=1) + exponent_d_rep = exponent_d.repeat(1, num_comps_pprior) + exponent = -0.5 * ( + exponent_prior_rep - exponent_d_rep - exponent_post # eq (26) in [1] + ) + + logits_post = logit_factors + log_sqrt_det_ratio + exponent + return logits_post diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py index 34f99b891..a3bf6744a 100644 --- a/sbi/utils/get_nn_models.py +++ b/sbi/utils/get_nn_models.py @@ -13,6 +13,7 @@ ) from sbi.neural_nets.flow import build_made, build_maf, build_nsf from sbi.neural_nets.mdn import build_mdn +from sbi.neural_nets.mdn_snpe_a import build_mdn_snpe_a def classifier_nn( @@ -212,16 +213,34 @@ def posterior_nn( ) ) - def build_fn(batch_theta, batch_x): - if model == "mdn": - return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs) - if model == "made": - return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs) - if model == "maf": - return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "nsf": - return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - else: - raise NotImplementedError + if model == "mdn_snpe_a": + kwargs.pop("num_components") + + def build_fn(batch_theta, batch_x, num_components): + # Extract the number of components from the kwargs, such that + # they are exposed as a kwargs, offering the possibility to later + # override this kwarg with functools.partial. This is necessary + # in order to make sure that the MDN in SNPE-A only has one + # component when running the Algorithm 1 part. + return build_mdn_snpe_a( + batch_x=batch_theta, + batch_y=batch_x, + num_components=num_components, + **kwargs + ) + + else: + + def build_fn(batch_theta, batch_x): + if model == "mdn": + return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "made": + return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "maf": + return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "nsf": + return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + else: + raise NotImplementedError return build_fn diff --git a/tests/inference_with_NaN_simulator_test.py b/tests/inference_with_NaN_simulator_test.py index 7c29cb045..40872285e 100644 --- a/tests/inference_with_NaN_simulator_test.py +++ b/tests/inference_with_NaN_simulator_test.py @@ -6,7 +6,7 @@ from torch import eye, ones, zeros from sbi import utils as utils -from sbi.inference import SNL, SNPE_C, SRE, prepare_for_sbi, simulate_for_sbi +from sbi.inference import SNL, SNPE_A, SNPE_C, SRE, prepare_for_sbi, simulate_for_sbi from sbi.simulators.linear_gaussian import ( linear_gaussian, samples_true_posterior_linear_gaussian_uniform_prior, @@ -36,7 +36,8 @@ def test_handle_invalid_x(x_shape, set_seed): assert torch.isfinite(x[x_is_valid]).all() -def test_z_scoring_warning(): +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +def test_z_scoring_warning(snpe_method: type): # Create data with large variance. num_dim = 2 @@ -47,7 +48,7 @@ def test_z_scoring_warning(): # Make sure a warning is raised because z-scoring will map these data to duplicate # data points. with pytest.warns(UserWarning, match="Z-scoring these simulation outputs"): - SNPE_C(utils.BoxUniform(zeros(num_dim), ones(num_dim))).append_simulations( + snpe_method(utils.BoxUniform(zeros(num_dim), ones(num_dim))).append_simulations( theta, x ).train(max_num_epochs=1) @@ -56,13 +57,14 @@ def test_z_scoring_warning(): @pytest.mark.parametrize( ("method", "exclude_invalid_x", "percent_nans"), ( + (SNPE_A, True, 0.05), (SNPE_C, True, 0.05), (SNL, True, 0.05), (SRE, True, 0.05), ), ) def test_inference_with_nan_simulator( - method, exclude_invalid_x, percent_nans, set_seed + method: type, exclude_invalid_x: bool, percent_nans: float, set_seed ): # likelihood_mean will be likelihood_shift+theta @@ -98,7 +100,11 @@ def linear_gaussian_nan( _ = inference.append_simulations(theta, x).train( exclude_invalid_x=exclude_invalid_x ) - posterior = inference.build_posterior().set_default_x(x_o) + + if method == SNPE_A: + posterior = inference.build_posterior(proposal=prior).set_default_x(x_o) + else: + posterior = inference.build_posterior().set_default_x(x_o) samples = posterior.sample((num_samples,)) @@ -107,7 +113,8 @@ def linear_gaussian_nan( @pytest.mark.slow -def test_inference_with_restriction_estimator(set_seed): +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +def test_inference_with_restriction_estimator(snpe_method: type, set_seed): # likelihood_mean will be likelihood_shift+theta num_dim = 3 @@ -149,9 +156,14 @@ def linear_gaussian_nan( all_theta, all_x, _ = rejection_estimator.get_simulations() # Any method can be used in combination with the `RejectionEstimator`. - inference = SNPE_C(prior=prior) + inference = snpe_method(prior=prior) _ = inference.append_simulations(all_theta, all_x).train() - posterior = inference.build_posterior().set_default_x(x_o) + + # Build posterior. + if snpe_method == SNPE_A: + posterior = inference.build_posterior(proposal=prior).set_default_x(x_o) + else: + posterior = inference.build_posterior().set_default_x(x_o) samples = posterior.sample((num_samples,)) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 8c0ea9f57..a2d067d5b 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -12,7 +12,7 @@ from sbi import analysis as analysis from sbi import utils as utils -from sbi.inference import SNPE_B, SNPE_C, prepare_for_sbi, simulate_for_sbi +from sbi.inference import SNPE_A, SNPE_B, SNPE_C, prepare_for_sbi, simulate_for_sbi from sbi.simulators.linear_gaussian import ( linear_gaussian, samples_true_posterior_linear_gaussian_mvn_prior_different_dims, @@ -276,6 +276,7 @@ def simulator(theta): # Testing rejection and mcmc sampling methods. @pytest.mark.slow +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) @pytest.mark.parametrize( "sample_with_mcmc, mcmc_method, prior_str", ( @@ -287,7 +288,7 @@ def simulator(theta): ), ) def test_api_snpe_c_posterior_correction( - sample_with_mcmc, mcmc_method, prior_str, set_seed + snpe_method, sample_with_mcmc, mcmc_method, prior_str, set_seed ): """Test that leakage correction applied to sampling works, with both MCMC and rejection. @@ -314,9 +315,8 @@ def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE_C( + inference = snpe_method( prior, - density_estimator="maf", simulation_batch_size=50, sample_with_mcmc=sample_with_mcmc, mcmc_method=mcmc_method, @@ -325,7 +325,11 @@ def simulator(theta): theta, x = simulate_for_sbi(simulator, prior, 1000) _ = inference.append_simulations(theta, x).train(max_num_epochs=5) - posterior = inference.build_posterior() + + if snpe_method == SNPE_A: + posterior = inference.build_posterior(proposal=prior) + else: + posterior = inference.build_posterior() posterior = posterior.set_sample_with_mcmc(sample_with_mcmc).set_mcmc_method( mcmc_method ) @@ -338,7 +342,8 @@ def simulator(theta): @pytest.mark.slow -def test_sample_conditional(set_seed): +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +def test_sample_conditional(snpe_method, set_seed): """ Test whether sampling from the conditional gives the same results as evaluating. @@ -366,10 +371,13 @@ def simulator(theta): else: return linear_gaussian(theta, -likelihood_shift, likelihood_cov) - net = utils.posterior_nn("maf", hidden_features=20) + if snpe_method == SNPE_A: + net = utils.posterior_nn("mdn_snpe_a", num_components=5, hidden_features=20) + else: + net = utils.posterior_nn("maf", hidden_features=20) simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE_C( + inference = snpe_method( prior, density_estimator=net, show_progress_bars=False, @@ -378,7 +386,11 @@ def simulator(theta): # We need a pretty big dataset to properly model the bimodality. theta, x = simulate_for_sbi(simulator, prior, 10000) _ = inference.append_simulations(theta, x).train(max_num_epochs=50) - posterior = inference.build_posterior().set_default_x(x_o) + + if snpe_method == SNPE_A: + posterior = inference.build_posterior(proposal=prior).set_default_x(x_o) + else: + posterior = inference.build_posterior().set_default_x(x_o) samples = posterior.sample((50,)) # Evaluate the conditional density be drawing samples and smoothing with a Gaussian @@ -432,7 +444,8 @@ def simulator(theta): assert max_err < 0.0025 -def example_posterior(): +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +def test_example_posterior(snpe_method): """Return an inferred `NeuralPosterior` for interactive examination.""" num_dim = 2 x_o = zeros(1, num_dim) @@ -449,7 +462,7 @@ def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE_C( + inference = snpe_method( prior, show_progress_bars=False, ) @@ -457,4 +470,9 @@ def simulator(theta): simulator, prior, 1000, simulation_batch_size=10, num_workers=6 ) _ = inference.append_simulations(theta, x).train() - return inference.build_posterior().set_default_x(x_o) + + if snpe_method == SNPE_A: + posterior = inference.build_posterior(proposal=prior).set_default_x(x_o) + else: + posterior = inference.build_posterior().set_default_x(x_o) + assert posterior is not None diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index b53f783a5..763400d07 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -7,20 +7,24 @@ from torch import eye, ones, zeros from torch.distributions import MultivariateNormal -from sbi.inference import SNPE_C, prepare_for_sbi, simulate_for_sbi +from sbi.inference import SNPE_A, SNPE_C, prepare_for_sbi, simulate_for_sbi from sbi.simulators.linear_gaussian import diagonal_linear_gaussian -def test_log_prob_with_different_x(): +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +def test_log_prob_with_different_x(snpe_method: type): num_dim = 2 prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior) - inference = SNPE_C(prior) + inference = snpe_method(prior) theta, x = simulate_for_sbi(simulator, prior, 1000) _ = inference.append_simulations(theta, x).train() - posterior = inference.build_posterior() + if snpe_method == SNPE_A: + posterior = inference.build_posterior(proposal=prior) + else: + posterior = inference.build_posterior() _ = posterior.sample((10,), x=ones(1, num_dim)) theta = posterior.sample((10,), ones(1, num_dim)) diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index eaae3c555..f8803cc50 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -23,7 +23,8 @@ from torch import Tensor, eye, nn, ones, zeros from torch.distributions import Beta, Distribution, Gamma, MultivariateNormal, Uniform -from sbi.inference import SNPE_C, simulate_for_sbi +from sbi.inference import SNPE_A, SNPE_C, simulate_for_sbi +from sbi.inference.posteriors.direct_posterior import DirectPosterior from sbi.simulators.linear_gaussian import diagonal_linear_gaussian from sbi.utils.get_nn_models import posterior_nn from sbi.utils.torchutils import BoxUniform @@ -261,6 +262,7 @@ def test_prepare_sbi_problem(simulator: Callable, prior): assert prior.sample().dtype == torch.float32 +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) @pytest.mark.parametrize( "user_simulator, user_prior", ( @@ -290,22 +292,34 @@ def test_prepare_sbi_problem(simulator: Callable, prior): ), ), ) -def test_inference_with_user_sbi_problems(user_simulator: Callable, user_prior): +def test_inference_with_user_sbi_problems( + snpe_method: type, user_simulator: Callable, user_prior +): """ Test inference with combinations of user defined simulators, priors and x_os. """ simulator, prior = prepare_for_sbi(user_simulator, user_prior) - inference = SNPE_C( + inference = snpe_method( prior, - density_estimator="maf", + density_estimator="mdn_snpe_a" if snpe_method == SNPE_A else "maf", show_progress_bars=False, ) # Run inference. theta, x = simulate_for_sbi(simulator, prior, 100) _ = inference.append_simulations(theta, x).train(max_num_epochs=2) - _ = inference.build_posterior() + + # Build posterior. + if snpe_method == SNPE_A: + if not isinstance(prior, (MultivariateNormal, BoxUniform, DirectPosterior)): + with pytest.raises(TypeError): + # SNPE-A does not support priors yet. + _ = inference.build_posterior(proposal=prior) + else: + _ = inference.build_posterior(proposal=prior) + else: + _ = inference.build_posterior() @pytest.mark.parametrize( @@ -514,9 +528,12 @@ def test_validate_theta_and_x_gpu(): @pytest.mark.gpu +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) @pytest.mark.parametrize("data_device", ("cpu", "cuda:0")) @pytest.mark.parametrize("training_device", ("cpu", "cuda:0")) -def test_train_with_different_data_and_training_device(data_device, training_device): +def test_train_with_different_data_and_training_device( + snpe_method: type, data_device, training_device +): assert torch.cuda.is_available(), "gpu geared test has no GPU available" @@ -528,8 +545,11 @@ def test_train_with_different_data_and_training_device(data_device, training_dev ) simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior_) - inference = SNPE_C( - prior, density_estimator="maf", show_progress_bars=False, device=training_device + inference = snpe_method( + prior, + density_estimator="mdn_snpe_a" if snpe_method == SNPE_A else "maf", + show_progress_bars=False, + device=training_device, ) # Run inference. @@ -543,4 +563,7 @@ def test_train_with_different_data_and_training_device(data_device, training_dev weights_device = next(inference._neural_net.parameters()).device assert torch.device(training_device) == weights_device - _ = inference.build_posterior() + if snpe_method == SNPE_A: + _ = inference.build_posterior(proposal=prior) + else: + _ = inference.build_posterior()