Skip to content

Commit

Permalink
add check for embedding net device
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jun 26, 2024
1 parent 0b2b111 commit 06f83d5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
9 changes: 5 additions & 4 deletions sbi/neural_nets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from sbi.neural_nets.mdn import build_mdn
from sbi.neural_nets.mnle import build_mnle
from sbi.utils.nn_utils import check_net_device

model_builders = {
"mdn": build_mdn,
Expand Down Expand Up @@ -98,8 +99,8 @@ def classifier_nn(
z_score_theta,
z_score_x,
hidden_features,
embedding_net_theta,
embedding_net_x,
check_net_device(embedding_net_theta, "cpu"),
check_net_device(embedding_net_x, "cpu"),
),
),
**kwargs,
Expand Down Expand Up @@ -180,7 +181,7 @@ def likelihood_nn(
hidden_features,
num_transforms,
num_bins,
embedding_net,
check_net_device(embedding_net, "cpu"),
num_components,
),
),
Expand Down Expand Up @@ -256,7 +257,7 @@ def posterior_nn(
hidden_features,
num_transforms,
num_bins,
embedding_net,
check_net_device(embedding_net, "cpu"),
num_components,
),
),
Expand Down
28 changes: 26 additions & 2 deletions sbi/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
def get_numel(
batch_x: Tensor,
batch_y: Tensor,
embedding_net_x: Optional[nn.Module] | None = None,
embedding_net_y: Optional[nn.Module] | None = None,
embedding_net_x: Optional[nn.Module] = None,
embedding_net_y: Optional[nn.Module] = None,
warn_on_1d: bool = False,
) -> Tuple[int, int]:
"""
Expand Down Expand Up @@ -45,3 +45,27 @@ def get_numel(
)

return x_numel, y_numel


def check_net_device(net: nn.Module, device: str) -> nn.Module:
"""
Check whether a net is on the desired device and move it there if not.
Args:
net: neural network.
device: desired device.
Returns:
Neural network on the desired device.
"""

if isinstance(net, nn.Identity):
return net
if str(next(net.parameters()).device) != device:
warn(

Check warning on line 65 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L65

Added line #L65 was not covered by tests
f"Network is not on the correct device. Moving it to {device}.",
stacklevel=2,
)
return net.to(device)

Check warning on line 69 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L69

Added line #L69 was not covered by tests
else:
return net

0 comments on commit 06f83d5

Please sign in to comment.