Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MAF with RQS as density estimator #819

Merged
merged 6 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyknos.nflows import distributions as distributions_
from pyknos.nflows import flows, transforms
from pyknos.nflows.nn import nets
from pyknos.nflows.transforms.splines import rational_quadratic
from torch import Tensor, nn, relu, tanh, tensor, uint8

from sbi.utils.sbiutils import (
Expand Down Expand Up @@ -179,6 +180,116 @@ def build_maf(
return neural_net


def build_maf_rqs(
batch_x: Tensor,
batch_y: Tensor,
z_score_x: Optional[str] = "independent",
z_score_y: Optional[str] = "independent",
hidden_features: int = 50,
num_transforms: int = 5,
embedding_net: nn.Module = nn.Identity(),
num_blocks: int = 2,
num_bins: int = 10,
tails: Optional[str] = "linear",
tail_bound: float = 3.0,
dropout_probability: float = 0.0,
use_batch_norm: bool = False,
min_bin_width: float = rational_quadratic.DEFAULT_MIN_BIN_WIDTH,
min_bin_height: float = rational_quadratic.DEFAULT_MIN_BIN_HEIGHT,
min_derivative: float = rational_quadratic.DEFAULT_MIN_DERIVATIVE,
**kwargs,
) -> nn.Module:
"""Builds MAF p(x|y), where the diffeomorphisms are rational-quadratic
splines (RQS).
Args:
batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring.
batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring.
z_score_x: Whether to z-score xs passing into the network, can be one of:
- `none`, or None: do not z-score.
- `independent`: z-score each dimension independently.
- `structured`: treat dimensions as related, therefore compute mean and std
over the entire batch, instead of per-dimension. Should be used when each
sample is, for example, a time series or an image.
z_score_y: Whether to z-score ys passing into the network, same options as
z_score_x.
hidden_features: Number of hidden features.
num_transforms: Number of transforms.
embedding_net: Optional embedding network for y.
num_blocks: number of blocks used for residual net for context embedding.
num_bins: Number of bins of the RQS.
tails: Whether to use constrained or unconstrained RQS, can be one of:
- None: constrained RQS.
- 'linear': unconstrained RQS (RQS transformation is only
applied on domain [-B, B], with `linear` tails, outside [-B, B],
identity transformation is returned).
tail_bound: RQS transformation is applied on domain [-B, B],
`tail_bound` is equal to B.
dropout_probability: dropout probability for regularization in residual net.
use_batch_norm: whether to use batch norm in residual net.
min_bin_width: Minimum bin width.
min_bin_height: Minimum bin height.
min_derivative: Minimum derivative at knot values of bins.
kwargs: Additional arguments that are passed by the build function but are not
relevant for maf and are therefore ignored.
Returns:
Neural network.
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
y_numel = embedding_net(batch_y[:1]).numel()

if x_numel == 1:
warn("In one-dimensional output space, this flow is limited to Gaussians")

transform_list = []
for _ in range(num_transforms):
block = [
transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
features=x_numel,
hidden_features=hidden_features,
context_features=y_numel,
num_bins=num_bins,
tails=tails,
tail_bound=tail_bound,
num_blocks=num_blocks,
use_residual_blocks=False,
random_mask=False,
activation=tanh,
dropout_probability=dropout_probability,
use_batch_norm=use_batch_norm,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
),
transforms.RandomPermutation(features=x_numel),
]
transform_list += block

z_score_x_bool, structured_x = z_score_parser(z_score_x)
if z_score_x_bool:
transform_list = [
standardizing_transform(batch_x, structured_x)
] + transform_list

z_score_y_bool, structured_y = z_score_parser(z_score_y)
if z_score_y_bool:
embedding_net = nn.Sequential(
standardizing_net(batch_y, structured_y), embedding_net
)

# Combine transforms.
transform = transforms.CompositeTransform(transform_list)

distribution = distributions_.StandardNormal((x_numel,))
neural_net = flows.Flow(transform, distribution, embedding_net)

return neural_net


def build_nsf(
batch_x: Tensor,
batch_y: Tensor,
Expand Down
14 changes: 9 additions & 5 deletions sbi/utils/get_nn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
build_mlp_classifier,
build_resnet_classifier,
)
from sbi.neural_nets.flow import build_made, build_maf, build_nsf
from sbi.neural_nets.flow import build_made, build_maf, build_maf_rqs, build_nsf
from sbi.neural_nets.mdn import build_mdn
from sbi.neural_nets.mnle import build_mnle

Expand Down Expand Up @@ -109,7 +109,7 @@ def likelihood_nn(
Args:
model: The type of density estimator that will be created. One of [`mdn`,
`made`, `maf`, `nsf`].
`made`, `maf`, `maf_rqs`, `nsf`].
z_score_theta: Whether to z-score parameters $\theta$ before passing them into
the network, can take one of the following:
- `none`, or None: do not z-score.
Expand Down Expand Up @@ -158,10 +158,12 @@ def likelihood_nn(
def build_fn(batch_theta, batch_x):
if model == "mdn":
return build_mdn(batch_x=batch_x, batch_y=batch_theta, **kwargs)
if model == "made":
elif model == "made":
return build_made(batch_x=batch_x, batch_y=batch_theta, **kwargs)
if model == "maf":
elif model == "maf":
return build_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "maf_rqs":
return build_maf_rqs(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "mnle":
Expand Down Expand Up @@ -191,7 +193,7 @@ def posterior_nn(
Args:
model: The type of density estimator that will be created. One of [`mdn`,
`made`, `maf`, `nsf`].
`made`, `maf`, `maf_rqs`, `nsf`].
z_score_theta: Whether to z-score parameters $\theta$ before passing them into
the network, can take one of the following:
- `none`, or None: do not z-score.
Expand Down Expand Up @@ -261,6 +263,8 @@ def build_fn(batch_theta, batch_x):
return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "maf":
return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "maf_rqs":
return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
else:
Expand Down
2 changes: 2 additions & 0 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@
(SNPE_C, "mdn", "rejection"),
(SNPE_C, "maf", "slice"),
(SNPE_C, "maf", "direct"),
(SNPE_C, "maf_rqs", "direct"),
(SNLE, "maf", "slice"),
(SNLE, "nsf", "slice_np"),
(SNLE, "nsf", "rejection"),
(SNLE, "maf", "importance"),
(SNLE, "maf_rqs", "slice"),
(SNRE_A, "mlp", "slice_np_vectorized"),
(SNRE_B, "resnet", "nuts"),
(SNRE_B, "resnet", "rejection"),
Expand Down
2 changes: 1 addition & 1 deletion tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_c2st_snpe_on_linearGaussian(


@pytest.mark.slow
@pytest.mark.parametrize("density_estimtor", ["mdn", "maf", "nsf"])
@pytest.mark.parametrize("density_estimtor", ["mdn", "maf", "maf_rqs", "nsf"])
def test_density_estimators_on_linearGaussian(density_estimtor):
"""Test SNPE with different density estimators on linear Gaussian example."""

Expand Down