Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename classifier (for NRE) to critic #1103

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/docs/reference/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

::: sbi.neural_nets.factory.likelihood_nn

::: sbi.neural_nets.factory.classifier_nn
::: sbi.neural_nets.factory.critic_nn

::: sbi.neural_nets.density_estimators.DensityEstimator
4 changes: 2 additions & 2 deletions sbi/inference/potentials/ratio_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _log_ratios_over_trials(
Args:
x: batch of iid data.
theta: batch of parameters
net: neural net representing the classifier to approximate the ratio.
net: neural net representing the critic to approximate the ratio.
track_gradients: Whether to track gradients.
Returns:
log_ratio_trial_sum: log ratio for each parameter, summed over all
Expand All @@ -128,7 +128,7 @@ def _log_ratios_over_trials(

# Calculate ratios in one batch.
with torch.set_grad_enabled(track_gradients):
log_ratio_trial_batch = net([theta_repeated, x_repeated])
log_ratio_trial_batch = net(theta_repeated, x_repeated)
bkmi marked this conversation as resolved.
Show resolved Hide resolved
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_ratio_trial_sum = log_ratio_trial_batch.reshape(x.shape[0], -1).sum(0)

Expand Down
13 changes: 7 additions & 6 deletions sbi/inference/snre/bnre.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class BNRE(SNRE_A):
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, Callable] = "resnet",
critic: Union[str, Callable] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
Expand All @@ -31,13 +31,13 @@ def __init__(
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
critic: Critic trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet). Alternatively, a function that builds a custom
neural network can be provided. The function will be called with the
first batch of simulations $(\theta, x)$, which can thus be used for
shape inference and potentially for z-scoring. It needs to return a
PyTorch `nn.Module` implementing the classifier.
PyTorch `nn.Module` implementing the critic.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
Expand Down Expand Up @@ -65,7 +65,7 @@ def train(
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
r"""Return critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Args:

regularization_strength: The multiplicative coefficient applied to the
Expand Down Expand Up @@ -96,7 +96,7 @@ def train(
dataloader_kwargs: Additional or updated kwargs to be passed to the training
and validation dataloaders (like, e.g., a collate_fn)
Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
kwargs["loss_kwargs"] = {
Expand All @@ -107,7 +107,8 @@ def train(
def _loss(
self, theta: Tensor, x: Tensor, num_atoms: int, regularization_strength: float
) -> Tensor:
"""Returns the binary cross-entropy loss for the trained classifier.
"""Returns the binary cross-entropy loss for the classifier
(defined by the critic).

The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the
Expand Down
13 changes: 7 additions & 6 deletions sbi/inference/snre/snre_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SNRE_A(RatioEstimator):
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, Callable] = "resnet",
critic: Union[str, Callable] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
Expand All @@ -28,13 +28,13 @@ def __init__(
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
critic: Critic trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet). Alternatively, a function that builds a custom
neural network can be provided. The function will be called with the
first batch of simulations (theta, x), which can thus be used for shape
inference and potentially for z-scoring. It needs to return a PyTorch
`nn.Module` implementing the classifier.
`nn.Module` implementing the critic.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
Expand Down Expand Up @@ -62,7 +62,7 @@ def train(
dataloader_kwargs: Optional[Dict] = None,
loss_kwargs: Optional[Dict[str, Any]] = None,
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
r"""Return critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.

Args:
training_batch_size: Training batch size.
Expand Down Expand Up @@ -91,7 +91,7 @@ def train(
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.

Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""

# AALR is defined for `num_atoms=2`.
Expand All @@ -100,7 +100,8 @@ def train(
return super().train(**kwargs, num_atoms=2)

def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
"""Returns the binary cross-entropy loss for the trained classifier.
"""Returns the binary cross-entropy loss for the classifier
(defined by the critic).

The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the
Expand Down
12 changes: 6 additions & 6 deletions sbi/inference/snre/snre_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SNRE_B(RatioEstimator):
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, Callable] = "resnet",
critic: Union[str, Callable] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
Expand All @@ -28,13 +28,13 @@ def __init__(
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
critic: Critic trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet). Alternatively, a function that builds a custom
neural network can be provided. The function will be called with the
first batch of simulations (theta, x), which can thus be used for shape
inference and potentially for z-scoring. It needs to return a PyTorch
`nn.Module` implementing the classifier.
`nn.Module` implementing the critic.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
Expand Down Expand Up @@ -62,7 +62,7 @@ def train(
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
r"""Return critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.

Args:
num_atoms: Number of atoms to use for classification.
Expand Down Expand Up @@ -91,14 +91,14 @@ def train(
and validation dataloaders (like, e.g., a collate_fn)

Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
return super().train(**kwargs)

def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
r"""Return cross-entropy (via softmax activation) loss for 1-out-of-`num_atoms`
classification.
classification (defined by the critic).

The classifier takes as input `num_atoms` $(\theta,x)$ pairs. Out of these
pairs, one pair was sampled from the joint $p(\theta,x)$ and all others from the
Expand Down
25 changes: 13 additions & 12 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sbi.inference.base import NeuralInference
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import ratio_estimator_based_potential
from sbi.neural_nets import classifier_nn
from sbi.neural_nets import critic_nn
from sbi.utils import (
check_estimator_arg,
check_prior,
Expand All @@ -25,7 +25,7 @@ class RatioEstimator(NeuralInference, ABC):
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, Callable] = "resnet",
critic: Union[str, Callable] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[SummaryWriter] = None,
Expand All @@ -49,13 +49,13 @@ def __init__(
(normalizing constant) of the data $x$.

Args:
classifier: Classifier trained to approximate likelihood ratios. If it is
critic: Critic trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet). Alternatively, a function that builds a custom
neural network can be provided. The function will be called with the
first batch of simulations (theta, x), which can thus be used for shape
inference and potentially for z-scoring. It needs to return a PyTorch
`nn.Module` implementing the classifier.
`nn.Module` implementing the critic.

See docstring of `NeuralInference` class for all other arguments.
"""
Expand All @@ -73,11 +73,11 @@ def __init__(
# `_build_neural_net`. It will be called in the first round and receive
# thetas and xs as inputs, so that they can be used for shape inference and
# potentially for z-scoring.
check_estimator_arg(classifier)
if isinstance(classifier, str):
self._build_neural_net = classifier_nn(model=classifier)
check_estimator_arg(critic)
if isinstance(critic, str):
self._build_neural_net = critic_nn(model=critic)
else:
self._build_neural_net = classifier
self._build_neural_net = critic

def append_simulations(
self,
Expand Down Expand Up @@ -139,7 +139,7 @@ def train(
dataloader_kwargs: Optional[Dict] = None,
loss_kwargs: Optional[Dict[str, Any]] = None,
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
r"""Return critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.

Args:
num_atoms: Number of atoms to use for classification.
Expand All @@ -159,7 +159,7 @@ def train(
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.

Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""
# Load data from most recent round.
self._round = max(self._data_round_index)
Expand Down Expand Up @@ -285,7 +285,8 @@ def train(
return deepcopy(self._neural_net)

def _classifier_logits(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
"""Return logits obtained through classifier forward pass.
"""Return logits obtained through classifier (defined by the critic's)
forward pass.

The logits are obtained from atomic sets of (theta,x) pairs.
"""
Expand All @@ -303,7 +304,7 @@ def _classifier_logits(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor
batch_size * num_atoms, -1
)

return self._neural_net([atomic_theta, repeated_x])
return self._neural_net(atomic_theta, repeated_x)

@abstractmethod
def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
Expand Down
10 changes: 5 additions & 5 deletions sbi/inference/snre/snre_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SNRE_C(RatioEstimator):
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, Callable] = "resnet",
critic: Union[str, Callable] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
Expand Down Expand Up @@ -42,13 +42,13 @@ def __init__(
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
critic: Critic trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet). Alternatively, a function that builds a custom
neural network can be provided. The function will be called with the
first batch of simulations (theta, x), which can thus be used for shape
inference and potentially for z-scoring. It needs to return a PyTorch
`nn.Module` implementing the classifier.
`nn.Module` implementing the critic.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
Expand Down Expand Up @@ -77,7 +77,7 @@ def train(
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
r"""Return critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.

Args:
num_classes: Number of theta to classify against, corresponds to $K$ in
Expand Down Expand Up @@ -116,7 +116,7 @@ def train(
and validation dataloaders (like, e.g., a collate_fn)

Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Critic that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
kwargs["num_atoms"] = kwargs.pop("num_classes") + 1
Expand Down
12 changes: 5 additions & 7 deletions sbi/neural_nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from sbi.neural_nets.classifier import (
StandardizeInputs,
build_input_layer,
build_linear_classifier,
build_mlp_classifier,
build_resnet_classifier,
from sbi.neural_nets.critic import (
build_linear_critic,
build_mlp_critic,
build_resnet_critic,
)
from sbi.neural_nets.density_estimators import DensityEstimator, NFlowsFlow
from sbi.neural_nets.embedding_nets import (
CNNEmbedding,
FCEmbedding,
PermutationInvariantEmbedding,
)
from sbi.neural_nets.factory import classifier_nn, likelihood_nn, posterior_nn
from sbi.neural_nets.factory import critic_nn, likelihood_nn, posterior_nn
from sbi.neural_nets.flow import (
build_made,
build_maf,
Expand Down
Loading