Skip to content

Commit

Permalink
SNPE-A review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
famura authored and michaeldeistler committed May 4, 2021
1 parent 93152e1 commit 96d3a3b
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 197 deletions.
87 changes: 60 additions & 27 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import sbi.utils as utils
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.snpe.snpe_base import PosteriorEstimator
from sbi.neural_nets.mdn_wrapper_snpe_a import MDNWrapper_SNPE_A
from sbi.types import TensorboardSummaryWriter, TorchModule


Expand All @@ -32,9 +33,6 @@ def __init__(
):
r"""SNPE-A [1].
https://github.com/mackelab/sbi/blob/main/sbi/inference/snpe/snpe_c.py
https://github.com/mackelab/sbi/blob/main/sbi/neural_nets/mdn.py
[1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional
Density Estimation_, Papamakarios et al., NeurIPS 2016,
https://arxiv.org/abs/1605.06376.
Expand All @@ -44,18 +42,24 @@ def __init__(
parameters, e.g. which ranges are meaningful for them. Any
object with `.log_prob()`and `.sample()` (for example, a PyTorch
distribution) can be used.
density_estimator: If it is a string, use a pre-configured network of the
provided type (one of nsf, maf, mdn, made). Alternatively, a function
density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a
pre-configured mixture of densities network. Alternatively, a function
that builds a custom neural network can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob` and `.sample()`.
Note that until the last round only a single (multivariate) Gaussian
component is used for training (see Algorithm 1 in [1]). In the last
round, this component is replicated `num_components` times, its parameters
are perturbed with a very small noise, and then the last training round
is done with the expanded Gaussian mixture as estimator for the
proposal posterior.
num_components:
Number of components of the mixture of Gaussians. This number is set to 1 before
running Algorithm 1, and then later set to the specified value before running
Algorithm 2.
Number of components of the mixture of Gaussians. This number is set to
1 before running Algorithm 1, and then later set to the specified value
before running Algorithm 2.
num_rounds: Total number of training rounds. For all but the last round, Algorithm 1
from [1] is executed. For last round, Algorithm 2 from [1] is executed once.
By default, `num_rounds` is set to 1, i.e. only Algorithm 2 is executed once
Expand All @@ -71,16 +75,29 @@ def __init__(
0.14.0 is more mature, we will remove this argument.
"""

# Catch invalid inputs.
if not ((density_estimator == "mdn_snpe_a") or callable(density_estimator)):
raise TypeError(
"The `density_estimator` passed to SNPE_A needs to be a "
"callable or the string 'mdn_snpe_a'!"
)

self._num_rounds = num_rounds
# TODO How to extract the number of components from density_estimator if that is a callable?
self._num_components = num_components

# WARNING: sneaky trick ahead. We proxy the parent's `train` here,
# requiring the signature to have `num_atoms`, save it for use below, and
# continue. It's sneaky because we are using the object (self) as a namespace
# to pass arguments between functions, and that's implicit state management.
kwargs = utils.del_entries(
locals(), entries=("self", "__class__", "unused_args", "num_rounds", "num_components")
locals(),
entries=(
"self",
"__class__",
"unused_args",
"num_rounds",
"num_components",
),
)
super().__init__(**kwargs)

Expand All @@ -100,7 +117,9 @@ def train(
dataloader_kwargs: Optional[Dict] = None,
) -> DirectPosterior:
r"""
Return density estimator that approximates the distribution $p(\theta|x)$.
Return density estimator that approximates the proposal posterior's
distribution $\tilde{p}(\theta|x)$.
Args:
training_batch_size: Training batch size.
learning_rate: Learning rate for Adam optimizer.
Expand Down Expand Up @@ -170,7 +189,9 @@ def train(

def build_posterior(
self,
proposal: Union[MultivariateNormal, utils.BoxUniform, DirectPosterior],
proposal: Optional[
Union[MultivariateNormal, utils.BoxUniform, DirectPosterior]
] = None,
density_estimator: Optional[TorchModule] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
sample_with_mcmc: bool = False,
Expand All @@ -180,7 +201,7 @@ def build_posterior(
r"""
Build posterior from the neural density estimator.
For SNPE, the posterior distribution that is returned here implements the TODO
For SNPE, the posterior distribution that is returned here implements the
following functionality over the raw neural density estimator:
- correct the calculation of the log probability such that it compensates for
Expand All @@ -189,8 +210,14 @@ def build_posterior(
- alternatively, if leakage is very high (which can happen for multi-round
SNPE), sample from the posterior with MCMC.
The DirectPosterior class assumes that the density estimator approximates the posterior.
In SNPE-A, the density estimator is an approximation of the proposal posterior. Hence, importance reweigthing
is needed during evaluation.
Args:
proposal: The distribution that the parameters $\theta$ were sampled from.
proposal: The proposal prior obtained from the from maximum-likelihood training.
If None, the posterior of the previous previous round is used. In the first round
the prior is used as a proposal.
density_estimator: The density estimator that the posterior is based on.
If `None`, use the latest neural density estimator that was trained.
rejection_sampling_parameters: Dictionary overriding the default parameters
Expand Down Expand Up @@ -229,20 +256,30 @@ def build_posterior(

# Set proposal of the density estimator.
# This also evokes the z-scoring correction is necessary.
if isinstance(proposal, (MultivariateNormal, utils.BoxUniform)):
density_estimator.set_proposal(proposal)
if proposal is None:
if self._model_bank:
proposal = self._model_bank[-1].net
else:
proposal = self._prior
elif isinstance(proposal, (MultivariateNormal, utils.BoxUniform)):
pass
elif isinstance(proposal, DirectPosterior):
# Extract the MoGFlow_SNPE_A from the DirectPosterior.
density_estimator.set_proposal(proposal.net)
# Extract the MDNWrapper_SNPE_A from the DirectPosterior.
proposal = proposal.net
else:
raise TypeError(
"So far, only MultivariateNormal, BoxUniform, and DirectPosterior are"
"supported for the `proposal` arg in SNPE_A.build_posterior()."
)

# Create the MDNWrapper_SNPE_A
wrapped_density_estimator = MDNWrapper_SNPE_A(
flow=density_estimator, proposal=proposal
)

self._posterior = DirectPosterior(
method_family="snpe",
neural_net=density_estimator,
neural_net=wrapped_density_estimator,
prior=self._prior,
x_shape=self._x_shape,
rejection_sampling_parameters=rejection_sampling_parameters,
Expand All @@ -266,13 +303,8 @@ def _log_prob_proposal_posterior(
"""
Return the log-probability of the proposal posterior.
.. note::
For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in
`_loss()` to be found in `snpe_base.py`.
If the proposal is a MoG, the density estimator is a MoG, and the prior is
either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which
suffers from leakage).
For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in
`_loss()` to be found in `snpe_base.py`.
Args:
theta: Batch of parameters θ.
Expand All @@ -293,7 +325,8 @@ def _expand_mog(self, eps: float = 1e-5):
symmetry such that the gradients in the subsequent training are not
all identical.
:param eps: Standard deviation for the random perturbation.
Args:
eps: Standard deviation for the random perturbation.
"""
assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN)

Expand Down
11 changes: 6 additions & 5 deletions sbi/inference/snpe/snpe_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.


from math import pi
from typing import Any, Callable, Dict, Optional, Union

import torch
Expand Down Expand Up @@ -252,7 +251,7 @@ def _set_maybe_z_scored_prior(self) -> None:

if isinstance(self._prior, MultivariateNormal):
self._maybe_z_scored_prior = MultivariateNormal(
almost_zero_mean, torch.diag(almost_one_std),
almost_zero_mean, torch.diag(almost_one_std)
)
else:
range_ = torch.sqrt(almost_one_std * 3.0)
Expand Down Expand Up @@ -417,7 +416,9 @@ def _log_prob_proposal_posterior_mog(
)

# Compute the log_prob of theta under the product.
log_prob_proposal_posterior = utils.mog_log_prob(theta, logits_pp, m_pp, prec_pp)
log_prob_proposal_posterior = utils.mog_log_prob(
theta, logits_pp, m_pp, prec_pp
)
utils.assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval")

return log_prob_proposal_posterior
Expand Down Expand Up @@ -471,7 +472,7 @@ def _automatic_posterior_transformation(
)

means_pp = self._means_proposal_posterior(
covariances_pp, means_p, precisions_p, means_d, precisions_d,
covariances_pp, means_p, precisions_p, means_d, precisions_d
)

logits_pp = self._logits_proposal_posterior(
Expand All @@ -489,7 +490,7 @@ def _automatic_posterior_transformation(
return logits_pp, means_pp, precisions_pp, covariances_pp

def _precisions_proposal_posterior(
self, precisions_p: Tensor, precisions_d: Tensor,
self, precisions_p: Tensor, precisions_d: Tensor
):
"""
Return the precisions and covariances of the proposal posterior.
Expand Down
8 changes: 4 additions & 4 deletions sbi/neural_nets/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyknos.nflows import flows, transforms
from torch import Tensor, nn

from sbi.utils.sbiutils import standardizing_net, standardizing_transform
import sbi.utils as utils


def build_mdn(
Expand All @@ -17,7 +17,7 @@ def build_mdn(
hidden_features: int = 50,
num_components: int = 10,
embedding_net: nn.Module = nn.Identity(),
**kwargs
**kwargs,
) -> nn.Module:
"""Builds MDN p(x|y).
Expand All @@ -42,11 +42,11 @@ def build_mdn(
transform = transforms.IdentityTransform()

if z_score_x:
transform_zx = standardizing_transform(batch_x)
transform_zx = utils.standardizing_transform(batch_x)
transform = transforms.CompositeTransform([transform_zx, transform])

if z_score_y:
embedding_net = nn.Sequential(standardizing_net(batch_y), embedding_net)
embedding_net = nn.Sequential(utils.standardizing_net(batch_y), embedding_net)

distribution = MultivariateGaussianMDN(
features=x_numel,
Expand Down
70 changes: 0 additions & 70 deletions sbi/neural_nets/mdn_snpe_a.py

This file was deleted.

Loading

0 comments on commit 96d3a3b

Please sign in to comment.