Skip to content

Commit

Permalink
include test for embedding net device.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jul 2, 2024
1 parent baae93b commit 5afba70
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 21 deletions.
11 changes: 7 additions & 4 deletions sbi/neural_nets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
"zuko_bpf": build_zuko_bpf,
}

embedding_net_warn_msg = """The passed embedding net will be moved to cpu for
constructing the net building function."""


def classifier_nn(
model: str,
Expand Down Expand Up @@ -99,8 +102,8 @@ def classifier_nn(
z_score_theta,
z_score_x,
hidden_features,
check_net_device(embedding_net_theta, "cpu"),
check_net_device(embedding_net_x, "cpu"),
check_net_device(embedding_net_theta, "cpu", embedding_net_warn_msg),
check_net_device(embedding_net_x, "cpu", embedding_net_warn_msg),
),
),
**kwargs,
Expand Down Expand Up @@ -181,7 +184,7 @@ def likelihood_nn(
hidden_features,
num_transforms,
num_bins,
check_net_device(embedding_net, "cpu"),
check_net_device(embedding_net, "cpu", embedding_net_warn_msg),
num_components,
),
),
Expand Down Expand Up @@ -257,7 +260,7 @@ def posterior_nn(
hidden_features,
num_transforms,
num_bins,
check_net_device(embedding_net, "cpu"),
check_net_device(embedding_net, "cpu", embedding_net_warn_msg),
num_components,
),
),
Expand Down
58 changes: 41 additions & 17 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,24 @@
)
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets import classifier_nn, likelihood_nn, posterior_nn
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.neural_nets.factory import (
classifier_nn,
embedding_net_warn_msg,
likelihood_nn,
posterior_nn,
)
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
from sbi.utils.torchutils import BoxUniform, gpu_available, process_device
from sbi.utils.user_input_checks import (
validate_theta_and_x,
)

# tests in this file are skipped if there is GPU device available
pytestmark = pytest.mark.skipif(
not gpu_available(), reason="No CUDA or MPS device available."
)


@pytest.mark.slow
@pytest.mark.gpu
Expand Down Expand Up @@ -239,19 +250,18 @@ def test_validate_theta_and_x_device(training_device: str, data_device: str) ->
)


@pytest.mark.gpu
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
@pytest.mark.parametrize("embedding_device", ("cpu", "gpu"))
def test_train_with_different_data_and_training_device(
inference_method, data_device: str, training_device: str
inference_method, data_device: str, training_device: str, embedding_device: str
) -> None:
assert gpu_available(), "this test requires that gpu is available."

data_device = process_device(data_device)
training_device = process_device(training_device)
embedding_device = process_device(embedding_device)

num_dim = 2
num_simulations = 32
Expand All @@ -260,19 +270,33 @@ def test_train_with_different_data_and_training_device(
)
simulator = diagonal_linear_gaussian

# moving embedding net to device to mimic user with large custom embedding.
embedding_net = FCEmbedding(input_dim=num_dim, output_dim=num_dim).to(
embedding_device
)

if inference_method in [SNRE_A, SNRE_B, SNRE_C]:
net_builder_fun = classifier_nn
kwargs = dict(model="mlp", embedding_net_x=embedding_net)
elif inference_method == SNLE:
net_builder_fun = likelihood_nn
kwargs = dict(model="mdn", embedding_net=embedding_net)
elif inference_method == SNPE_A:
net_builder_fun = posterior_nn
kwargs = dict(model="mdn_snpe_a", embedding_net=embedding_net)
else:
net_builder_fun = posterior_nn
kwargs = dict(model="mdn", embedding_net=embedding_net)

# warning must be issued when embedding not on cpu.
if embedding_device != "cpu":
with pytest.warns(UserWarning, match=embedding_net_warn_msg):
net_builder = net_builder_fun(**kwargs)
else:
net_builder = net_builder_fun(**kwargs)

inference = inference_method(
prior,
**(
dict(classifier="resnet")
if inference_method in [SNRE_A, SNRE_B, SNRE_C]
else dict(
density_estimator=(
"mdn_snpe_a" if inference_method == SNPE_A else "maf"
)
)
),
show_progress_bars=False,
device=training_device,
prior, net_builder, show_progress_bars=False, device=training_device
)

theta = prior.sample((num_simulations,))
Expand Down

0 comments on commit 5afba70

Please sign in to comment.