From d8641061f274657b568f4f15b0da396a341e6be3 Mon Sep 17 00:00:00 2001 From: janfb Date: Tue, 8 Nov 2022 18:40:40 +0100 Subject: [PATCH] improve handling of multi-d x in SNPE, adapt iid test. --- sbi/utils/user_input_checks.py | 13 ++++- tests/embedding_net_test.py | 98 +++++++++++++-------------------- tests/user_input_checks_test.py | 15 +++-- 3 files changed, 58 insertions(+), 68 deletions(-) diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 66a3685ce..b5b55d2ae 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -272,9 +272,12 @@ def check_for_possibly_batched_x_shape(x_shape): has an inferred batch shape larger than one. This is not supported in some sbi methods for reasons depending on the scenario: - - in case you want to evaluate or sample conditioned on several xs - e.g., (p(theta | [x1, x2, x3])), this is not supported yet except - when using likelihood based SNLE and SNRE. + - in case you want to evaluate or sample conditioned on several iid + xs e.g., (p(theta | [x1, x2, x3])), this is fully supported only + for likelihood based SNLE and SNRE. For SNPE it is supported only + for a fixed number of trials and using an appropriate embedding + net, i.e., by treating the trials as additional data dimension. In + that case, make sure to pass xo with a leading batch dimensionen. - in case you trained with a single round to do amortized inference and now you want to evaluate or sample a given theta conditioned on @@ -569,6 +572,10 @@ def process_x( x = atleast_2d(torch.as_tensor(x, dtype=float32)) + # If x_shape is provided, we can fix a missing batch dim for >1D data. + if x_shape is not None and len(x_shape) > len(x.shape): + x = x.unsqueeze(0) + input_x_shape = x.shape if not allow_iid_x: check_for_possibly_batched_x_shape(input_x_shape) diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py index c7ee0bc92..895911675 100644 --- a/tests/embedding_net_test.py +++ b/tests/embedding_net_test.py @@ -5,7 +5,7 @@ from torch import eye, ones, zeros from sbi import utils as utils -from sbi.inference import SNLE, SNPE, SNRE +from sbi.inference import SNLE, SNPE, SNRE, simulate_for_sbi from sbi.neural_nets.embedding_nets import ( CNNEmbedding, FCEmbedding, @@ -96,9 +96,9 @@ def test_iid_embedding_api(num_trials, num_dim): @pytest.mark.slow -@pytest.mark.parametrize("num_trials", [10, 50]) +@pytest.mark.parametrize("num_trials", [1, 10, 50]) @pytest.mark.parametrize("num_dim", [2]) -@pytest.mark.parametrize("method", ["SNPE", "SNLE", "SNRE"]) +@pytest.mark.parametrize("method", ("SNPE",)) def test_iid_inference(num_trials, num_dim, method): """Test accuracy in Gaussian linear simulator with iid trials. @@ -112,51 +112,39 @@ def test_iid_inference(num_trials, num_dim, method): # Scale number of training samples with num_trials. num_thetas = 1000 + 100 * num_trials - # Likelihood-based methods train on single trials. - num_simulations = num_thetas - - if method == "SNPE": # SNPE needs embedding and iid trials during training. - theta = prior.sample((num_thetas,)) - # simulate iid x. - iid_theta = theta.reshape(num_thetas, 1, num_dim).repeat(1, num_trials, 1) - x = torch.randn_like(iid_theta) + iid_theta - x_o = zeros(1, num_trials, num_dim) - - # embedding - latent_dim = 10 - single_trial_net = FCEmbedding( - input_dim=num_dim, - num_hiddens=40, - num_layers=2, - output_dim=latent_dim, - ) - embedding_net = PermutationInvariantEmbedding( - single_trial_net, - trial_net_output_dim=latent_dim, - # NOTE: post-embedding is not needed really. - num_layers=1, - num_hiddens=10, - output_dim=10, - ) - - density_estimator = posterior_nn("maf", embedding_net=embedding_net) - inference = SNPE(prior, density_estimator=density_estimator) - else: # likelihood-based methods: single-trial training without embeddings. + # simulate iid x. + def simulator(theta, num_trials=num_trials): + iid_theta = theta.reshape(theta.shape[0], 1, num_dim).repeat(1, num_trials, 1) + return torch.randn_like(iid_theta) + iid_theta + + theta, x = simulate_for_sbi(simulator, prior, num_simulations=num_thetas) + + # embedding + latent_dim = 10 + single_trial_net = FCEmbedding( + input_dim=num_dim, + num_hiddens=40, + num_layers=2, + output_dim=latent_dim, + ) + embedding_net = PermutationInvariantEmbedding( + single_trial_net, + trial_net_output_dim=latent_dim, + # NOTE: post-embedding is not needed really. + num_layers=1, + num_hiddens=10, + output_dim=10, + ) - theta = prior.sample((num_simulations,)) - x = torch.randn_like(theta) + theta - x_o = zeros(1, num_dim) + density_estimator = posterior_nn("maf", embedding_net=embedding_net) - if method == "SNLE": - inference = SNLE(prior, density_estimator=likelihood_nn("maf")) - elif method == "SNRE": - inference = SNRE(prior, classifier=classifier_nn("resnet")) - else: - raise NameError + inference = SNPE(prior, density_estimator=density_estimator) # get reference samples from true posterior num_samples = 1000 + # define x_o without batch dim to test handling below. + x_o = zeros(num_trials, num_dim) reference_samples = true_posterior_linear_gaussian_mvn_prior( x_o.squeeze(), likelihood_shift=torch.zeros(num_dim), @@ -172,26 +160,16 @@ def test_iid_inference(num_trials, num_dim, method): posterior = inference.build_posterior().set_default_x(x_o) samples = posterior.sample((num_samples,)) - if method == "SNPE": - check_c2st(samples, reference_samples, alg=method) - # permute and test again - num_repeats = 2 - for _ in range(num_repeats): - trial_permutet_x_o = x_o[:, torch.randperm(x_o.shape[1]), :] - samples = posterior.sample((num_samples,), x=trial_permutet_x_o) - check_c2st(samples, reference_samples, alg=method + " permuted") - else: - check_c2st(samples, reference_samples, alg=method) + check_c2st(samples, reference_samples, alg=method) + # permute and test again + num_repeats = 2 + for _ in range(num_repeats): + trial_permutet_x_o = x_o[torch.randperm(x_o.shape[0]), :] + samples = posterior.sample((num_samples,), x=trial_permutet_x_o) + check_c2st(samples, reference_samples, alg=method + " permuted") -@pytest.mark.parametrize( - "input_shape", - [ - (32,), - (32, 32), - (32, 64), - ], -) +@pytest.mark.parametrize("input_shape", [(32,), (32, 32), (32, 64)]) @pytest.mark.parametrize("num_channels", (1, 2, 3)) def test_1d_and_2d_cnn_embedding_net(input_shape, num_channels): import torch diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index f70ddd81f..36177c0ff 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -192,14 +192,19 @@ def test_process_prior(prior): @pytest.mark.parametrize( - "x, x_shape", + "x, x_shape, allow_iid", ( - (ones(3), torch.Size([1, 3])), - (ones(1, 3), torch.Size([1, 3])), + (ones(3), torch.Size([1, 3]), False), + (ones(1, 3), torch.Size([1, 3]), False), + (ones(10, 3), torch.Size([1, 10, 3]), False), # 2D data / iid SNPE + pytest.param( + ones(10, 3), None, False, marks=pytest.mark.xfail + ), # 2D data / iid SNPE without x_shape + (ones(10, 10), torch.Size([1, 10]), True), # iid likelihood based ), ) -def test_process_x(x, x_shape): - process_x(x, x_shape) +def test_process_x(x, x_shape, allow_iid): + process_x(x, x_shape, allow_iid_x=allow_iid) @pytest.mark.parametrize(