From 5afba70aed4064a18fdf2ad5b0017fa5c8a37d00 Mon Sep 17 00:00:00 2001 From: janfb Date: Tue, 2 Jul 2024 13:39:09 +0200 Subject: [PATCH] include test for embedding net device. --- sbi/neural_nets/factory.py | 11 +++--- tests/inference_on_device_test.py | 58 ++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index 1b6d63396..e38af92ad 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -48,6 +48,9 @@ "zuko_bpf": build_zuko_bpf, } +embedding_net_warn_msg = """The passed embedding net will be moved to cpu for + constructing the net building function.""" + def classifier_nn( model: str, @@ -99,8 +102,8 @@ def classifier_nn( z_score_theta, z_score_x, hidden_features, - check_net_device(embedding_net_theta, "cpu"), - check_net_device(embedding_net_x, "cpu"), + check_net_device(embedding_net_theta, "cpu", embedding_net_warn_msg), + check_net_device(embedding_net_x, "cpu", embedding_net_warn_msg), ), ), **kwargs, @@ -181,7 +184,7 @@ def likelihood_nn( hidden_features, num_transforms, num_bins, - check_net_device(embedding_net, "cpu"), + check_net_device(embedding_net, "cpu", embedding_net_warn_msg), num_components, ), ), @@ -257,7 +260,7 @@ def posterior_nn( hidden_features, num_transforms, num_bins, - check_net_device(embedding_net, "cpu"), + check_net_device(embedding_net, "cpu", embedding_net_warn_msg), num_components, ), ), diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 07cffa329..b22d75ddd 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -26,13 +26,24 @@ ) from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior from sbi.inference.potentials.base_potential import BasePotential -from sbi.neural_nets import classifier_nn, likelihood_nn, posterior_nn +from sbi.neural_nets.embedding_nets import FCEmbedding +from sbi.neural_nets.factory import ( + classifier_nn, + embedding_net_warn_msg, + likelihood_nn, + posterior_nn, +) from sbi.simulators import diagonal_linear_gaussian, linear_gaussian from sbi.utils.torchutils import BoxUniform, gpu_available, process_device from sbi.utils.user_input_checks import ( validate_theta_and_x, ) +# tests in this file are skipped if there is GPU device available +pytestmark = pytest.mark.skipif( + not gpu_available(), reason="No CUDA or MPS device available." +) + @pytest.mark.slow @pytest.mark.gpu @@ -239,19 +250,18 @@ def test_validate_theta_and_x_device(training_device: str, data_device: str) -> ) -@pytest.mark.gpu @pytest.mark.parametrize( "inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE] ) @pytest.mark.parametrize("data_device", ("cpu", "gpu")) @pytest.mark.parametrize("training_device", ("cpu", "gpu")) +@pytest.mark.parametrize("embedding_device", ("cpu", "gpu")) def test_train_with_different_data_and_training_device( - inference_method, data_device: str, training_device: str + inference_method, data_device: str, training_device: str, embedding_device: str ) -> None: - assert gpu_available(), "this test requires that gpu is available." - data_device = process_device(data_device) training_device = process_device(training_device) + embedding_device = process_device(embedding_device) num_dim = 2 num_simulations = 32 @@ -260,19 +270,33 @@ def test_train_with_different_data_and_training_device( ) simulator = diagonal_linear_gaussian + # moving embedding net to device to mimic user with large custom embedding. + embedding_net = FCEmbedding(input_dim=num_dim, output_dim=num_dim).to( + embedding_device + ) + + if inference_method in [SNRE_A, SNRE_B, SNRE_C]: + net_builder_fun = classifier_nn + kwargs = dict(model="mlp", embedding_net_x=embedding_net) + elif inference_method == SNLE: + net_builder_fun = likelihood_nn + kwargs = dict(model="mdn", embedding_net=embedding_net) + elif inference_method == SNPE_A: + net_builder_fun = posterior_nn + kwargs = dict(model="mdn_snpe_a", embedding_net=embedding_net) + else: + net_builder_fun = posterior_nn + kwargs = dict(model="mdn", embedding_net=embedding_net) + + # warning must be issued when embedding not on cpu. + if embedding_device != "cpu": + with pytest.warns(UserWarning, match=embedding_net_warn_msg): + net_builder = net_builder_fun(**kwargs) + else: + net_builder = net_builder_fun(**kwargs) + inference = inference_method( - prior, - **( - dict(classifier="resnet") - if inference_method in [SNRE_A, SNRE_B, SNRE_C] - else dict( - density_estimator=( - "mdn_snpe_a" if inference_method == SNPE_A else "maf" - ) - ) - ), - show_progress_bars=False, - device=training_device, + prior, net_builder, show_progress_bars=False, device=training_device ) theta = prior.sample((num_simulations,))