Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive Neural Ratio Estimation #787

Merged
merged 21 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ The following algorithms are currently available:

* [`SNRE_B`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_b.SNRE_B) or `SRE` from Durkan C, Murray I, and Papamakarios G. [_On Contrastive Learning for Likelihood-free Inference_](https://arxiv.org/abs/2002.03712) (ICML 2020).

* [`SNRE_C](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_c.SNRE_C) or `NRE-C` from Miller BK, Weniger C, Forré P. [_Contrastive Neural Ratio Estimation_](https://arxiv.org/abs/2210.06170) (NeurIPS 2022). Sequential support introduced in this repository.
bkmi marked this conversation as resolved.
Show resolved Hide resolved

#### Sequential Neural Variational Inference (SNVI)

* [`SNVI`](https://www.mackelab.org/sbi/reference/#sbi.inference.posteriors.vi_posterior) from Glöckler M, Deistler M, Macke J, [_Variational methods for simulation-based inference_](https://openreview.net/forum?id=kZ0UYdhqkNY) (ICLR 2022).
Expand Down
1 change: 1 addition & 0 deletions docs/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ The following papers offer additional details on the inference methods included

- **On Contrastive Learning for Likelihood-free Inference**<br>Durkan, Murray & Papamakarios (ICML 2020) <br>[[PDF]](http://proceedings.mlr.press/v119/durkan20a/durkan20a.pdf)

- **Contrastive Neural Ratio Estimation**<br>Benjamin Kurt Miller, Christoph Weniger, Patrick Forré (NeurIPS 2022) <br>[[PDF]](https://arxiv.org/pdf/2210.06170.pdf)

### Utilities

Expand Down
7 changes: 7 additions & 0 deletions docs/docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@
filters: [ "!^_", "^__", "!^__class__" ]
inherited_members: true

::: sbi.inference.snre.snre_c.SNRE_C
rendering:
show_root_heading: true
selection:
filters: [ "!^_", "^__", "!^__class__" ]
inherited_members: true

::: sbi.inference.abc.mcabc.MCABC
rendering:
show_root_heading: true
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sbi.inference.snpe.snpe_a import SNPE_A
from sbi.inference.snpe.snpe_b import SNPE_B
from sbi.inference.snpe.snpe_c import SNPE_C # noqa: F401
from sbi.inference.snre import SNRE, SNRE_A, SNRE_B # noqa: F401
from sbi.inference.snre import SNRE, SNRE_A, SNRE_B, SNRE_C # noqa: F401
from sbi.utils.user_input_checks import prepare_for_sbi

SNL = SNLE = SNLE_A
Expand All @@ -37,7 +37,7 @@

SRE = SNRE_B
AALR = SNRE_A
_snre_family = ["SNRE_A", "AALR", "SNRE_B", "SNRE", "SRE"]
_snre_family = ["SNRE_A", "AALR", "SNRE_B", "SNRE", "SRE", "SNRE_C"]

ABC = MCABC
SMC = SMCABC
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/potentials/ratio_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def ratio_estimator_based_potential(
ratio_estimator: The neural network modelling likelihood-to-evidence ratio.
prior: The prior distribution.
x_o: The observed data at which to evaluate the likelihood-to-evidence ratio.
enable_transform: Whether to transform parameters to unconstrained space.
enable_transform: when false, theta_transform is the identity.
bkmi marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The potential function and a transformation that maps
Expand Down
2 changes: 2 additions & 0 deletions sbi/inference/snre/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from sbi.inference.snre.snre_a import SNRE_A
from sbi.inference.snre.snre_b import SNRE_B
from sbi.inference.snre.snre_c import SNRE_C

# Aliases
AALR = SNRE_A
SRE = SNRE = SNRE_B
CNRE = NREC = SNRE_C
bkmi marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion sbi/inference/snre/snre_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def train(
return super().train(**kwargs)

def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
r"""Return cross-entropy loss for 1-out-of-`num_atoms` classification.
r"""Return cross-entropy (via softmax activation) loss for 1-out-of-`num_atoms` classification.
bkmi marked this conversation as resolved.
Show resolved Hide resolved

The classifier takes as input `num_atoms` $(\theta,x)$ pairs. Out of these
pairs, one pair was sampled from the joint $p(\theta,x)$ and all others from the
Expand Down
18 changes: 15 additions & 3 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(
- SNRE_B / SRE can use more than two atoms, potentially boosting performance,
but allows for posterior evaluation **only up to a normalizing constant**,
even when training only one round.
- SNRE_C is a generalization of SNRE_A and SNRE_B which can use multiple classes
(atoms) but encourages an exact likelihood-to-evidence ratio (density evaluation)
by introducing a independently drawn class. Exactness holds only for the first round.
janfb marked this conversation as resolved.
Show resolved Hide resolved

Args:
classifier: Classifier trained to approximate likelihood ratios. If it is
Expand Down Expand Up @@ -150,7 +153,8 @@ def train(
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[dict] = None,
bkmi marked this conversation as resolved.
Show resolved Hide resolved
loss_kwargs: Dict[str, Any] = {},
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.

Expand All @@ -169,6 +173,7 @@ def train(
estimator for the posterior from scratch each round.
dataloader_kwargs: Additional or updated kwargs to be passed to the training
and validation dataloaders (like, e.g., a collate_fn).
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.

Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Expand Down Expand Up @@ -233,7 +238,9 @@ def train(
batch[1].to(self._device),
)

train_losses = self._loss(theta_batch, x_batch, num_atoms)
train_losses = self._loss(
theta_batch, x_batch, num_atoms, **loss_kwargs
)
train_loss = torch.mean(train_losses)
train_log_probs_sum -= train_losses.sum().item()

Expand Down Expand Up @@ -261,7 +268,9 @@ def train(
batch[0].to(self._device),
batch[1].to(self._device),
)
val_losses = self._loss(theta_batch, x_batch, num_atoms)
val_losses = self._loss(
theta_batch, x_batch, num_atoms, **loss_kwargs
)
val_log_prob_sum -= val_losses.sum().item()
# Take mean over all validation samples.
self._val_log_prob = val_log_prob_sum / (
Expand Down Expand Up @@ -326,6 +335,7 @@ def build_posterior(
mcmc_parameters: Dict[str, Any] = {},
vi_parameters: Dict[str, Any] = {},
rejection_sampling_parameters: Dict[str, Any] = {},
enable_transform: bool = True,
bkmi marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior]:
r"""Build posterior from the neural density estimator.

Expand Down Expand Up @@ -354,6 +364,7 @@ def build_posterior(
vi_parameters: Additional kwargs passed to `VIPosterior`.
rejection_sampling_parameters: Additional kwargs passed to
`RejectionPosterior`.
enable_transform: whether to compute a transform to unbounded space

Returns:
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods
Expand Down Expand Up @@ -382,6 +393,7 @@ def build_posterior(
ratio_estimator=ratio_estimator,
prior=prior,
x_o=None,
enable_transform=enable_transform,
)

if sample_with == "mcmc":
Expand Down
181 changes: 181 additions & 0 deletions sbi/inference/snre/snre_c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from typing import Callable, Dict, Optional, Tuple, Union

import torch
from torch import Tensor, nn
from torch.distributions import Distribution

from sbi.inference.snre.snre_base import RatioEstimator
from sbi.types import TensorboardSummaryWriter
from sbi.utils import del_entries, repeat_rows


class SNRE_C(RatioEstimator):
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, Callable] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""A sequential extension to NRE-C[1], a generalization of SNRE_A and SNRE_B.
bkmi marked this conversation as resolved.
Show resolved Hide resolved
We call the algorithm SNRE_C within sbi.

[1] _Contrastive Neural Ratio Estimation_, Benajmin Kurt Miller, et. al.,
NeurIPS 2022, https://arxiv.org/abs/2210.06170

Args:
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet). Alternatively, a function that builds a custom
neural network can be provided. The function will be called with the
first batch of simulations (theta, x), which can thus be used for shape
inference and potentially for z-scoring. It needs to return a PyTorch
`nn.Module` implementing the classifier.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
summary_writer: A tensorboard `SummaryWriter` to control, among others, log
file location (default is `<current working directory>/logs`.)
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
"""

kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
self,
K: int = 9,
gamma: float = 1.0,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.

Args:
K: Number of theta to classify against. Minimum 1. Similar to `num_atoms`
bkmi marked this conversation as resolved.
Show resolved Hide resolved
for SNRE_B except SNRE_C has an additional independently drawn sample.
gamma: Determines the relative weight of the sum of all $K$ dependently drawn
classes against the marginally drawn one. Specifically, $p(y=k) := p_K$,
$p(y=0) := p_0$, $p_0 = 1 - K p_K$, and finally $\gamma := K p_K / p_0$.
training_batch_size: Training batch size.
learning_rate: Learning rate for Adam optimizer.
validation_fraction: The fraction of data to use for validation.
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training. Expect errors, silent or explicit, when `False`.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
be restored from the last time `.train()` was called.
discard_prior_samples: Whether to discard samples simulated in round 1, i.e.
from the prior. Training may be sped up by ignoring such less targeted
samples.
retrain_from_scratch: Whether to retrain the conditional density
estimator for the posterior from scratch each round.
show_train_summary: Whether to print the number of epochs and validation
loss and leakage after the training.
dataloader_kwargs: Additional or updated kwargs to be passed to the training
and validation dataloaders (like, e.g., a collate_fn)

Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
kwargs["num_atoms"] = kwargs.pop("K") + 1
kwargs["loss_kwargs"] = {"gamma": kwargs.pop("gamma")}
bkmi marked this conversation as resolved.
Show resolved Hide resolved
return super().train(**kwargs)

def _loss(
self, theta: Tensor, x: Tensor, num_atoms: int, gamma: float
) -> torch.Tensor:
r"""Return cross-entropy loss (via ''multi-class sigmoid'' activation) for 1-out-of-`K + 1` classification.

At optimum, this loss function returns the exact likelihood-to-evidence ratio in the first round.
Details of loss computation are described in Contrastive Neural Ratio Estimation[1]. The paper
does not discuss the sequential case.
janfb marked this conversation as resolved.
Show resolved Hide resolved

[1] _Contrastive Neural Ratio Estimation_, Benajmin Kurt Miller, et. al.,
NeurIPS 2022, https://arxiv.org/abs/2210.06170
"""

# The algorithm is written with K, so we convert back to K format rather than reasoning in num_atoms.
K = num_atoms - 1
assert K >= 1
bkmi marked this conversation as resolved.
Show resolved Hide resolved

assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
batch_size = theta.shape[0]

# We append an contrastive theta to the marginal case because we will remove the jointly drawn
# sample in the logits_marginal[:, 0] position. That makes the remaining sample marginally drawn.
# We have a batch of `batch_size` datapoints.
logits_marginal = self._classifier_logits(theta, x, K + 1).reshape(
batch_size, K + 1
)
logits_joint = self._classifier_logits(theta, x, K).reshape(batch_size, K)

dtype = logits_marginal.dtype
device = logits_marginal.device

# Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence
bkmi marked this conversation as resolved.
Show resolved Hide resolved
# we remove the jointly drawn sample from the logits_marginal
logits_marginal = logits_marginal[:, 1:]
# ... and retain it in the logits_joint. Now we have two arrays with `K` choices.

# To use logsumexp, we extend the denominator logits with loggamma
loggamma = torch.tensor(gamma, dtype=dtype, device=device).log()
logK = torch.tensor(K, dtype=dtype, device=device).log()
denominator_marginal = torch.concat(
[loggamma + logits_marginal, logK.expand((batch_size, 1))],
dim=-1,
)
denominator_joint = torch.concat(
[loggamma + logits_joint, logK.expand((batch_size, 1))],
dim=-1,
)

# Compute the contributions to the loss from each term in the classification.
log_prob_marginal = logK - torch.logsumexp(denominator_marginal, dim=-1)
log_prob_joint = (
loggamma + logits_joint[:, 0] - torch.logsumexp(denominator_joint, dim=-1)
)

# relative weights. p_marginal := p_0, and p_joint := p_K from the notation.
p_marginal, p_joint = self._get_prior_probs_marginal_and_joint(K, gamma)
return -torch.mean(
p_marginal * log_prob_marginal + p_joint * K * log_prob_joint
)

@staticmethod
def _get_prior_probs_marginal_and_joint(
K: int, gamma: float
) -> Tuple[float, float]:
"""Return a tuple (p_marginal, p_joint) where `p_marginal := `$p_0$, `p_joint := `$p_K$.

We let the joint (dependently drawn) class to be equally likely across K options.
The marginal class is therefore restricted to get the remaining probability.
"""
assert K >= 1
p_joint = gamma / (1 + gamma * K)
p_marginal = 1 / (1 + gamma * K)
return p_marginal, p_joint
Loading