Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add size invariant iid embedding nets, tests. #808

Merged
merged 3 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ def get_simulations(
Args:
starting_round: The earliest round to return samples from (we start counting
from zero).
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training.
warn_on_invalid: Whether to give out a warning if invalid simulations were
found.

Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch import Tensor, nn, ones, optim
from torch.distributions import Distribution
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils import data
from torch.utils.tensorboard.writer import SummaryWriter

from sbi import utils as utils
Expand Down Expand Up @@ -306,6 +305,7 @@ def train(
# Get theta,x to initialize NN
theta, x, _ = self.get_simulations(starting_round=start_idx)
# Use only training data for building the neural net (z-scoring transforms)

self._neural_net = self._build_neural_net(
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
Expand Down
81 changes: 53 additions & 28 deletions sbi/neural_nets/embedding_nets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor, nn
Expand All @@ -13,7 +13,7 @@ def __init__(
input_dim: int,
output_dim: int = 20,
num_layers: int = 2,
num_hiddens: int = 40,
num_hiddens: int = 50,
):
"""Fully-connected multi-layer neural network to be used as embedding network.

Expand Down Expand Up @@ -216,17 +216,23 @@ def __init__(
self,
trial_net: nn.Module,
trial_net_output_dim: int,
combining_operation: str = "mean",
aggregation_fn: Optional[str] = "sum",
num_hiddens: int = 100,
num_layers: int = 2,
num_hiddens: int = 40,
output_dim: int = 20,
aggregation_dim: int = 1,
):
"""Permutation invariant multi-layer NN.

The trial_net is applied to each "trial" of the input
and is combined by the combining_operation (mean or sum) to construct a
permutation invariant embedding across iid trials.
This embedding is embedded again using an additional fully connected net.
Applies the trial_net to every trial to obtain trial embeddings.
It then aggregates the trial embeddings across the aggregation dimension to
construct a permutation invariant embedding across iid trials.
The resulting embedding is processed further using an additional fully
connected net. The input to the final embedding net is the trial_net output
plus the number of trials N: (batch, trial_net_output_dim + 1)

If the data x has varying number of trials per batch element, missing trials
should be encoded as NaNs. In the forward pass, the NaNs are masked.

Args:
trial_net: Network to process one trial. The combining_operation is
Expand All @@ -236,24 +242,26 @@ def __init__(
Remark: This network should be large enough as it acts on all (iid)
inputs seperatley and needs enough capacity to process the information
of all inputs.
trial_net_output_dim: Dimensionality of the output of the trial_net / input
to the fully connected layers.
combining_operation: How to combine the permutational dimensions, one of
'mean' or 'sum'.
trial_net_output_dim: Dimensionality of the output of the trial_net.
aggregation_fn: Function to aggregate the trial embeddings. Defaults to
taking the sum over the non-nan values.
num_layers: Number of fully connected layer, minimum of 2.
num_hiddens: Number of hidden dimensions in fully-connected layers.
output_dim: Dimensionality of the output.
aggregation_dim: Dimension along which to aggregate the trial embeddings.
"""
super().__init__()
self.trial_net = trial_net
self.combining_operation = combining_operation

if combining_operation not in ["sum", "mean"]:
raise ValueError("combining_operation must be in ['sum', 'mean'].")
self.aggregation_dim = aggregation_dim
assert aggregation_fn in [
"mean",
"sum",
], "aggregation_fn must be 'mean' or 'sum'."
self.aggregation_fn = aggregation_fn

# construct fully connected layers
self.fc_subnet = FCEmbedding(
input_dim=trial_net_output_dim,
input_dim=trial_net_output_dim + 1, # +1 to encode number of trials
output_dim=output_dim,
num_layers=num_layers,
num_hiddens=num_hiddens,
Expand All @@ -266,19 +274,36 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
Network output (batch_size, output_dim).
"""
batch, permutation_dim, _ = x.shape

iid_embeddings = self.trial_net(x.view(batch * permutation_dim, -1)).view(
batch, permutation_dim, -1
# Get number of trials from non-nan entries
num_batch, max_num_trials = x.shape[0], x.shape[self.aggregation_dim]
nan_counts = (
torch.isnan(x)
.sum(dim=self.aggregation_dim) # count nans over trial dimension
.reshape(-1)[:num_batch] # counts are the same across data dims
.unsqueeze(-1) # make it (batch, 1) to match embeddings below
)

if self.combining_operation == "mean":
e = iid_embeddings.mean(dim=1)
elif self.combining_operation == "sum":
e = iid_embeddings.sum(dim=1)
# number of non-nan trials
trial_counts = max_num_trials - nan_counts

# get nan entries
is_nan = torch.isnan(x)
# apply trial net with nan entries replaced with 0
masked_x = torch.nan_to_num(x, nan=0.0)
trial_embeddings = self.trial_net(masked_x)
# replace previous nan entries with zeros
trial_embeddings = trial_embeddings * (~is_nan.all(-1, keepdim=True)).float()
janfb marked this conversation as resolved.
Show resolved Hide resolved

# Take mean over permutation dimension divide by number of trials
# (instead of just taking torch.mean) to account for masking.
if self.aggregation_fn == "mean":
combined_embedding = (
trial_embeddings.sum(dim=self.aggregation_dim) / trial_counts
)
else:
raise ValueError("combining_operation must be in ['sum', 'mean'].")
combined_embedding = trial_embeddings.sum(dim=self.aggregation_dim)

embedding = self.fc_subnet(e)
assert not torch.isnan(combined_embedding).any(), "NaNs in embedding."
janfb marked this conversation as resolved.
Show resolved Hide resolved

return embedding
# add number of trials as additional input
return self.fc_subnet(torch.cat([combined_embedding, trial_counts], dim=1))
20 changes: 17 additions & 3 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,20 @@ def standardizing_net(
t_std = torch.std(batch_t[is_valid_t], dim=0)
t_std[t_std < min_std] = min_std
else:
t_std = 1
t_std = torch.ones(1)
logging.warning(
"""Using a one-dimensional batch will instantiate a Standardize transform
with (mean, std) parameters which are not representative of the data. We
allow this behavior because you might be loading a pre-trained. If this is
not the case, please be sure to use a larger batch."""
allow this behavior because you might be loading a pre-trained net.
If this is not the case, please be sure to use a larger batch."""
)

nan_in_stats = torch.logical_or(torch.isnan(t_mean).any(), torch.isnan(t_std).any())
assert not nan_in_stats, """Training data mean or std for standardizing net must not
contain NaNs. In case you are encoding missing trials with
NaNs, consider setting z_score_x='none' to disable
z-scoring."""

return Standardize(t_mean, t_std)


Expand All @@ -243,6 +249,8 @@ def handle_invalid_x(
# Squeeze to cover all dimensions in case of multidimensional x.
x = x.reshape(batch_size, -1)

# TODO: add option to allow for NaNs in certain dimensions, e.g., to encode varying
# numbers of trials.
x_is_nan = torch.isnan(x).any(dim=1)
x_is_inf = torch.isinf(x).any(dim=1)
num_nans = int(x_is_nan.sum().item())
Expand All @@ -253,6 +261,12 @@ def handle_invalid_x(
else:
is_valid_x = ones(batch_size, dtype=torch.bool)

assert (
is_valid_x.sum() > 0
), """No valid data entries left after excluding NaNs and Infs. In case you are
encoding missing trials with NaNs consider setting exclude_invalid_x=False and
z_score_x = 'none' to disable z-scoring."""

return is_valid_x, num_nans, num_infs


Expand Down
76 changes: 73 additions & 3 deletions tests/embedding_net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
true_posterior_linear_gaussian_mvn_prior,
)
from sbi.utils import classifier_nn, likelihood_nn, posterior_nn
from tests.test_utils import check_c2st

from .test_utils import check_c2st


@pytest.mark.parametrize("method", ["SNPE", "SNLE", "SNRE"])
Expand Down Expand Up @@ -94,6 +95,75 @@ def test_iid_embedding_api(num_trials, num_dim):
_ = posterior.potential(s)


@pytest.mark.slow
def test_iid_embedding_varying_num_trials(trial_factor=40, max_num_trials=100):
"""Test embedding net with varying number of trials."""
num_dim = 2
prior = torch.distributions.MultivariateNormal(
torch.zeros(num_dim), torch.eye(num_dim)
)

# Scale number of training samples with num_trials.
num_thetas = 5000 + trial_factor * max_num_trials

theta = prior.sample(sample_shape=torch.Size((num_thetas,)))
num_trials = torch.randint(1, max_num_trials, size=(num_thetas,))

# simulate iid x, pad smaller number of trials with nans.
x = ones(num_thetas, max_num_trials, 2) * float("nan")

for i in range(num_thetas):
th = theta[i].repeat(num_trials[i], 1)
x[i, : num_trials[i]] = torch.randn_like(th) + th

# build embedding net
output_dim = 5
single_trial_net = FCEmbedding(input_dim=num_dim, output_dim=output_dim)
embedding_net = PermutationInvariantEmbedding(
single_trial_net,
trial_net_output_dim=output_dim,
output_dim=output_dim,
aggregation_fn="sum",
)

# test embedding net
assert embedding_net(x[:3]).shape == (3, output_dim)

density_estimator = posterior_nn(
model="mdn",
embedding_net=embedding_net,
z_score_x="none", # turn off z-scoring because of NaN encodinds.
z_score_theta="independent",
)
inference = SNPE(prior, density_estimator=density_estimator)

# do not exclude invalid x, as we padded with nans.
_ = inference.append_simulations(theta, x, exclude_invalid_x=False).train(
training_batch_size=100
)

num_samples = 1000
# test different number of trials
num_test_trials = torch.linspace(1, max_num_trials, 5, dtype=torch.int)
for num_trials in num_test_trials:
# x_o must have the same number of trials as x, thus we pad with nans.
x_o = ones(1, max_num_trials, num_dim) * float("nan")
x_o[:, :num_trials] = 0.0

# get reference samples from true posterior
reference_samples = true_posterior_linear_gaussian_mvn_prior(
x_o[0, :num_trials, :], # omit nans
likelihood_shift=torch.zeros(num_dim),
likelihood_cov=torch.eye(num_dim),
prior_cov=prior.covariance_matrix,
prior_mean=prior.loc,
).sample((num_samples,))

posterior = inference.build_posterior().set_default_x(x_o)
samples = posterior.sample((num_samples,))
check_c2st(samples, reference_samples, alg=f"iid-NPE with {num_trials} trials")


@pytest.mark.slow
@pytest.mark.parametrize("num_trials", [1, 10, 50])
@pytest.mark.parametrize("num_dim", [2])
Expand All @@ -110,7 +180,7 @@ def test_iid_inference(num_trials, num_dim, method):
)

# Scale number of training samples with num_trials.
num_thetas = 1000 + 100 * num_trials
num_thetas = 1000 + 110 * num_trials

# simulate iid x.
def simulator(theta, num_trials=num_trials):
Expand Down Expand Up @@ -202,7 +272,7 @@ def simulator1d(theta):
prior = MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim))

num_simulations = 1000
theta = prior.sample((num_simulations,))
theta = prior.sample(torch.Size((num_simulations,)))
x = simulator(theta)
if num_channels > 1:
x = x.unsqueeze(1).repeat(
Expand Down
6 changes: 3 additions & 3 deletions tests/inference_with_NaN_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
)
from sbi.utils import RestrictionEstimator
from sbi.utils.sbiutils import handle_invalid_x
from tests.test_utils import check_c2st

from .test_utils import check_c2st


@pytest.mark.parametrize(
"x_shape",
(
torch.Size((1, 1)),
torch.Size((1, 10)),
torch.Size((10, 1)),
torch.Size((10, 10)),
),
Expand All @@ -38,6 +37,7 @@ def test_handle_invalid_x(x_shape):
x = torch.rand(x_shape)
x[x < 0.1] = float("nan")
x[x > 0.9] = float("inf")
x[-1, :] = 1.0 # make sure there is one row of valid entries.

x_is_valid, *_ = handle_invalid_x(x, exclude_invalid_x=True)

Expand Down