Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Mar 1, 2023
1 parent 8e7617b commit fe29549
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 50 deletions.
2 changes: 0 additions & 2 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ def get_simulations(
Args:
starting_round: The earliest round to return samples from (we start counting
from zero).
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training.
warn_on_invalid: Whether to give out a warning if invalid simulations were
found.
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def train(
# Get theta,x to initialize NN
theta, x, _ = self.get_simulations(starting_round=start_idx)
# Use only training data for building the neural net (z-scoring transforms)

self._neural_net = self._build_neural_net(
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
Expand Down
93 changes: 49 additions & 44 deletions sbi/neural_nets/embedding_nets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from typing import List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch import Tensor, nn
Expand All @@ -13,7 +13,7 @@ def __init__(
input_dim: int,
output_dim: int = 20,
num_layers: int = 2,
num_hiddens: int = 40,
num_hiddens: int = 50,
):
"""Fully-connected multi-layer neural network to be used as embedding network.
Expand Down Expand Up @@ -216,17 +216,23 @@ def __init__(
self,
trial_net: nn.Module,
trial_net_output_dim: int,
combining_operation: str = "mean",
aggregation_dim: int = 1,
aggregation_fn: Optional[str] = "sum",
num_hiddens: int = 100,
num_layers: int = 2,
num_hiddens: int = 40,
output_dim: int = 20,
):
"""Permutation invariant multi-layer NN.
The trial_net is applied to each "trial" of the input
and is combined by the combining_operation (mean or sum) to construct a
permutation invariant embedding across iid trials.
This embedding is embedded again using an additional fully connected net.
Applies the trial_net to every trial to obtain trial embeddings.
It then aggregates the trial embeddings across the aggregation dimension to
construct a permutation invariant embedding across iid trials.
The resulting embedding is processed further using an additional fully
connected net. The input to the final embedding net is the trial_net output
plus the number of trials N: (batch, trial_net_output_dim + 1)
If the data x has varying number of trials per batch element, missing trials
should be encoded as NaNs. In the forward pass, the NaNs are masked.
Args:
trial_net: Network to process one trial. The combining_operation is
Expand All @@ -236,25 +242,22 @@ def __init__(
Remark: This network should be large enough as it acts on all (iid)
inputs seperatley and needs enough capacity to process the information
of all inputs.
trial_net_output_dim: Dimensionality of the output of the trial_net / input
to the fully connected layers.
combining_operation: How to combine the permutational dimensions, one of
'mean' or 'sum'.
trial_net_output_dim: Dimensionality of the output of the trial_net.
aggregation_dim: Dimension along which to aggregate the trial embeddings.
aggregation_fn: Function to aggregate the trial embeddings. Defaults to
taking the sum over the non-nan values.
num_layers: Number of fully connected layer, minimum of 2.
num_hiddens: Number of hidden dimensions in fully-connected layers.
output_dim: Dimensionality of the output.
"""
super().__init__()
self.trial_net = trial_net
self.combining_operation = combining_operation

# define function for permutation invariant embedding
if combining_operation == "mean":
self.combining_function = torch.mean
elif self.combining_operation == "sum":
self.combining_function = torch.sum
else:
raise ValueError("combining_operation must be in ['sum', 'mean'].")
self.aggregation_dim = aggregation_dim
assert aggregation_fn in [
"mean",
"sum",
], "aggregation_fn must be 'mean' or 'sum'."
self.aggregation_fn = aggregation_fn

# construct fully connected layers
self.fc_subnet = FCEmbedding(
Expand All @@ -271,34 +274,36 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
Network output (batch_size, output_dim).
"""
batch, permutation_dim, _ = x.shape

# if no NaNs for padding varying trial lengths we can batch the computation
if not torch.isnan(x).any():
trial_embeddings = self.trial_net(x.view(batch * permutation_dim, -1)).view(
batch, permutation_dim, -1
# Get number of trials from non-nan entries
num_batch, max_num_trials = x.shape[0], x.shape[self.aggregation_dim]
nan_counts = (
torch.isnan(x)
.sum(dim=self.aggregation_dim) # count nans over trial dimension
.reshape(-1)[:num_batch] # counts are the same across data dims
.unsqueeze(-1) # make it (batch, 1) to match embeddings below
)
# number of non-nan trials
trial_counts = max_num_trials - nan_counts

# get nan entries
is_nan = torch.isnan(x)
# apply trial net with nan entries replaced with 0
masked_x = torch.nan_to_num(x, nan=0.0)
trial_embeddings = self.trial_net(masked_x)
# replace previous nan entries with zeros
trial_embeddings = trial_embeddings * (~is_nan.all(-1, keepdim=True)).float()

# Take mean over permutation dimension divide by number of trials
# (instead of just taking torch.mean) to account for masking.
if self.aggregation_fn == "mean":
combined_embedding = (
trial_embeddings.sum(dim=self.aggregation_dim) / trial_counts
)
combined_embedding = self.combining_function(trial_embeddings, dim=1)
trial_counts = torch.ones(batch, 1, dtype=torch.float32) * permutation_dim

# otherwise we need to loop over the batch to account for varying trial lengths
else:
combined_embedding = []
trial_counts = torch.zeros(batch, 1)
for i in range(batch):
# remove NaNs
valid_x = x[i, ~torch.isnan(x[i, :, 0]), :]
trial_counts[i] = valid_x.shape[0]
trial_embeddings = self.trial_net(valid_x)
# apply combining operation over permutation dimension
combined_embedding.append(
self.combining_function(trial_embeddings, dim=0)
)

combined_embedding = torch.stack(combined_embedding, dim=0)
combined_embedding = trial_embeddings.sum(dim=self.aggregation_dim)

assert not torch.isnan(combined_embedding).any(), "NaNs in embedding."

# add number of trials as additional input
# print(torch.cat([combined_embedding, trial_counts], dim=1))
return self.fc_subnet(torch.cat([combined_embedding, trial_counts], dim=1))
11 changes: 7 additions & 4 deletions tests/embedding_net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_iid_embedding_api(num_trials, num_dim):


@pytest.mark.slow
def test_iid_embedding_varying_num_trials(max_num_trials=100, trial_factor=50):
def test_iid_embedding_varying_num_trials(trial_factor=40, max_num_trials=100):
"""Test embedding net with varying number of trials."""
num_dim = 2
prior = torch.distributions.MultivariateNormal(
Expand All @@ -122,14 +122,17 @@ def test_iid_embedding_varying_num_trials(max_num_trials=100, trial_factor=50):
single_trial_net,
trial_net_output_dim=output_dim,
output_dim=output_dim,
aggregation_fn="sum",
)

# test embedding net
e = embedding_net(x[:10])
assert e.shape == (10, output_dim)
assert embedding_net(x[:3]).shape == (3, output_dim)

density_estimator = posterior_nn(
"mdn", embedding_net=embedding_net, z_score_x="none"
model="mdn",
embedding_net=embedding_net,
z_score_x="none",
z_score_theta="none",
)
inference = SNPE(prior, density_estimator=density_estimator)

Expand Down

0 comments on commit fe29549

Please sign in to comment.