Skip to content

Commit

Permalink
add MAF with RQS as density estimator (#819)
Browse files Browse the repository at this point in the history
* For SNPE, MAF with RQS is available now.

* add density estimators to tests.

* For SNLE, MAF with RQS is available now; MAF with RQS added to tests; Black formatting.

* Adjustment of the MAF RQS default values for consistency to the NSF.

* Black formatting.

---------

Co-authored-by: Imahn <imahn.shekhzadeh@desy.de>
Co-authored-by: janfb <jan.boelts@tum.de>
  • Loading branch information
3 people authored Mar 15, 2023
1 parent 31c6076 commit 7a26b70
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 6 deletions.
111 changes: 111 additions & 0 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyknos.nflows import distributions as distributions_
from pyknos.nflows import flows, transforms
from pyknos.nflows.nn import nets
from pyknos.nflows.transforms.splines import rational_quadratic
from torch import Tensor, nn, relu, tanh, tensor, uint8

from sbi.utils.sbiutils import (
Expand Down Expand Up @@ -179,6 +180,116 @@ def build_maf(
return neural_net


def build_maf_rqs(
batch_x: Tensor,
batch_y: Tensor,
z_score_x: Optional[str] = "independent",
z_score_y: Optional[str] = "independent",
hidden_features: int = 50,
num_transforms: int = 5,
embedding_net: nn.Module = nn.Identity(),
num_blocks: int = 2,
num_bins: int = 10,
tails: Optional[str] = "linear",
tail_bound: float = 3.0,
dropout_probability: float = 0.0,
use_batch_norm: bool = False,
min_bin_width: float = rational_quadratic.DEFAULT_MIN_BIN_WIDTH,
min_bin_height: float = rational_quadratic.DEFAULT_MIN_BIN_HEIGHT,
min_derivative: float = rational_quadratic.DEFAULT_MIN_DERIVATIVE,
**kwargs,
) -> nn.Module:
"""Builds MAF p(x|y), where the diffeomorphisms are rational-quadratic
splines (RQS).
Args:
batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring.
batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring.
z_score_x: Whether to z-score xs passing into the network, can be one of:
- `none`, or None: do not z-score.
- `independent`: z-score each dimension independently.
- `structured`: treat dimensions as related, therefore compute mean and std
over the entire batch, instead of per-dimension. Should be used when each
sample is, for example, a time series or an image.
z_score_y: Whether to z-score ys passing into the network, same options as
z_score_x.
hidden_features: Number of hidden features.
num_transforms: Number of transforms.
embedding_net: Optional embedding network for y.
num_blocks: number of blocks used for residual net for context embedding.
num_bins: Number of bins of the RQS.
tails: Whether to use constrained or unconstrained RQS, can be one of:
- None: constrained RQS.
- 'linear': unconstrained RQS (RQS transformation is only
applied on domain [-B, B], with `linear` tails, outside [-B, B],
identity transformation is returned).
tail_bound: RQS transformation is applied on domain [-B, B],
`tail_bound` is equal to B.
dropout_probability: dropout probability for regularization in residual net.
use_batch_norm: whether to use batch norm in residual net.
min_bin_width: Minimum bin width.
min_bin_height: Minimum bin height.
min_derivative: Minimum derivative at knot values of bins.
kwargs: Additional arguments that are passed by the build function but are not
relevant for maf and are therefore ignored.
Returns:
Neural network.
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
y_numel = embedding_net(batch_y[:1]).numel()

if x_numel == 1:
warn("In one-dimensional output space, this flow is limited to Gaussians")

transform_list = []
for _ in range(num_transforms):
block = [
transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
features=x_numel,
hidden_features=hidden_features,
context_features=y_numel,
num_bins=num_bins,
tails=tails,
tail_bound=tail_bound,
num_blocks=num_blocks,
use_residual_blocks=False,
random_mask=False,
activation=tanh,
dropout_probability=dropout_probability,
use_batch_norm=use_batch_norm,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
),
transforms.RandomPermutation(features=x_numel),
]
transform_list += block

z_score_x_bool, structured_x = z_score_parser(z_score_x)
if z_score_x_bool:
transform_list = [
standardizing_transform(batch_x, structured_x)
] + transform_list

z_score_y_bool, structured_y = z_score_parser(z_score_y)
if z_score_y_bool:
embedding_net = nn.Sequential(
standardizing_net(batch_y, structured_y), embedding_net
)

# Combine transforms.
transform = transforms.CompositeTransform(transform_list)

distribution = distributions_.StandardNormal((x_numel,))
neural_net = flows.Flow(transform, distribution, embedding_net)

return neural_net


def build_nsf(
batch_x: Tensor,
batch_y: Tensor,
Expand Down
14 changes: 9 additions & 5 deletions sbi/utils/get_nn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
build_mlp_classifier,
build_resnet_classifier,
)
from sbi.neural_nets.flow import build_made, build_maf, build_nsf
from sbi.neural_nets.flow import build_made, build_maf, build_maf_rqs, build_nsf
from sbi.neural_nets.mdn import build_mdn
from sbi.neural_nets.mnle import build_mnle

Expand Down Expand Up @@ -109,7 +109,7 @@ def likelihood_nn(
Args:
model: The type of density estimator that will be created. One of [`mdn`,
`made`, `maf`, `nsf`].
`made`, `maf`, `maf_rqs`, `nsf`].
z_score_theta: Whether to z-score parameters $\theta$ before passing them into
the network, can take one of the following:
- `none`, or None: do not z-score.
Expand Down Expand Up @@ -158,10 +158,12 @@ def likelihood_nn(
def build_fn(batch_theta, batch_x):
if model == "mdn":
return build_mdn(batch_x=batch_x, batch_y=batch_theta, **kwargs)
if model == "made":
elif model == "made":
return build_made(batch_x=batch_x, batch_y=batch_theta, **kwargs)
if model == "maf":
elif model == "maf":
return build_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "maf_rqs":
return build_maf_rqs(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "mnle":
Expand Down Expand Up @@ -191,7 +193,7 @@ def posterior_nn(
Args:
model: The type of density estimator that will be created. One of [`mdn`,
`made`, `maf`, `nsf`].
`made`, `maf`, `maf_rqs`, `nsf`].
z_score_theta: Whether to z-score parameters $\theta$ before passing them into
the network, can take one of the following:
- `none`, or None: do not z-score.
Expand Down Expand Up @@ -261,6 +263,8 @@ def build_fn(batch_theta, batch_x):
return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "maf":
return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "maf_rqs":
return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
else:
Expand Down
2 changes: 2 additions & 0 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@
(SNPE_C, "mdn", "rejection"),
(SNPE_C, "maf", "slice"),
(SNPE_C, "maf", "direct"),
(SNPE_C, "maf_rqs", "direct"),
(SNLE, "maf", "slice"),
(SNLE, "nsf", "slice_np"),
(SNLE, "nsf", "rejection"),
(SNLE, "maf", "importance"),
(SNLE, "maf_rqs", "slice"),
(SNRE_A, "mlp", "slice_np_vectorized"),
(SNRE_B, "resnet", "nuts"),
(SNRE_B, "resnet", "rejection"),
Expand Down
2 changes: 1 addition & 1 deletion tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_c2st_snpe_on_linearGaussian(


@pytest.mark.slow
@pytest.mark.parametrize("density_estimtor", ["mdn", "maf", "nsf"])
@pytest.mark.parametrize("density_estimtor", ["mdn", "maf", "maf_rqs", "nsf"])
def test_density_estimators_on_linearGaussian(density_estimtor):
"""Test SNPE with different density estimators on linear Gaussian example."""

Expand Down

0 comments on commit 7a26b70

Please sign in to comment.