From 06f83d52f7504500537dd057ada3da774a12231f Mon Sep 17 00:00:00 2001 From: janfb Date: Wed, 26 Jun 2024 18:25:12 +0200 Subject: [PATCH] add check for embedding net device --- sbi/neural_nets/factory.py | 9 +++++---- sbi/utils/nn_utils.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index 5bddcc0f5..1b6d63396 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -28,6 +28,7 @@ ) from sbi.neural_nets.mdn import build_mdn from sbi.neural_nets.mnle import build_mnle +from sbi.utils.nn_utils import check_net_device model_builders = { "mdn": build_mdn, @@ -98,8 +99,8 @@ def classifier_nn( z_score_theta, z_score_x, hidden_features, - embedding_net_theta, - embedding_net_x, + check_net_device(embedding_net_theta, "cpu"), + check_net_device(embedding_net_x, "cpu"), ), ), **kwargs, @@ -180,7 +181,7 @@ def likelihood_nn( hidden_features, num_transforms, num_bins, - embedding_net, + check_net_device(embedding_net, "cpu"), num_components, ), ), @@ -256,7 +257,7 @@ def posterior_nn( hidden_features, num_transforms, num_bins, - embedding_net, + check_net_device(embedding_net, "cpu"), num_components, ), ), diff --git a/sbi/utils/nn_utils.py b/sbi/utils/nn_utils.py index cf7056fb9..f8ef910d8 100644 --- a/sbi/utils/nn_utils.py +++ b/sbi/utils/nn_utils.py @@ -10,8 +10,8 @@ def get_numel( batch_x: Tensor, batch_y: Tensor, - embedding_net_x: Optional[nn.Module] | None = None, - embedding_net_y: Optional[nn.Module] | None = None, + embedding_net_x: Optional[nn.Module] = None, + embedding_net_y: Optional[nn.Module] = None, warn_on_1d: bool = False, ) -> Tuple[int, int]: """ @@ -45,3 +45,27 @@ def get_numel( ) return x_numel, y_numel + + +def check_net_device(net: nn.Module, device: str) -> nn.Module: + """ + Check whether a net is on the desired device and move it there if not. + + Args: + net: neural network. + device: desired device. + + Returns: + Neural network on the desired device. + """ + + if isinstance(net, nn.Identity): + return net + if str(next(net.parameters()).device) != device: + warn( + f"Network is not on the correct device. Moving it to {device}.", + stacklevel=2, + ) + return net.to(device) + else: + return net