Skip to content

Commit

Permalink
refactor: remove embedding net device handling. unify get_numel.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jun 25, 2024
1 parent e798a03 commit 0b2b111
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 249 deletions.
39 changes: 19 additions & 20 deletions sbi/neural_nets/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from pyknos.nflows.nn import nets
from torch import Tensor, nn, relu

from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device


class StandardizeInputs(nn.Module):
Expand Down Expand Up @@ -114,13 +114,13 @@ def build_linear_classifier(
Returns:
Neural network.
"""
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)

# Infer the output dimensionalities of the embedding_net by making a forward pass.
x_numel = embedding_net_x(batch_x[:1]).numel()
y_numel = embedding_net_y(batch_y[:1]).numel()
x_numel, y_numel = get_numel(
batch_x,
batch_y,
embedding_net_x=embedding_net_x,
embedding_net_y=embedding_net_y,
)

neural_net = nn.Linear(x_numel + y_numel, 1)

Expand Down Expand Up @@ -164,13 +164,13 @@ def build_mlp_classifier(
Returns:
Neural network.
"""
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)

# Infer the output dimensionalities of the embedding_net by making a forward pass.
x_numel = embedding_net_x(batch_x[:1]).numel()
y_numel = embedding_net_y(batch_y[:1]).numel()
x_numel, y_numel = get_numel(
batch_x,
batch_y,
embedding_net_x=embedding_net_x,
embedding_net_y=embedding_net_y,
)

neural_net = nn.Sequential(
nn.Linear(x_numel + y_numel, hidden_features),
Expand Down Expand Up @@ -225,13 +225,12 @@ def build_resnet_classifier(
Returns:
Neural network.
"""
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)

# Infer the output dimensionalities of the embedding_net by making a forward pass.
x_numel = embedding_net_x(batch_x[:1]).numel()
y_numel = embedding_net_y(batch_y[:1]).numel()
x_numel, y_numel = get_numel(
batch_x,
batch_y,
embedding_net_x=embedding_net_x,
embedding_net_y=embedding_net_y,
)

neural_net = nets.ResidualNet(
in_features=x_numel + y_numel,
Expand Down
46 changes: 11 additions & 35 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from functools import partial
from typing import List, Optional, Sequence, Tuple, Union
from warnings import warn
from typing import List, Optional, Sequence, Union

import torch
import zuko
Expand All @@ -16,45 +15,18 @@
from torch import Tensor, nn, relu, tanh, tensor, uint8

from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import (
standardizing_net,
standardizing_transform,
standardizing_transform_zuko,
z_score_parser,
)
from sbi.utils.torchutils import create_alternating_binary_mask
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device

nflow_specific_kwargs = ["num_bins", "num_components"]


def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[int, int]:
"""
Get the number of elements in the input and output space.
Args:
batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring.
batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring.
embedding_net: Optional embedding network for y.
Returns:
Tuple of the number of elements in the input and output space.
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
y_numel = embedding_net(batch_y[:1]).numel()
if x_numel == 1:
warn(
"In one-dimensional output space, this flow is limited to Gaussians",
stacklevel=2,
)

return x_numel, y_numel


def build_made(
batch_x: Tensor,
batch_y: Tensor,
Expand Down Expand Up @@ -87,7 +59,7 @@ def build_made(
Returns:
Neural network.
"""
x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net)
x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net_y=embedding_net)

transform = transforms.IdentityTransform()

Expand Down Expand Up @@ -162,7 +134,9 @@ def build_maf(
Returns:
Neural network.
"""
x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net)
x_numel, y_numel = get_numel(
batch_x, batch_y, embedding_net_y=embedding_net, warn_on_1d=True
)

transform_list = []
for _ in range(num_transforms):
Expand Down Expand Up @@ -262,7 +236,9 @@ def build_maf_rqs(
Returns:
Neural network.
"""
x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net)
x_numel, y_numel = get_numel(
batch_x, batch_y, embedding_net_y=embedding_net, warn_on_1d=True
)

transform_list = []
for _ in range(num_transforms):
Expand Down Expand Up @@ -357,7 +333,7 @@ def build_nsf(
Returns:
Neural network.
"""
x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net)
x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net_y=embedding_net)

# Define mask function to alternate between predicted x-dimensions.
def mask_in_layer(i):
Expand Down Expand Up @@ -1046,7 +1022,7 @@ def build_zuko_flow(
ZukoFlow: The constructed Zuko normalizing flow model.
"""

x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net)
x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net_y=embedding_net)

# keep only zuko kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in nflow_specific_kwargs}
Expand Down
9 changes: 2 additions & 7 deletions sbi/neural_nets/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from torch import Tensor, nn

from sbi.neural_nets.density_estimators import NFlowsFlow
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import (
standardizing_net,
standardizing_transform,
z_score_parser,
)
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device


def build_mdn(
Expand Down Expand Up @@ -48,12 +48,7 @@ def build_mdn(
Returns:
Neural network.
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
embedding_net.eval()
y_numel = embedding_net(batch_y[:1]).numel()
x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net_y=embedding_net)

transform = transforms.IdentityTransform()

Expand Down
1 change: 0 additions & 1 deletion sbi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# flake8: noqa
from sbi.utils.analysis_utils import get_1d_marginal_peaks_from_kde
from sbi.utils.conditional_density_utils import extract_and_transform_mog
from sbi.utils.io import get_data_root, get_log_root, get_project_root
from sbi.utils.kde import KDEWrapper, get_kde
from sbi.utils.potentialutils import pyro_potential_wrapper, transformed_potential
Expand Down
47 changes: 47 additions & 0 deletions sbi/utils/nn_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# import all needed modules
from typing import Optional, Tuple
from warnings import warn

from torch import Tensor, nn

from sbi.utils.user_input_checks import check_data_device


def get_numel(
batch_x: Tensor,
batch_y: Tensor,
embedding_net_x: Optional[nn.Module] | None = None,
embedding_net_y: Optional[nn.Module] | None = None,
warn_on_1d: bool = False,
) -> Tuple[int, int]:
"""
Get the number of elements in the input and output space.
Args:
batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring.
batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring.
embedding_net_x: Optional embedding network for x.
embedding_net_y: Optional embedding network for y.
warn_on_1d: Whether to warn if the output space is one-dimensional.
Returns:
Tuple of the number of elements in the input and output space.
"""
if embedding_net_x is None:
embedding_net_x = nn.Identity()
if embedding_net_y is None:
embedding_net_y = nn.Identity()

# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
# Make sure the embedding_net is on the same device as the data.
x_numel = embedding_net_x.to(batch_x.device)(batch_x[:1]).numel()
y_numel = embedding_net_y.to(batch_y.device)(batch_y[:1]).numel()
if x_numel == 1 and warn_on_1d:
warn(
"In one-dimensional output space, this flow is limited to Gaussians",
stacklevel=2,
)

return x_numel, y_numel
30 changes: 0 additions & 30 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,36 +426,6 @@ def check_prior_support(prior):
) from err


def check_embedding_net_device(embedding_net: nn.Module, datum: torch.Tensor) -> None:
"""Checks if the device for the `embedding_net`'s weights is the same as the device
for the fed `datum`. In case of discrepancy, warn the user and move the
embedding_net` to the `datum`'s device.
Args:
embedding_net: torch `Module` embedding data
datum torch `Tensor` from the training device
"""
datum_device = datum.device
embedding_net_devices = [p.device for p in embedding_net.parameters()]
if len(embedding_net_devices) > 0:
embedding_net_device = embedding_net_devices[0]
if embedding_net_device != datum_device:
warnings.warn(
"Mismatch between the device of the data fed "
"to the embedding_net and the device of the "
"embedding_net's weights. Fed data has device "
f"'{datum_device}' vs embedding_net weights have "
f"device '{embedding_net_device}'. "
"Automatically switching the embedding_net's device to "
f"'{datum_device}', which could otherwise be done manually "
f"""using the line `embedding_net.to('{datum_device}')`.""",
stacklevel=2,
)
embedding_net.to(datum_device)
else:
pass


def check_data_device(datum_1: torch.Tensor, datum_2: torch.Tensor) -> None:
"""Checks if two tensors have the seme device. Fails if there is a device
discrepancy
Expand Down
Loading

0 comments on commit 0b2b111

Please sign in to comment.