diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py
index 04e93cb41..d8b270e26 100644
--- a/sbi/inference/snpe/snpe_a.py
+++ b/sbi/inference/snpe/snpe_a.py
@@ -14,6 +14,7 @@
import sbi.utils as utils
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.snpe.snpe_base import PosteriorEstimator
+from sbi.neural_nets.mdn_wrapper_snpe_a import MDNWrapper_SNPE_A
from sbi.types import TensorboardSummaryWriter, TorchModule
@@ -32,9 +33,6 @@ def __init__(
):
r"""SNPE-A [1].
- https://github.com/mackelab/sbi/blob/main/sbi/inference/snpe/snpe_c.py
- https://github.com/mackelab/sbi/blob/main/sbi/neural_nets/mdn.py
-
[1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional
Density Estimation_, Papamakarios et al., NeurIPS 2016,
https://arxiv.org/abs/1605.06376.
@@ -44,18 +42,24 @@ def __init__(
parameters, e.g. which ranges are meaningful for them. Any
object with `.log_prob()`and `.sample()` (for example, a PyTorch
distribution) can be used.
- density_estimator: If it is a string, use a pre-configured network of the
- provided type (one of nsf, maf, mdn, made). Alternatively, a function
+ density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a
+ pre-configured mixture of densities network. Alternatively, a function
that builds a custom neural network can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob` and `.sample()`.
+ Note that until the last round only a single (multivariate) Gaussian
+ component is used for training (see Algorithm 1 in [1]). In the last
+ round, this component is replicated `num_components` times, its parameters
+ are perturbed with a very small noise, and then the last training round
+ is done with the expanded Gaussian mixture as estimator for the
+ proposal posterior.
num_components:
- Number of components of the mixture of Gaussians. This number is set to 1 before
- running Algorithm 1, and then later set to the specified value before running
- Algorithm 2.
+ Number of components of the mixture of Gaussians. This number is set to
+ 1 before running Algorithm 1, and then later set to the specified value
+ before running Algorithm 2.
num_rounds: Total number of training rounds. For all but the last round, Algorithm 1
from [1] is executed. For last round, Algorithm 2 from [1] is executed once.
By default, `num_rounds` is set to 1, i.e. only Algorithm 2 is executed once
@@ -71,8 +75,14 @@ def __init__(
0.14.0 is more mature, we will remove this argument.
"""
+ # Catch invalid inputs.
+ if not ((density_estimator == "mdn_snpe_a") or callable(density_estimator)):
+ raise TypeError(
+ "The `density_estimator` passed to SNPE_A needs to be a "
+ "callable or the string 'mdn_snpe_a'!"
+ )
+
self._num_rounds = num_rounds
- # TODO How to extract the number of components from density_estimator if that is a callable?
self._num_components = num_components
# WARNING: sneaky trick ahead. We proxy the parent's `train` here,
@@ -80,7 +90,14 @@ def __init__(
# continue. It's sneaky because we are using the object (self) as a namespace
# to pass arguments between functions, and that's implicit state management.
kwargs = utils.del_entries(
- locals(), entries=("self", "__class__", "unused_args", "num_rounds", "num_components")
+ locals(),
+ entries=(
+ "self",
+ "__class__",
+ "unused_args",
+ "num_rounds",
+ "num_components",
+ ),
)
super().__init__(**kwargs)
@@ -100,7 +117,9 @@ def train(
dataloader_kwargs: Optional[Dict] = None,
) -> DirectPosterior:
r"""
- Return density estimator that approximates the distribution $p(\theta|x)$.
+ Return density estimator that approximates the proposal posterior's
+ distribution $\tilde{p}(\theta|x)$.
+
Args:
training_batch_size: Training batch size.
learning_rate: Learning rate for Adam optimizer.
@@ -170,7 +189,9 @@ def train(
def build_posterior(
self,
- proposal: Union[MultivariateNormal, utils.BoxUniform, DirectPosterior],
+ proposal: Optional[
+ Union[MultivariateNormal, utils.BoxUniform, DirectPosterior]
+ ] = None,
density_estimator: Optional[TorchModule] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
sample_with_mcmc: bool = False,
@@ -180,7 +201,7 @@ def build_posterior(
r"""
Build posterior from the neural density estimator.
- For SNPE, the posterior distribution that is returned here implements the TODO
+ For SNPE, the posterior distribution that is returned here implements the
following functionality over the raw neural density estimator:
- correct the calculation of the log probability such that it compensates for
@@ -189,8 +210,14 @@ def build_posterior(
- alternatively, if leakage is very high (which can happen for multi-round
SNPE), sample from the posterior with MCMC.
+ The DirectPosterior class assumes that the density estimator approximates the posterior.
+ In SNPE-A, the density estimator is an approximation of the proposal posterior. Hence, importance reweigthing
+ is needed during evaluation.
+
Args:
- proposal: The distribution that the parameters $\theta$ were sampled from.
+ proposal: The proposal prior obtained from the from maximum-likelihood training.
+ If None, the posterior of the previous previous round is used. In the first round
+ the prior is used as a proposal.
density_estimator: The density estimator that the posterior is based on.
If `None`, use the latest neural density estimator that was trained.
rejection_sampling_parameters: Dictionary overriding the default parameters
@@ -229,20 +256,30 @@ def build_posterior(
# Set proposal of the density estimator.
# This also evokes the z-scoring correction is necessary.
- if isinstance(proposal, (MultivariateNormal, utils.BoxUniform)):
- density_estimator.set_proposal(proposal)
+ if proposal is None:
+ if self._model_bank:
+ proposal = self._model_bank[-1].net
+ else:
+ proposal = self._prior
+ elif isinstance(proposal, (MultivariateNormal, utils.BoxUniform)):
+ pass
elif isinstance(proposal, DirectPosterior):
- # Extract the MoGFlow_SNPE_A from the DirectPosterior.
- density_estimator.set_proposal(proposal.net)
+ # Extract the MDNWrapper_SNPE_A from the DirectPosterior.
+ proposal = proposal.net
else:
raise TypeError(
"So far, only MultivariateNormal, BoxUniform, and DirectPosterior are"
"supported for the `proposal` arg in SNPE_A.build_posterior()."
)
+ # Create the MDNWrapper_SNPE_A
+ wrapped_density_estimator = MDNWrapper_SNPE_A(
+ flow=density_estimator, proposal=proposal
+ )
+
self._posterior = DirectPosterior(
method_family="snpe",
- neural_net=density_estimator,
+ neural_net=wrapped_density_estimator,
prior=self._prior,
x_shape=self._x_shape,
rejection_sampling_parameters=rejection_sampling_parameters,
@@ -266,13 +303,8 @@ def _log_prob_proposal_posterior(
"""
Return the log-probability of the proposal posterior.
- .. note::
- For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in
- `_loss()` to be found in `snpe_base.py`.
-
- If the proposal is a MoG, the density estimator is a MoG, and the prior is
- either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which
- suffers from leakage).
+ For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in
+ `_loss()` to be found in `snpe_base.py`.
Args:
theta: Batch of parameters θ.
@@ -293,7 +325,8 @@ def _expand_mog(self, eps: float = 1e-5):
symmetry such that the gradients in the subsequent training are not
all identical.
- :param eps: Standard deviation for the random perturbation.
+ Args:
+ eps: Standard deviation for the random perturbation.
"""
assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN)
diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py
index b298a5155..775e0a06b 100644
--- a/sbi/inference/snpe/snpe_c.py
+++ b/sbi/inference/snpe/snpe_c.py
@@ -2,7 +2,6 @@
# under the Affero General Public License v3, see .
-from math import pi
from typing import Any, Callable, Dict, Optional, Union
import torch
@@ -252,7 +251,7 @@ def _set_maybe_z_scored_prior(self) -> None:
if isinstance(self._prior, MultivariateNormal):
self._maybe_z_scored_prior = MultivariateNormal(
- almost_zero_mean, torch.diag(almost_one_std),
+ almost_zero_mean, torch.diag(almost_one_std)
)
else:
range_ = torch.sqrt(almost_one_std * 3.0)
@@ -417,7 +416,9 @@ def _log_prob_proposal_posterior_mog(
)
# Compute the log_prob of theta under the product.
- log_prob_proposal_posterior = utils.mog_log_prob(theta, logits_pp, m_pp, prec_pp)
+ log_prob_proposal_posterior = utils.mog_log_prob(
+ theta, logits_pp, m_pp, prec_pp
+ )
utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval")
return log_prob_proposal_posterior
@@ -471,7 +472,7 @@ def _automatic_posterior_transformation(
)
means_pp = self._means_proposal_posterior(
- covariances_pp, means_p, precisions_p, means_d, precisions_d,
+ covariances_pp, means_p, precisions_p, means_d, precisions_d
)
logits_pp = self._logits_proposal_posterior(
@@ -489,7 +490,7 @@ def _automatic_posterior_transformation(
return logits_pp, means_pp, precisions_pp, covariances_pp
def _precisions_proposal_posterior(
- self, precisions_p: Tensor, precisions_d: Tensor,
+ self, precisions_p: Tensor, precisions_d: Tensor
):
"""
Return the precisions and covariances of the proposal posterior.
diff --git a/sbi/neural_nets/mdn.py b/sbi/neural_nets/mdn.py
index 900bc5796..2de2da657 100644
--- a/sbi/neural_nets/mdn.py
+++ b/sbi/neural_nets/mdn.py
@@ -6,7 +6,7 @@
from pyknos.nflows import flows, transforms
from torch import Tensor, nn
-from sbi.utils.sbiutils import standardizing_net, standardizing_transform
+import sbi.utils as utils
def build_mdn(
@@ -17,7 +17,7 @@ def build_mdn(
hidden_features: int = 50,
num_components: int = 10,
embedding_net: nn.Module = nn.Identity(),
- **kwargs
+ **kwargs,
) -> nn.Module:
"""Builds MDN p(x|y).
@@ -42,11 +42,11 @@ def build_mdn(
transform = transforms.IdentityTransform()
if z_score_x:
- transform_zx = standardizing_transform(batch_x)
+ transform_zx = utils.standardizing_transform(batch_x)
transform = transforms.CompositeTransform([transform_zx, transform])
if z_score_y:
- embedding_net = nn.Sequential(standardizing_net(batch_y), embedding_net)
+ embedding_net = nn.Sequential(utils.standardizing_net(batch_y), embedding_net)
distribution = MultivariateGaussianMDN(
features=x_numel,
diff --git a/sbi/neural_nets/mdn_snpe_a.py b/sbi/neural_nets/mdn_snpe_a.py
deleted file mode 100644
index 044aa0bf1..000000000
--- a/sbi/neural_nets/mdn_snpe_a.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
-# under the Affero General Public License v3, see .
-
-from pyknos.mdn.mdn import MultivariateGaussianMDN
-from pyknos.nflows import transforms
-from torch import Tensor, nn
-
-from sbi.neural_nets.mog_flow_snpe_a import MoGFlow_SNPE_A
-from sbi.utils.sbiutils import standardizing_net, standardizing_transform
-
-
-def build_mdn_snpe_a(
- batch_x: Tensor = None,
- batch_y: Tensor = None,
- z_score_x: bool = True,
- z_score_y: bool = True,
- hidden_features: int = 50,
- num_components: int = 10,
- embedding_net: nn.Module = nn.Identity(),
- **kwargs,
-) -> nn.Module:
- """Builds MDN p(x|y) with different sampling in training and evaluation mode.
-
- Args:
- batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring.
- batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring.
- z_score_x: Whether to z-score xs passing into the network.
- z_score_y: Whether to z-score ys passing into the network.
- hidden_features: Number of hidden features.
- num_components: Number of components.
- embedding_net: Optional embedding network for y.
- kwargs: Additional arguments that are passed by the build function but are not
- relevant for MDNs and are therefore ignored.
-
- Returns:
- Neural network.
- """
- x_numel = batch_x[0].numel()
- # Infer the output dimensionality of the embedding_net by making a forward pass.
- y_numel = embedding_net(batch_y[:1]).numel()
-
- transform = transforms.IdentityTransform()
-
- if z_score_x:
- transform_zx = standardizing_transform(batch_x)
- transform = transforms.CompositeTransform([transform_zx, transform])
-
- if z_score_y:
- embedding_net = nn.Sequential(standardizing_net(batch_y), embedding_net)
-
- distribution = MultivariateGaussianMDN(
- features=x_numel,
- context_features=y_numel,
- hidden_features=hidden_features,
- hidden_net=nn.Sequential(
- nn.Linear(y_numel, hidden_features),
- nn.ReLU(),
- nn.Dropout(p=0.0),
- nn.Linear(hidden_features, hidden_features),
- nn.ReLU(),
- nn.Linear(hidden_features, hidden_features),
- nn.ReLU(),
- ),
- num_components=num_components,
- custom_initialization=True,
- )
-
- neural_net = MoGFlow_SNPE_A(transform, distribution, embedding_net)
-
- return neural_net
diff --git a/sbi/neural_nets/mog_flow_snpe_a.py b/sbi/neural_nets/mdn_wrapper_snpe_a.py
similarity index 90%
rename from sbi/neural_nets/mog_flow_snpe_a.py
rename to sbi/neural_nets/mdn_wrapper_snpe_a.py
index 1445d57a7..0d9488de0 100644
--- a/sbi/neural_nets/mog_flow_snpe_a.py
+++ b/sbi/neural_nets/mdn_wrapper_snpe_a.py
@@ -5,6 +5,7 @@
from warnings import warn
import torch
+import torch.nn as nn
from pyknos.mdn.mdn import MultivariateGaussianMDN
from pyknos.nflows import flows
from pyknos.nflows.transforms import CompositeTransform
@@ -12,13 +13,16 @@
from torch.distributions import MultivariateNormal
import sbi.utils as utils
+import sbi.utils.sbiutils
from sbi.utils import torchutils
-class MoGFlow_SNPE_A(flows.Flow):
+class MDNWrapper_SNPE_A(nn.Module):
"""
- A wrapper for nflow's `Flow` class to enable a different log prob calculation
- sampling strategy for training and testing, tailored to SNPE-A [1]
+ A wrapper containing a `flows.Flow` class to enable a different log prob calculation
+ sampling strategy for training and testing, tailored to SNPE-A [1].
+ This class inherits from `nn.Module` since the constructor of `DirectPosterior`
+ expects the argument `neural_net` to be a `nn.Module`.
[1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional
Density Estimation_, Papamakarios et al., NeurIPS 2016,
@@ -29,38 +33,40 @@ class MoGFlow_SNPE_A(flows.Flow):
def __init__(
self,
- transform,
- distribution,
- embedding_net=None,
+ flow: flows.Flow,
+ proposal: Union["utils.BoxUniform", MultivariateNormal, "MDNWrapper_SNPE_A"],
allow_precision_correction: bool = False,
):
"""Constructor.
Args:
- transform: A `Transform` object, it transforms data into noise.
- distribution: A `Distribution` object, the base distribution of the flow that
- generates the noise.
- embedding_net: A `nn.Module` which has trainable parameters to encode the
- context (condition). It is trained jointly with the flow.
+ flow: The trained normalizing flow, passed when building the posterior.
+ proposal: The proposal distribution.
allow_precision_correction:
Add a diagonal with the smallest eigenvalue in every entry in case
the precision matrix becomes ill-conditioned.
"""
- # Construct the flow.
- super().__init__(transform, distribution, embedding_net)
+ # Call nn.Module's constructor.
+ super().__init__()
- self._proposal = None
+ self._neural_net = flow
self._allow_precision_correction = allow_precision_correction
+ # Set the proposal.
+ self._proposal = proposal
+ # Take care of z-scoring, pre-compute and store prior terms.
+ self._set_state_for_mog_proposal()
+
@property
def proposal(
self,
- ) -> Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"]:
+ ) -> Union["utils.BoxUniform", MultivariateNormal, "MDNWrapper_SNPE_A"]:
"""Get the proposal of the previous round."""
return self._proposal
def set_proposal(
- self, proposal: Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"]
+ self,
+ proposal: Union["utils.BoxUniform", MultivariateNormal, "MDNWrapper_SNPE_A"],
):
"""Set the proposal of the previous round."""
self._proposal = proposal
@@ -70,7 +76,7 @@ def set_proposal(
def _get_first_prior_from_proposal(
self,
- ) -> Union["utils.BoxUniform", MultivariateNormal, "MoGFlow_SNPE_A"]:
+ ) -> Union["utils.BoxUniform", MultivariateNormal, "MDNWrapper_SNPE_A"]:
"""Iterate a possible chain of proposals."""
curr_prior = self._proposal
@@ -87,13 +93,15 @@ def log_prob(self, inputs, context=None):
if self._proposal is None:
# Use Flow.lob_prob() if there has been no previous proposal memorized
# in this instance. This is the case if we are in the training
- # loop, i.e. this MoGFlow_SNPE_A instance is not an attribute of a
+ # loop, i.e. this MDNWrapper_SNPE_A instance is not an attribute of a
# DirectPosterior instance.
- return super().log_prob(inputs, context) # q_phi from eq (3) in [1]
+ return self._neural_net.log_prob(
+ inputs, context
+ ) # q_phi from eq (3) in [1]
elif isinstance(self._proposal, (utils.BoxUniform, MultivariateNormal)):
# No importance re-weighting is needed if the proposal prior is the prior
- return super().log_prob(inputs, context)
+ return self._neural_net.log_prob(inputs, context)
else:
# When we want to compute the approx. posterior, a proposal prior \tilde{p}
@@ -107,7 +115,7 @@ def log_prob(self, inputs, context=None):
theta = self._maybe_z_score_theta(inputs)
# Compute the log_prob of theta under the product.
- log_prob_proposal_posterior = utils.mog_log_prob(
+ log_prob_proposal_posterior = sbi.utils.sbiutils.mog_log_prob(
theta,
logits_pp,
m_pp,
@@ -122,14 +130,14 @@ def sample(self, num_samples, context=None, batch_size=None) -> Tensor:
if self._proposal is None:
# Use Flow.sample() if there has been no previous proposal memorized
# in this instance. This is the case if we are in the training
- # loop, i.e. this MoGFlow_SNPE_A instance is not an attribute of a
+ # loop, i.e. this MDNWrapper_SNPE_A instance is not an attribute of a
# DirectPosterior instance.
- return super().sample(num_samples, context, batch_size)
+ return self._neural_net.sample(num_samples, context, batch_size)
else:
# No importance re-weighting is needed if the proposal prior is the prior
if isinstance(self._proposal, (utils.BoxUniform, MultivariateNormal)):
- return super().sample(num_samples, context, batch_size)
+ return self._neural_net.sample(num_samples, context, batch_size)
# When we want to sample from the approx. posterior, a proposal prior \tilde{p}
# has already been observed. To analytically calculate the log-prob of the
@@ -174,7 +182,7 @@ def _sample_approx_posterior_mog(
num_samples, logits_pp, m_pp, prec_factors_pp
)
- embedded_context = self._embedding_net(x)
+ embedded_context = self._neural_net._embedding_net(x)
if embedded_context is not None:
# Merge the context dimension with sample dimension in order to
# apply the transform.
@@ -183,7 +191,7 @@ def _sample_approx_posterior_mog(
embedded_context, num_reps=num_samples
)
- theta, _ = self._transform.inverse(theta, context=embedded_context)
+ theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context)
if embedded_context is not None:
# Split the context dimension from sample dimension.
@@ -204,8 +212,8 @@ def _get_mixture_components(self, x: Tensor):
"""
# Evaluate the density estimator.
- encoded_x = self._embedding_net(x)
- dist = self._distribution # defined to avoid black formatting.
+ encoded_x = self._neural_net._embedding_net(x)
+ dist = self._neural_net._distribution # defined to avoid black formatting.
logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x)
norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True)
@@ -284,7 +292,7 @@ def _proposal_posterior_transformation(
precisions_d,
)
- logits_post = MoGFlow_SNPE_A._logits_posterior(
+ logits_post = MDNWrapper_SNPE_A._logits_posterior(
means_post,
precisions_post,
covariances_post,
@@ -300,7 +308,7 @@ def _proposal_posterior_transformation(
def _set_state_for_mog_proposal(self) -> None:
"""
- Set state variables of the MoGFlow_SNPE_A instance evevy time `set_proposal()`
+ Set state variables of the MDNWrapper_SNPE_A instance evevy time `set_proposal()`
is called, i.e. every time a posterior is build using `SNPE_A.build_posterior()`.
This function is almost identical to `SNPE_C._set_state_for_mog_proposal()`.
@@ -314,7 +322,7 @@ def _set_state_for_mog_proposal(self) -> None:
training step if the prior is Gaussian.
"""
- self.z_score_theta = isinstance(self._transform, CompositeTransform)
+ self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform)
self._set_maybe_z_scored_prior()
@@ -348,8 +356,8 @@ def _set_maybe_z_scored_prior(self) -> None:
prior = self._get_first_prior_from_proposal()
if self.z_score_theta:
- scale = self._transform._transforms[0]._scale
- shift = self._transform._transforms[0]._shift
+ scale = self._neural_net._transform._transforms[0]._scale
+ shift = self._neural_net._transform._transforms[0]._shift
# Following the definition of the linear transform in
# `standardizing_transform` in `sbiutils.py`:
@@ -383,7 +391,7 @@ def _maybe_z_score_theta(self, theta: Tensor) -> Tensor:
"""Return potentially standardized theta if z-scoring was requested."""
if self.z_score_theta:
- theta, _ = self._transform(theta)
+ theta, _ = self._neural_net._transform(theta)
return theta
@@ -452,7 +460,7 @@ def _precisions_posterior(self, precisions_pprior: Tensor, precisions_d: Tensor)
"The precision matrix of a posterior is not positive definite! "
"This is a known issue for SNPE-A. Either try a different parameter "
"setting or pass `allow_precision_correction=True` when constructing "
- "the `MoGFlow_SNPE_A` density estimator."
+ "the `MDNWrapper_SNPE_A` density estimator."
)
covariances_p = torch.inverse(precisions_p)
diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py
index 799b22448..dc6e9f77c 100644
--- a/sbi/utils/__init__.py
+++ b/sbi/utils/__init__.py
@@ -2,7 +2,6 @@
from sbi.utils.conditional_density import (
conditional_corrcoeff,
eval_conditional_density,
- mog_log_prob,
)
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.io import get_data_root, get_log_root, get_project_root
@@ -19,6 +18,7 @@
handle_invalid_x,
logit,
mask_sims_from_prior,
+ mog_log_prob,
sample_posterior_within_prior,
standardizing_net,
standardizing_transform,
diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py
index 3675ad018..07b194ef8 100644
--- a/sbi/utils/conditional_density.py
+++ b/sbi/utils/conditional_density.py
@@ -1,14 +1,12 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see .
-from math import pi
from typing import Any, Callable, List, Optional, Tuple, Union
from warnings import warn
import torch
from torch import Tensor
-import sbi.utils as utils
from sbi.utils.torchutils import ensure_theta_batched
@@ -332,45 +330,3 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor:
"""
limits_diff = torch.prod(limits[:, 1] - limits[:, 0])
return probs * probs.numel() / limits_diff / torch.sum(probs)
-
-
-def mog_log_prob(
- theta: Tensor,
- logits_pp: Tensor,
- means_pp: Tensor,
- precisions_pp: Tensor,
-) -> Tensor:
- r"""
- Returns the log-probability of parameter sets $\theta$ under a mixture of Gaussians.
-
- Note that the mixture can have different logits, means, covariances for any theta in
- the batch. This is because these values were computed from a batch of $x$ (and the
- $x$ in the batch are not the same).
-
- This code is similar to the code of mdn.py in pyknos, but it does not use
- log(det(Cov)) = -2*sum(log(diag(L))), L being Cholesky of Precision. Instead, it
- just computes log(det(Cov)). Also, it uses the above-defined helper
- `_batched_vmv()`.
-
- Args:
- theta: Parameters at which to evaluate the mixture.
- logits_pp: (Unnormalized) mixture components.
- means_pp: Means of all mixture components. Shape
- (batch_dim, num_components, theta_dim).
- precisions_pp: Precisions of all mixtures. Shape
- (batch_dim, num_components, theta_dim, theta_dim).
-
- Returns: The log-probability.
- """
-
- _, _, output_dim = means_pp.size()
- theta = theta.view(-1, 1, output_dim)
-
- # Split up evaluation into parts.
- weights = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True)
- constant = -(output_dim / 2.0) * torch.log(torch.tensor([2 * pi]))
- log_det = 0.5 * torch.log(torch.det(precisions_pp))
- theta_minus_mean = theta.expand_as(means_pp) - means_pp
- exponent = -0.5 * utils.batched_mixture_vmv(precisions_pp, theta_minus_mean)
-
- return torch.logsumexp(weights + constant + log_det + exponent, dim=-1)
diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py
index a3bf6744a..70e26af52 100644
--- a/sbi/utils/get_nn_models.py
+++ b/sbi/utils/get_nn_models.py
@@ -13,7 +13,6 @@
)
from sbi.neural_nets.flow import build_made, build_maf, build_nsf
from sbi.neural_nets.mdn import build_mdn
-from sbi.neural_nets.mdn_snpe_a import build_mdn_snpe_a
def classifier_nn(
@@ -222,7 +221,7 @@ def build_fn(batch_theta, batch_x, num_components):
# override this kwarg with functools.partial. This is necessary
# in order to make sure that the MDN in SNPE-A only has one
# component when running the Algorithm 1 part.
- return build_mdn_snpe_a(
+ return build_mdn(
batch_x=batch_theta,
batch_y=batch_x,
num_components=num_components,
diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py
index 528344b9b..5e22e61e2 100644
--- a/sbi/utils/sbiutils.py
+++ b/sbi/utils/sbiutils.py
@@ -3,6 +3,7 @@
import logging
import warnings
+from math import pi
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
@@ -15,6 +16,7 @@
from torch.distributions.distribution import Distribution
from tqdm.auto import tqdm
+from sbi import utils as utils
from sbi.utils.torchutils import atleast_2d
@@ -547,3 +549,45 @@ def log_prob(self, value: Tensor) -> Tensor:
"""
value = atleast_2d(value)
return zeros(value.shape[0])
+
+
+def mog_log_prob(
+ theta: Tensor,
+ logits_pp: Tensor,
+ means_pp: Tensor,
+ precisions_pp: Tensor,
+) -> Tensor:
+ r"""
+ Returns the log-probability of parameter sets $\theta$ under a mixture of Gaussians.
+
+ Note that the mixture can have different logits, means, covariances for any theta in
+ the batch. This is because these values were computed from a batch of $x$ (and the
+ $x$ in the batch are not the same).
+
+ This code is similar to the code of mdn.py in pyknos, but it does not use
+ log(det(Cov)) = -2*sum(log(diag(L))), L being Cholesky of Precision. Instead, it
+ just computes log(det(Cov)). Also, it uses the above-defined helper
+ `_batched_vmv()`.
+
+ Args:
+ theta: Parameters at which to evaluate the mixture.
+ logits_pp: (Unnormalized) mixture components.
+ means_pp: Means of all mixture components. Shape
+ (batch_dim, num_components, theta_dim).
+ precisions_pp: Precisions of all mixtures. Shape
+ (batch_dim, num_components, theta_dim, theta_dim).
+
+ Returns: The log-probability.
+ """
+
+ _, _, output_dim = means_pp.size()
+ theta = theta.view(-1, 1, output_dim)
+
+ # Split up evaluation into parts.
+ weights = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True)
+ constant = -(output_dim / 2.0) * torch.log(torch.tensor([2 * pi]))
+ log_det = 0.5 * torch.log(torch.det(precisions_pp))
+ theta_minus_mean = theta.expand_as(means_pp) - means_pp
+ exponent = -0.5 * utils.batched_mixture_vmv(precisions_pp, theta_minus_mean)
+
+ return torch.logsumexp(weights + constant + log_det + exponent, dim=-1)
diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py
index a2d067d5b..e087afce8 100644
--- a/tests/linearGaussian_snpe_test.py
+++ b/tests/linearGaussian_snpe_test.py
@@ -288,7 +288,7 @@ def simulator(theta):
),
)
def test_api_snpe_c_posterior_correction(
- snpe_method, sample_with_mcmc, mcmc_method, prior_str, set_seed
+ snpe_method: type, sample_with_mcmc, mcmc_method, prior_str, set_seed
):
"""Test that leakage correction applied to sampling works, with both MCMC and
rejection.
@@ -342,8 +342,10 @@ def simulator(theta):
@pytest.mark.slow
-@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
-def test_sample_conditional(snpe_method, set_seed):
+@pytest.mark.parametrize(
+ "snpe_method, num_rounds", [(SNPE_A, 2), (SNPE_A, 5), (SNPE_C, None)]
+)
+def test_sample_conditional(snpe_method: type, num_rounds: int, set_seed):
"""
Test whether sampling from the conditional gives the same results as evaluating.
@@ -373,14 +375,17 @@ def simulator(theta):
if snpe_method == SNPE_A:
net = utils.posterior_nn("mdn_snpe_a", num_components=5, hidden_features=20)
+ extra_kwargs = dict(num_rounds=num_rounds)
else:
net = utils.posterior_nn("maf", hidden_features=20)
+ extra_kwargs = dict()
simulator, prior = prepare_for_sbi(simulator, prior)
inference = snpe_method(
prior,
density_estimator=net,
show_progress_bars=False,
+ **extra_kwargs,
)
# We need a pretty big dataset to properly model the bimodality.
@@ -444,8 +449,10 @@ def simulator(theta):
assert max_err < 0.0025
-@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
-def test_example_posterior(snpe_method):
+@pytest.mark.parametrize(
+ "snpe_method, num_rounds", [(SNPE_A, 2), (SNPE_A, 10), (SNPE_C, None)]
+)
+def test_example_posterior(snpe_method: type, num_rounds: int):
"""Return an inferred `NeuralPosterior` for interactive examination."""
num_dim = 2
x_o = zeros(1, num_dim)
@@ -461,11 +468,13 @@ def test_example_posterior(snpe_method):
def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)
+ if snpe_method == SNPE_A:
+ extra_kwargs = dict(num_rounds=num_rounds)
+ else:
+ extra_kwargs = dict()
+
simulator, prior = prepare_for_sbi(simulator, prior)
- inference = snpe_method(
- prior,
- show_progress_bars=False,
- )
+ inference = snpe_method(prior, show_progress_bars=False, **extra_kwargs)
theta, x = simulate_for_sbi(
simulator, prior, 1000, simulation_batch_size=10, num_workers=6
)