diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 4d2e5a31f..49da8af77 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -9,6 +9,7 @@ from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets +from pyknos.nflows.transforms.splines import rational_quadratic from torch import Tensor, nn, relu, tanh, tensor, uint8 from sbi.utils.sbiutils import ( @@ -179,6 +180,116 @@ def build_maf( return neural_net +def build_maf_rqs( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: int = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + num_blocks: int = 2, + num_bins: int = 10, + tails: Optional[str] = "linear", + tail_bound: float = 3.0, + dropout_probability: float = 0.0, + use_batch_norm: bool = False, + min_bin_width: float = rational_quadratic.DEFAULT_MIN_BIN_WIDTH, + min_bin_height: float = rational_quadratic.DEFAULT_MIN_BIN_HEIGHT, + min_derivative: float = rational_quadratic.DEFAULT_MIN_DERIVATIVE, + **kwargs, +) -> nn.Module: + """Builds MAF p(x|y), where the diffeomorphisms are rational-quadratic + splines (RQS). + + Args: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: Number of hidden features. + num_transforms: Number of transforms. + embedding_net: Optional embedding network for y. + num_blocks: number of blocks used for residual net for context embedding. + num_bins: Number of bins of the RQS. + tails: Whether to use constrained or unconstrained RQS, can be one of: + - None: constrained RQS. + - 'linear': unconstrained RQS (RQS transformation is only + applied on domain [-B, B], with `linear` tails, outside [-B, B], + identity transformation is returned). + tail_bound: RQS transformation is applied on domain [-B, B], + `tail_bound` is equal to B. + dropout_probability: dropout probability for regularization in residual net. + use_batch_norm: whether to use batch norm in residual net. + min_bin_width: Minimum bin width. + min_bin_height: Minimum bin height. + min_derivative: Minimum derivative at knot values of bins. + kwargs: Additional arguments that are passed by the build function but are not + relevant for maf and are therefore ignored. + + Returns: + Neural network. + """ + x_numel = batch_x[0].numel() + # Infer the output dimensionality of the embedding_net by making a forward pass. + check_data_device(batch_x, batch_y) + check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) + y_numel = embedding_net(batch_y[:1]).numel() + + if x_numel == 1: + warn("In one-dimensional output space, this flow is limited to Gaussians") + + transform_list = [] + for _ in range(num_transforms): + block = [ + transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform( + features=x_numel, + hidden_features=hidden_features, + context_features=y_numel, + num_bins=num_bins, + tails=tails, + tail_bound=tail_bound, + num_blocks=num_blocks, + use_residual_blocks=False, + random_mask=False, + activation=tanh, + dropout_probability=dropout_probability, + use_batch_norm=use_batch_norm, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ), + transforms.RandomPermutation(features=x_numel), + ] + transform_list += block + + z_score_x_bool, structured_x = z_score_parser(z_score_x) + if z_score_x_bool: + transform_list = [ + standardizing_transform(batch_x, structured_x) + ] + transform_list + + z_score_y_bool, structured_y = z_score_parser(z_score_y) + if z_score_y_bool: + embedding_net = nn.Sequential( + standardizing_net(batch_y, structured_y), embedding_net + ) + + # Combine transforms. + transform = transforms.CompositeTransform(transform_list) + + distribution = distributions_.StandardNormal((x_numel,)) + neural_net = flows.Flow(transform, distribution, embedding_net) + + return neural_net + + def build_nsf( batch_x: Tensor, batch_y: Tensor, diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py index 68e88ffc1..71938df37 100644 --- a/sbi/utils/get_nn_models.py +++ b/sbi/utils/get_nn_models.py @@ -11,7 +11,7 @@ build_mlp_classifier, build_resnet_classifier, ) -from sbi.neural_nets.flow import build_made, build_maf, build_nsf +from sbi.neural_nets.flow import build_made, build_maf, build_maf_rqs, build_nsf from sbi.neural_nets.mdn import build_mdn from sbi.neural_nets.mnle import build_mnle @@ -109,7 +109,7 @@ def likelihood_nn( Args: model: The type of density estimator that will be created. One of [`mdn`, - `made`, `maf`, `nsf`]. + `made`, `maf`, `maf_rqs`, `nsf`]. z_score_theta: Whether to z-score parameters $\theta$ before passing them into the network, can take one of the following: - `none`, or None: do not z-score. @@ -158,10 +158,12 @@ def likelihood_nn( def build_fn(batch_theta, batch_x): if model == "mdn": return build_mdn(batch_x=batch_x, batch_y=batch_theta, **kwargs) - if model == "made": + elif model == "made": return build_made(batch_x=batch_x, batch_y=batch_theta, **kwargs) - if model == "maf": + elif model == "maf": return build_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "maf_rqs": + return build_maf_rqs(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "nsf": return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "mnle": @@ -191,7 +193,7 @@ def posterior_nn( Args: model: The type of density estimator that will be created. One of [`mdn`, - `made`, `maf`, `nsf`]. + `made`, `maf`, `maf_rqs`, `nsf`]. z_score_theta: Whether to z-score parameters $\theta$ before passing them into the network, can take one of the following: - `none`, or None: do not z-score. @@ -261,6 +263,8 @@ def build_fn(batch_theta, batch_x): return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "maf": return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "maf_rqs": + return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "nsf": return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) else: diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 3acc5724e..baaa956e5 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -51,10 +51,12 @@ (SNPE_C, "mdn", "rejection"), (SNPE_C, "maf", "slice"), (SNPE_C, "maf", "direct"), + (SNPE_C, "maf_rqs", "direct"), (SNLE, "maf", "slice"), (SNLE, "nsf", "slice_np"), (SNLE, "nsf", "rejection"), (SNLE, "maf", "importance"), + (SNLE, "maf_rqs", "slice"), (SNRE_A, "mlp", "slice_np_vectorized"), (SNRE_B, "resnet", "nuts"), (SNRE_B, "resnet", "rejection"), diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 460f517c3..039c393a4 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -146,7 +146,7 @@ def test_c2st_snpe_on_linearGaussian( @pytest.mark.slow -@pytest.mark.parametrize("density_estimtor", ["mdn", "maf", "nsf"]) +@pytest.mark.parametrize("density_estimtor", ["mdn", "maf", "maf_rqs", "nsf"]) def test_density_estimators_on_linearGaussian(density_estimtor): """Test SNPE with different density estimators on linear Gaussian example."""