From 96d3a3b0a141663c88afe03015d20fd7eddfa3e0 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 3 May 2021 17:04:26 +0200 Subject: [PATCH] SNPE-A review changes --- sbi/inference/snpe/snpe_a.py | 87 +++++++++++++------ sbi/inference/snpe/snpe_c.py | 11 +-- sbi/neural_nets/mdn.py | 8 +- sbi/neural_nets/mdn_snpe_a.py | 70 --------------- ...g_flow_snpe_a.py => mdn_wrapper_snpe_a.py} | 78 +++++++++-------- sbi/utils/__init__.py | 2 +- sbi/utils/conditional_density.py | 44 ---------- sbi/utils/get_nn_models.py | 3 +- sbi/utils/sbiutils.py | 44 ++++++++++ tests/linearGaussian_snpe_test.py | 27 ++++-- 10 files changed, 177 insertions(+), 197 deletions(-) delete mode 100644 sbi/neural_nets/mdn_snpe_a.py rename sbi/neural_nets/{mog_flow_snpe_a.py => mdn_wrapper_snpe_a.py} (90%) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 04e93cb41..d8b270e26 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -14,6 +14,7 @@ import sbi.utils as utils from sbi.inference.posteriors.direct_posterior import DirectPosterior from sbi.inference.snpe.snpe_base import PosteriorEstimator +from sbi.neural_nets.mdn_wrapper_snpe_a import MDNWrapper_SNPE_A from sbi.types import TensorboardSummaryWriter, TorchModule @@ -32,9 +33,6 @@ def __init__( ): r"""SNPE-A [1]. - https://github.com/mackelab/sbi/blob/main/sbi/inference/snpe/snpe_c.py - https://github.com/mackelab/sbi/blob/main/sbi/neural_nets/mdn.py - [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional Density Estimation_, Papamakarios et al., NeurIPS 2016, https://arxiv.org/abs/1605.06376. @@ -44,18 +42,24 @@ def __init__( parameters, e.g. which ranges are meaningful for them. Any object with `.log_prob()`and `.sample()` (for example, a PyTorch distribution) can be used. - density_estimator: If it is a string, use a pre-configured network of the - provided type (one of nsf, maf, mdn, made). Alternatively, a function + density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a + pre-configured mixture of densities network. Alternatively, a function that builds a custom neural network can be provided. The function will be called with the first batch of simulations (theta, x), which can thus be used for shape inference and potentially for z-scoring. It needs to return a PyTorch `nn.Module` implementing the density estimator. The density estimator needs to provide the methods `.log_prob` and `.sample()`. + Note that until the last round only a single (multivariate) Gaussian + component is used for training (see Algorithm 1 in [1]). In the last + round, this component is replicated `num_components` times, its parameters + are perturbed with a very small noise, and then the last training round + is done with the expanded Gaussian mixture as estimator for the + proposal posterior. num_components: - Number of components of the mixture of Gaussians. This number is set to 1 before - running Algorithm 1, and then later set to the specified value before running - Algorithm 2. + Number of components of the mixture of Gaussians. This number is set to + 1 before running Algorithm 1, and then later set to the specified value + before running Algorithm 2. num_rounds: Total number of training rounds. For all but the last round, Algorithm 1 from [1] is executed. For last round, Algorithm 2 from [1] is executed once. By default, `num_rounds` is set to 1, i.e. only Algorithm 2 is executed once @@ -71,8 +75,14 @@ def __init__( 0.14.0 is more mature, we will remove this argument. """ + # Catch invalid inputs. + if not ((density_estimator == "mdn_snpe_a") or callable(density_estimator)): + raise TypeError( + "The `density_estimator` passed to SNPE_A needs to be a " + "callable or the string 'mdn_snpe_a'!" + ) + self._num_rounds = num_rounds - # TODO How to extract the number of components from density_estimator if that is a callable? self._num_components = num_components # WARNING: sneaky trick ahead. We proxy the parent's `train` here, @@ -80,7 +90,14 @@ def __init__( # continue. It's sneaky because we are using the object (self) as a namespace # to pass arguments between functions, and that's implicit state management. kwargs = utils.del_entries( - locals(), entries=("self", "__class__", "unused_args", "num_rounds", "num_components") + locals(), + entries=( + "self", + "__class__", + "unused_args", + "num_rounds", + "num_components", + ), ) super().__init__(**kwargs) @@ -100,7 +117,9 @@ def train( dataloader_kwargs: Optional[Dict] = None, ) -> DirectPosterior: r""" - Return density estimator that approximates the distribution $p(\theta|x)$. + Return density estimator that approximates the proposal posterior's + distribution $\tilde{p}(\theta|x)$. + Args: training_batch_size: Training batch size. learning_rate: Learning rate for Adam optimizer. @@ -170,7 +189,9 @@ def train( def build_posterior( self, - proposal: Union[MultivariateNormal, utils.BoxUniform, DirectPosterior], + proposal: Optional[ + Union[MultivariateNormal, utils.BoxUniform, DirectPosterior] + ] = None, density_estimator: Optional[TorchModule] = None, rejection_sampling_parameters: Optional[Dict[str, Any]] = None, sample_with_mcmc: bool = False, @@ -180,7 +201,7 @@ def build_posterior( r""" Build posterior from the neural density estimator. - For SNPE, the posterior distribution that is returned here implements the TODO + For SNPE, the posterior distribution that is returned here implements the following functionality over the raw neural density estimator: - correct the calculation of the log probability such that it compensates for @@ -189,8 +210,14 @@ def build_posterior( - alternatively, if leakage is very high (which can happen for multi-round SNPE), sample from the posterior with MCMC. + The DirectPosterior class assumes that the density estimator approximates the posterior. + In SNPE-A, the density estimator is an approximation of the proposal posterior. Hence, importance reweigthing + is needed during evaluation. + Args: - proposal: The distribution that the parameters $\theta$ were sampled from. + proposal: The proposal prior obtained from the from maximum-likelihood training. + If None, the posterior of the previous previous round is used. In the first round + the prior is used as a proposal. density_estimator: The density estimator that the posterior is based on. If `None`, use the latest neural density estimator that was trained. rejection_sampling_parameters: Dictionary overriding the default parameters @@ -229,20 +256,30 @@ def build_posterior( # Set proposal of the density estimator. # This also evokes the z-scoring correction is necessary. - if isinstance(proposal, (MultivariateNormal, utils.BoxUniform)): - density_estimator.set_proposal(proposal) + if proposal is None: + if self._model_bank: + proposal = self._model_bank[-1].net + else: + proposal = self._prior + elif isinstance(proposal, (MultivariateNormal, utils.BoxUniform)): + pass elif isinstance(proposal, DirectPosterior): - # Extract the MoGFlow_SNPE_A from the DirectPosterior. - density_estimator.set_proposal(proposal.net) + # Extract the MDNWrapper_SNPE_A from the DirectPosterior. + proposal = proposal.net else: raise TypeError( "So far, only MultivariateNormal, BoxUniform, and DirectPosterior are" "supported for the `proposal` arg in SNPE_A.build_posterior()." ) + # Create the MDNWrapper_SNPE_A + wrapped_density_estimator = MDNWrapper_SNPE_A( + flow=density_estimator, proposal=proposal + ) + self._posterior = DirectPosterior( method_family="snpe", - neural_net=density_estimator, + neural_net=wrapped_density_estimator, prior=self._prior, x_shape=self._x_shape, rejection_sampling_parameters=rejection_sampling_parameters, @@ -266,13 +303,8 @@ def _log_prob_proposal_posterior( """ Return the log-probability of the proposal posterior. - .. note:: - For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in - `_loss()` to be found in `snpe_base.py`. - - If the proposal is a MoG, the density estimator is a MoG, and the prior is - either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which - suffers from leakage). + For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in + `_loss()` to be found in `snpe_base.py`. Args: theta: Batch of parameters θ. @@ -293,7 +325,8 @@ def _expand_mog(self, eps: float = 1e-5): symmetry such that the gradients in the subsequent training are not all identical. - :param eps: Standard deviation for the random perturbation. + Args: + eps: Standard deviation for the random perturbation. """ assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN) diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index b298a5155..775e0a06b 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -2,7 +2,6 @@ # under the Affero General Public License v3, see . -from math import pi from typing import Any, Callable, Dict, Optional, Union import torch @@ -252,7 +251,7 @@ def _set_maybe_z_scored_prior(self) -> None: if isinstance(self._prior, MultivariateNormal): self._maybe_z_scored_prior = MultivariateNormal( - almost_zero_mean, torch.diag(almost_one_std), + almost_zero_mean, torch.diag(almost_one_std) ) else: range_ = torch.sqrt(almost_one_std * 3.0) @@ -417,7 +416,9 @@ def _log_prob_proposal_posterior_mog( ) # Compute the log_prob of theta under the product. - log_prob_proposal_posterior = utils.mog_log_prob(theta, logits_pp, m_pp, prec_pp) + log_prob_proposal_posterior = utils.mog_log_prob( + theta, logits_pp, m_pp, prec_pp + ) utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") return log_prob_proposal_posterior @@ -471,7 +472,7 @@ def _automatic_posterior_transformation( ) means_pp = self._means_proposal_posterior( - covariances_pp, means_p, precisions_p, means_d, precisions_d, + covariances_pp, means_p, precisions_p, means_d, precisions_d ) logits_pp = self._logits_proposal_posterior( @@ -489,7 +490,7 @@ def _automatic_posterior_transformation( return logits_pp, means_pp, precisions_pp, covariances_pp def _precisions_proposal_posterior( - self, precisions_p: Tensor, precisions_d: Tensor, + self, precisions_p: Tensor, precisions_d: Tensor ): """ Return the precisions and covariances of the proposal posterior. diff --git a/sbi/neural_nets/mdn.py b/sbi/neural_nets/mdn.py index 900bc5796..2de2da657 100644 --- a/sbi/neural_nets/mdn.py +++ b/sbi/neural_nets/mdn.py @@ -6,7 +6,7 @@ from pyknos.nflows import flows, transforms from torch import Tensor, nn -from sbi.utils.sbiutils import standardizing_net, standardizing_transform +import sbi.utils as utils def build_mdn( @@ -17,7 +17,7 @@ def build_mdn( hidden_features: int = 50, num_components: int = 10, embedding_net: nn.Module = nn.Identity(), - **kwargs + **kwargs, ) -> nn.Module: """Builds MDN p(x|y). @@ -42,11 +42,11 @@ def build_mdn( transform = transforms.IdentityTransform() if z_score_x: - transform_zx = standardizing_transform(batch_x) + transform_zx = utils.standardizing_transform(batch_x) transform = transforms.CompositeTransform([transform_zx, transform]) if z_score_y: - embedding_net = nn.Sequential(standardizing_net(batch_y), embedding_net) + embedding_net = nn.Sequential(utils.standardizing_net(batch_y), embedding_net) distribution = MultivariateGaussianMDN( features=x_numel, diff --git a/sbi/neural_nets/mdn_snpe_a.py b/sbi/neural_nets/mdn_snpe_a.py deleted file mode 100644 index 044aa0bf1..000000000 --- a/sbi/neural_nets/mdn_snpe_a.py +++ /dev/null @@ -1,70 +0,0 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . - -from pyknos.mdn.mdn import MultivariateGaussianMDN -from pyknos.nflows import transforms -from torch import Tensor, nn - -from sbi.neural_nets.mog_flow_snpe_a import MoGFlow_SNPE_A -from sbi.utils.sbiutils import standardizing_net, standardizing_transform - - -def build_mdn_snpe_a( - batch_x: Tensor = None, - batch_y: Tensor = None, - z_score_x: bool = True, - z_score_y: bool = True, - hidden_features: int = 50, - num_components: int = 10, - embedding_net: nn.Module = nn.Identity(), - **kwargs, -) -> nn.Module: - """Builds MDN p(x|y) with different sampling in training and evaluation mode. - - Args: - batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. - batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. - z_score_x: Whether to z-score xs passing into the network. - z_score_y: Whether to z-score ys passing into the network. - hidden_features: Number of hidden features. - num_components: Number of components. - embedding_net: Optional embedding network for y. - kwargs: Additional arguments that are passed by the build function but are not - relevant for MDNs and are therefore ignored. - - Returns: - Neural network. - """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - y_numel = embedding_net(batch_y[:1]).numel() - - transform = transforms.IdentityTransform() - - if z_score_x: - transform_zx = standardizing_transform(batch_x) - transform = transforms.CompositeTransform([transform_zx, transform]) - - if z_score_y: - embedding_net = nn.Sequential(standardizing_net(batch_y), embedding_net) - - distribution = MultivariateGaussianMDN( - features=x_numel, - context_features=y_numel, - hidden_features=hidden_features, - hidden_net=nn.Sequential( - nn.Linear(y_numel, hidden_features), - nn.ReLU(), - nn.Dropout(p=0.0), - nn.Linear(hidden_features, hidden_features), - nn.ReLU(), - nn.Linear(hidden_features, hidden_features), - nn.ReLU(), - ), - num_components=num_components, - custom_initialization=True, - ) - - neural_net = MoGFlow_SNPE_A(transform, distribution, embedding_net) - - return neural_net diff --git a/sbi/neural_nets/mog_flow_snpe_a.py b/sbi/neural_nets/mdn_wrapper_snpe_a.py similarity index 90% rename from sbi/neural_nets/mog_flow_snpe_a.py rename to sbi/neural_nets/mdn_wrapper_snpe_a.py index 1445d57a7..0d9488de0 100644 --- a/sbi/neural_nets/mog_flow_snpe_a.py +++ b/sbi/neural_nets/mdn_wrapper_snpe_a.py @@ -5,6 +5,7 @@ from warnings import warn import torch +import torch.nn as nn from pyknos.mdn.mdn import MultivariateGaussianMDN from pyknos.nflows import flows from pyknos.nflows.transforms import CompositeTransform @@ -12,13 +13,16 @@ from torch.distributions import MultivariateNormal import sbi.utils as utils +import sbi.utils.sbiutils from sbi.utils import torchutils -class MoGFlow_SNPE_A(flows.Flow): +class MDNWrapper_SNPE_A(nn.Module): """ - A wrapper for nflow's `Flow` class to enable a different log prob calculation - sampling strategy for training and testing, tailored to SNPE-A [1] + A wrapper containing a `flows.Flow` class to enable a different log prob calculation + sampling strategy for training and testing, tailored to SNPE-A [1]. + This class inherits from `nn.Module` since the constructor of `DirectPosterior` + expects the argument `neural_net` to be a `nn.Module`. [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional Density Estimation_, Papamakarios et al., NeurIPS 2016, @@ -29,38 +33,40 @@ class MoGFlow_SNPE_A(flows.Flow): def __init__( self, - transform, - distribution, - embedding_net=None, + flow: flows.Flow, + proposal: Union["utils.BoxUniform", MultivariateNormal, "MDNWrapper_SNPE_A"], allow_precision_correction: bool = False, ): """Constructor. Args: - transform: A `Transform` object, it transforms data into noise. - distribution: A `Distribution` object, the base distribution of the flow that - generates the noise. - embedding_net: A `nn.Module` which has trainable parameters to encode the - context (condition). It is trained jointly with the flow. + flow: The trained normalizing flow, passed when building the posterior. + proposal: The proposal distribution. allow_precision_correction: Add a diagonal with the smallest eigenvalue in every entry in case the precision matrix becomes ill-conditioned. """ - # Construct the flow. - super().__init__(transform, distribution, embedding_net) + # Call nn.Module's constructor. + super().__init__() - self._proposal = None + self._neural_net = flow self._allow_precision_correction = allow_precision_correction + # Set the proposal. + self._proposal = proposal + # Take care of z-scoring, pre-compute and store prior terms. + self._set_state_for_mog_proposal() + @property def proposal( self, - ) -> Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"]: + ) -> Union["utils.BoxUniform", MultivariateNormal, "MDNWrapper_SNPE_A"]: """Get the proposal of the previous round.""" return self._proposal def set_proposal( - self, proposal: Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"] + self, + proposal: Union["utils.BoxUniform", MultivariateNormal, "MDNWrapper_SNPE_A"], ): """Set the proposal of the previous round.""" self._proposal = proposal @@ -70,7 +76,7 @@ def set_proposal( def _get_first_prior_from_proposal( self, - ) -> Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"]: + ) -> Union["utils.BoxUniform", MultivariateNormal, "MDNWrapper_SNPE_A"]: """Iterate a possible chain of proposals.""" curr_prior = self._proposal @@ -87,13 +93,15 @@ def log_prob(self, inputs, context=None): if self._proposal is None: # Use Flow.lob_prob() if there has been no previous proposal memorized # in this instance. This is the case if we are in the training - # loop, i.e. this MoGFlow_SNPE_A instance is not an attribute of a + # loop, i.e. this MDNWrapper_SNPE_A instance is not an attribute of a # DirectPosterior instance. - return super().log_prob(inputs, context) # q_phi from eq (3) in [1] + return self._neural_net.log_prob( + inputs, context + ) # q_phi from eq (3) in [1] elif isinstance(self._proposal, (utils.BoxUniform, MultivariateNormal)): # No importance re-weighting is needed if the proposal prior is the prior - return super().log_prob(inputs, context) + return self._neural_net.log_prob(inputs, context) else: # When we want to compute the approx. posterior, a proposal prior \tilde{p} @@ -107,7 +115,7 @@ def log_prob(self, inputs, context=None): theta = self._maybe_z_score_theta(inputs) # Compute the log_prob of theta under the product. - log_prob_proposal_posterior = utils.mog_log_prob( + log_prob_proposal_posterior = sbi.utils.sbiutils.mog_log_prob( theta, logits_pp, m_pp, @@ -122,14 +130,14 @@ def sample(self, num_samples, context=None, batch_size=None) -> Tensor: if self._proposal is None: # Use Flow.sample() if there has been no previous proposal memorized # in this instance. This is the case if we are in the training - # loop, i.e. this MoGFlow_SNPE_A instance is not an attribute of a + # loop, i.e. this MDNWrapper_SNPE_A instance is not an attribute of a # DirectPosterior instance. - return super().sample(num_samples, context, batch_size) + return self._neural_net.sample(num_samples, context, batch_size) else: # No importance re-weighting is needed if the proposal prior is the prior if isinstance(self._proposal, (utils.BoxUniform, MultivariateNormal)): - return super().sample(num_samples, context, batch_size) + return self._neural_net.sample(num_samples, context, batch_size) # When we want to sample from the approx. posterior, a proposal prior \tilde{p} # has already been observed. To analytically calculate the log-prob of the @@ -174,7 +182,7 @@ def _sample_approx_posterior_mog( num_samples, logits_pp, m_pp, prec_factors_pp ) - embedded_context = self._embedding_net(x) + embedded_context = self._neural_net._embedding_net(x) if embedded_context is not None: # Merge the context dimension with sample dimension in order to # apply the transform. @@ -183,7 +191,7 @@ def _sample_approx_posterior_mog( embedded_context, num_reps=num_samples ) - theta, _ = self._transform.inverse(theta, context=embedded_context) + theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) if embedded_context is not None: # Split the context dimension from sample dimension. @@ -204,8 +212,8 @@ def _get_mixture_components(self, x: Tensor): """ # Evaluate the density estimator. - encoded_x = self._embedding_net(x) - dist = self._distribution # defined to avoid black formatting. + encoded_x = self._neural_net._embedding_net(x) + dist = self._neural_net._distribution # defined to avoid black formatting. logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) @@ -284,7 +292,7 @@ def _proposal_posterior_transformation( precisions_d, ) - logits_post = MoGFlow_SNPE_A._logits_posterior( + logits_post = MDNWrapper_SNPE_A._logits_posterior( means_post, precisions_post, covariances_post, @@ -300,7 +308,7 @@ def _proposal_posterior_transformation( def _set_state_for_mog_proposal(self) -> None: """ - Set state variables of the MoGFlow_SNPE_A instance evevy time `set_proposal()` + Set state variables of the MDNWrapper_SNPE_A instance evevy time `set_proposal()` is called, i.e. every time a posterior is build using `SNPE_A.build_posterior()`. This function is almost identical to `SNPE_C._set_state_for_mog_proposal()`. @@ -314,7 +322,7 @@ def _set_state_for_mog_proposal(self) -> None: training step if the prior is Gaussian. """ - self.z_score_theta = isinstance(self._transform, CompositeTransform) + self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) self._set_maybe_z_scored_prior() @@ -348,8 +356,8 @@ def _set_maybe_z_scored_prior(self) -> None: prior = self._get_first_prior_from_proposal() if self.z_score_theta: - scale = self._transform._transforms[0]._scale - shift = self._transform._transforms[0]._shift + scale = self._neural_net._transform._transforms[0]._scale + shift = self._neural_net._transform._transforms[0]._shift # Following the definition of the linear transform in # `standardizing_transform` in `sbiutils.py`: @@ -383,7 +391,7 @@ def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: """Return potentially standardized theta if z-scoring was requested.""" if self.z_score_theta: - theta, _ = self._transform(theta) + theta, _ = self._neural_net._transform(theta) return theta @@ -452,7 +460,7 @@ def _precisions_posterior(self, precisions_pprior: Tensor, precisions_d: Tensor) "The precision matrix of a posterior is not positive definite! " "This is a known issue for SNPE-A. Either try a different parameter " "setting or pass `allow_precision_correction=True` when constructing " - "the `MoGFlow_SNPE_A` density estimator." + "the `MDNWrapper_SNPE_A` density estimator." ) covariances_p = torch.inverse(precisions_p) diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py index 799b22448..dc6e9f77c 100644 --- a/sbi/utils/__init__.py +++ b/sbi/utils/__init__.py @@ -2,7 +2,6 @@ from sbi.utils.conditional_density import ( conditional_corrcoeff, eval_conditional_density, - mog_log_prob, ) from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn from sbi.utils.io import get_data_root, get_log_root, get_project_root @@ -19,6 +18,7 @@ handle_invalid_x, logit, mask_sims_from_prior, + mog_log_prob, sample_posterior_within_prior, standardizing_net, standardizing_transform, diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 3675ad018..07b194ef8 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -1,14 +1,12 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . -from math import pi from typing import Any, Callable, List, Optional, Tuple, Union from warnings import warn import torch from torch import Tensor -import sbi.utils as utils from sbi.utils.torchutils import ensure_theta_batched @@ -332,45 +330,3 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor: """ limits_diff = torch.prod(limits[:, 1] - limits[:, 0]) return probs * probs.numel() / limits_diff / torch.sum(probs) - - -def mog_log_prob( - theta: Tensor, - logits_pp: Tensor, - means_pp: Tensor, - precisions_pp: Tensor, -) -> Tensor: - r""" - Returns the log-probability of parameter sets $\theta$ under a mixture of Gaussians. - - Note that the mixture can have different logits, means, covariances for any theta in - the batch. This is because these values were computed from a batch of $x$ (and the - $x$ in the batch are not the same). - - This code is similar to the code of mdn.py in pyknos, but it does not use - log(det(Cov)) = -2*sum(log(diag(L))), L being Cholesky of Precision. Instead, it - just computes log(det(Cov)). Also, it uses the above-defined helper - `_batched_vmv()`. - - Args: - theta: Parameters at which to evaluate the mixture. - logits_pp: (Unnormalized) mixture components. - means_pp: Means of all mixture components. Shape - (batch_dim, num_components, theta_dim). - precisions_pp: Precisions of all mixtures. Shape - (batch_dim, num_components, theta_dim, theta_dim). - - Returns: The log-probability. - """ - - _, _, output_dim = means_pp.size() - theta = theta.view(-1, 1, output_dim) - - # Split up evaluation into parts. - weights = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True) - constant = -(output_dim / 2.0) * torch.log(torch.tensor([2 * pi])) - log_det = 0.5 * torch.log(torch.det(precisions_pp)) - theta_minus_mean = theta.expand_as(means_pp) - means_pp - exponent = -0.5 * utils.batched_mixture_vmv(precisions_pp, theta_minus_mean) - - return torch.logsumexp(weights + constant + log_det + exponent, dim=-1) diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py index a3bf6744a..70e26af52 100644 --- a/sbi/utils/get_nn_models.py +++ b/sbi/utils/get_nn_models.py @@ -13,7 +13,6 @@ ) from sbi.neural_nets.flow import build_made, build_maf, build_nsf from sbi.neural_nets.mdn import build_mdn -from sbi.neural_nets.mdn_snpe_a import build_mdn_snpe_a def classifier_nn( @@ -222,7 +221,7 @@ def build_fn(batch_theta, batch_x, num_components): # override this kwarg with functools.partial. This is necessary # in order to make sure that the MDN in SNPE-A only has one # component when running the Algorithm 1 part. - return build_mdn_snpe_a( + return build_mdn( batch_x=batch_theta, batch_y=batch_x, num_components=num_components, diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 528344b9b..5e22e61e2 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -3,6 +3,7 @@ import logging import warnings +from math import pi from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -15,6 +16,7 @@ from torch.distributions.distribution import Distribution from tqdm.auto import tqdm +from sbi import utils as utils from sbi.utils.torchutils import atleast_2d @@ -547,3 +549,45 @@ def log_prob(self, value: Tensor) -> Tensor: """ value = atleast_2d(value) return zeros(value.shape[0]) + + +def mog_log_prob( + theta: Tensor, + logits_pp: Tensor, + means_pp: Tensor, + precisions_pp: Tensor, +) -> Tensor: + r""" + Returns the log-probability of parameter sets $\theta$ under a mixture of Gaussians. + + Note that the mixture can have different logits, means, covariances for any theta in + the batch. This is because these values were computed from a batch of $x$ (and the + $x$ in the batch are not the same). + + This code is similar to the code of mdn.py in pyknos, but it does not use + log(det(Cov)) = -2*sum(log(diag(L))), L being Cholesky of Precision. Instead, it + just computes log(det(Cov)). Also, it uses the above-defined helper + `_batched_vmv()`. + + Args: + theta: Parameters at which to evaluate the mixture. + logits_pp: (Unnormalized) mixture components. + means_pp: Means of all mixture components. Shape + (batch_dim, num_components, theta_dim). + precisions_pp: Precisions of all mixtures. Shape + (batch_dim, num_components, theta_dim, theta_dim). + + Returns: The log-probability. + """ + + _, _, output_dim = means_pp.size() + theta = theta.view(-1, 1, output_dim) + + # Split up evaluation into parts. + weights = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True) + constant = -(output_dim / 2.0) * torch.log(torch.tensor([2 * pi])) + log_det = 0.5 * torch.log(torch.det(precisions_pp)) + theta_minus_mean = theta.expand_as(means_pp) - means_pp + exponent = -0.5 * utils.batched_mixture_vmv(precisions_pp, theta_minus_mean) + + return torch.logsumexp(weights + constant + log_det + exponent, dim=-1) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index a2d067d5b..e087afce8 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -288,7 +288,7 @@ def simulator(theta): ), ) def test_api_snpe_c_posterior_correction( - snpe_method, sample_with_mcmc, mcmc_method, prior_str, set_seed + snpe_method: type, sample_with_mcmc, mcmc_method, prior_str, set_seed ): """Test that leakage correction applied to sampling works, with both MCMC and rejection. @@ -342,8 +342,10 @@ def simulator(theta): @pytest.mark.slow -@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) -def test_sample_conditional(snpe_method, set_seed): +@pytest.mark.parametrize( + "snpe_method, num_rounds", [(SNPE_A, 2), (SNPE_A, 5), (SNPE_C, None)] +) +def test_sample_conditional(snpe_method: type, num_rounds: int, set_seed): """ Test whether sampling from the conditional gives the same results as evaluating. @@ -373,14 +375,17 @@ def simulator(theta): if snpe_method == SNPE_A: net = utils.posterior_nn("mdn_snpe_a", num_components=5, hidden_features=20) + extra_kwargs = dict(num_rounds=num_rounds) else: net = utils.posterior_nn("maf", hidden_features=20) + extra_kwargs = dict() simulator, prior = prepare_for_sbi(simulator, prior) inference = snpe_method( prior, density_estimator=net, show_progress_bars=False, + **extra_kwargs, ) # We need a pretty big dataset to properly model the bimodality. @@ -444,8 +449,10 @@ def simulator(theta): assert max_err < 0.0025 -@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) -def test_example_posterior(snpe_method): +@pytest.mark.parametrize( + "snpe_method, num_rounds", [(SNPE_A, 2), (SNPE_A, 10), (SNPE_C, None)] +) +def test_example_posterior(snpe_method: type, num_rounds: int): """Return an inferred `NeuralPosterior` for interactive examination.""" num_dim = 2 x_o = zeros(1, num_dim) @@ -461,11 +468,13 @@ def test_example_posterior(snpe_method): def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) + if snpe_method == SNPE_A: + extra_kwargs = dict(num_rounds=num_rounds) + else: + extra_kwargs = dict() + simulator, prior = prepare_for_sbi(simulator, prior) - inference = snpe_method( - prior, - show_progress_bars=False, - ) + inference = snpe_method(prior, show_progress_bars=False, **extra_kwargs) theta, x = simulate_for_sbi( simulator, prior, 1000, simulation_batch_size=10, num_workers=6 )