Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve handling of multi-d x in SNPE, adapt iid test. #780

Merged
merged 1 commit into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,12 @@ def check_for_possibly_batched_x_shape(x_shape):
has an inferred batch shape larger than one. This is not supported in
some sbi methods for reasons depending on the scenario:

- in case you want to evaluate or sample conditioned on several xs
e.g., (p(theta | [x1, x2, x3])), this is not supported yet except
when using likelihood based SNLE and SNRE.
- in case you want to evaluate or sample conditioned on several iid
xs e.g., (p(theta | [x1, x2, x3])), this is fully supported only
for likelihood based SNLE and SNRE. For SNPE it is supported only
for a fixed number of trials and using an appropriate embedding
net, i.e., by treating the trials as additional data dimension. In
that case, make sure to pass xo with a leading batch dimensionen.

- in case you trained with a single round to do amortized inference
and now you want to evaluate or sample a given theta conditioned on
Expand Down Expand Up @@ -569,6 +572,10 @@ def process_x(

x = atleast_2d(torch.as_tensor(x, dtype=float32))

# If x_shape is provided, we can fix a missing batch dim for >1D data.
if x_shape is not None and len(x_shape) > len(x.shape):
x = x.unsqueeze(0)
janfb marked this conversation as resolved.
Show resolved Hide resolved

input_x_shape = x.shape
if not allow_iid_x:
check_for_possibly_batched_x_shape(input_x_shape)
Expand Down
98 changes: 38 additions & 60 deletions tests/embedding_net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import eye, ones, zeros

from sbi import utils as utils
from sbi.inference import SNLE, SNPE, SNRE
from sbi.inference import SNLE, SNPE, SNRE, simulate_for_sbi
from sbi.neural_nets.embedding_nets import (
CNNEmbedding,
FCEmbedding,
Expand Down Expand Up @@ -96,9 +96,9 @@ def test_iid_embedding_api(num_trials, num_dim):


@pytest.mark.slow
@pytest.mark.parametrize("num_trials", [10, 50])
@pytest.mark.parametrize("num_trials", [1, 10, 50])
@pytest.mark.parametrize("num_dim", [2])
@pytest.mark.parametrize("method", ["SNPE", "SNLE", "SNRE"])
@pytest.mark.parametrize("method", ("SNPE",))
janfb marked this conversation as resolved.
Show resolved Hide resolved
def test_iid_inference(num_trials, num_dim, method):
"""Test accuracy in Gaussian linear simulator with iid trials.

Expand All @@ -112,51 +112,39 @@ def test_iid_inference(num_trials, num_dim, method):

# Scale number of training samples with num_trials.
num_thetas = 1000 + 100 * num_trials
# Likelihood-based methods train on single trials.
num_simulations = num_thetas

if method == "SNPE": # SNPE needs embedding and iid trials during training.
theta = prior.sample((num_thetas,))
# simulate iid x.
iid_theta = theta.reshape(num_thetas, 1, num_dim).repeat(1, num_trials, 1)
x = torch.randn_like(iid_theta) + iid_theta
x_o = zeros(1, num_trials, num_dim)

# embedding
latent_dim = 10
single_trial_net = FCEmbedding(
input_dim=num_dim,
num_hiddens=40,
num_layers=2,
output_dim=latent_dim,
)
embedding_net = PermutationInvariantEmbedding(
single_trial_net,
trial_net_output_dim=latent_dim,
# NOTE: post-embedding is not needed really.
num_layers=1,
num_hiddens=10,
output_dim=10,
)

density_estimator = posterior_nn("maf", embedding_net=embedding_net)

inference = SNPE(prior, density_estimator=density_estimator)
else: # likelihood-based methods: single-trial training without embeddings.
# simulate iid x.
def simulator(theta, num_trials=num_trials):
iid_theta = theta.reshape(theta.shape[0], 1, num_dim).repeat(1, num_trials, 1)
return torch.randn_like(iid_theta) + iid_theta

theta, x = simulate_for_sbi(simulator, prior, num_simulations=num_thetas)

# embedding
latent_dim = 10
single_trial_net = FCEmbedding(
input_dim=num_dim,
num_hiddens=40,
num_layers=2,
output_dim=latent_dim,
)
embedding_net = PermutationInvariantEmbedding(
single_trial_net,
trial_net_output_dim=latent_dim,
# NOTE: post-embedding is not needed really.
num_layers=1,
num_hiddens=10,
output_dim=10,
)

theta = prior.sample((num_simulations,))
x = torch.randn_like(theta) + theta
x_o = zeros(1, num_dim)
density_estimator = posterior_nn("maf", embedding_net=embedding_net)

if method == "SNLE":
inference = SNLE(prior, density_estimator=likelihood_nn("maf"))
elif method == "SNRE":
inference = SNRE(prior, classifier=classifier_nn("resnet"))
else:
raise NameError
inference = SNPE(prior, density_estimator=density_estimator)

# get reference samples from true posterior
num_samples = 1000
# define x_o without batch dim to test handling below.
x_o = zeros(num_trials, num_dim)
reference_samples = true_posterior_linear_gaussian_mvn_prior(
x_o.squeeze(),
likelihood_shift=torch.zeros(num_dim),
Expand All @@ -172,26 +160,16 @@ def test_iid_inference(num_trials, num_dim, method):
posterior = inference.build_posterior().set_default_x(x_o)
samples = posterior.sample((num_samples,))

if method == "SNPE":
check_c2st(samples, reference_samples, alg=method)
# permute and test again
num_repeats = 2
for _ in range(num_repeats):
trial_permutet_x_o = x_o[:, torch.randperm(x_o.shape[1]), :]
samples = posterior.sample((num_samples,), x=trial_permutet_x_o)
check_c2st(samples, reference_samples, alg=method + " permuted")
else:
check_c2st(samples, reference_samples, alg=method)
check_c2st(samples, reference_samples, alg=method)
# permute and test again
num_repeats = 2
for _ in range(num_repeats):
trial_permutet_x_o = x_o[torch.randperm(x_o.shape[0]), :]
samples = posterior.sample((num_samples,), x=trial_permutet_x_o)
check_c2st(samples, reference_samples, alg=method + " permuted")


@pytest.mark.parametrize(
"input_shape",
[
(32,),
(32, 32),
(32, 64),
],
)
@pytest.mark.parametrize("input_shape", [(32,), (32, 32), (32, 64)])
@pytest.mark.parametrize("num_channels", (1, 2, 3))
def test_1d_and_2d_cnn_embedding_net(input_shape, num_channels):
import torch
Expand Down
15 changes: 10 additions & 5 deletions tests/user_input_checks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,19 @@ def test_process_prior(prior):


@pytest.mark.parametrize(
"x, x_shape",
"x, x_shape, allow_iid",
(
(ones(3), torch.Size([1, 3])),
(ones(1, 3), torch.Size([1, 3])),
(ones(3), torch.Size([1, 3]), False),
(ones(1, 3), torch.Size([1, 3]), False),
(ones(10, 3), torch.Size([1, 10, 3]), False), # 2D data / iid SNPE
pytest.param(
ones(10, 3), None, False, marks=pytest.mark.xfail
), # 2D data / iid SNPE without x_shape
(ones(10, 10), torch.Size([1, 10]), True), # iid likelihood based
),
)
def test_process_x(x, x_shape):
process_x(x, x_shape)
def test_process_x(x, x_shape, allow_iid):
process_x(x, x_shape, allow_iid_x=allow_iid)


@pytest.mark.parametrize(
Expand Down