From bf65102c4f726286e0c78647ffde295cca768e9d Mon Sep 17 00:00:00 2001 From: Nastya Krouglova <41705732+anastasiakrouglova@users.noreply.github.com> Date: Wed, 27 Mar 2024 16:16:39 +0100 Subject: [PATCH 01/14] Zuko density estimators (#1088) * update zuko to 1.1.0 * test zuko_gmm commit * build_zuko_nsf added * add build_zuko_naf, update test * add license change to pr template. * CLN pyproject.toml (#1009) * CLN pyproject.toml * CLN optional deps comment * CLN alphabetical order * fix x_o and broken link tutorial 7 (#1003) * fix x_o and broken link tutorial 7 * typo in title * suppress plotting output --------- Co-authored-by: Matthijs * replace prepare_for_sbi in tutorials (#1013) * add zuko density estimators * not working gmm * update tests for PR * update PR for pyright * resolve pyright * add reportArgumentType * resolve pyright issue * resolve all issues pyright * resolve pyright * add typing and docstring * add functions from factory to test * remove comment mdn file * add docstrings flow file * add docstring in density_estimator_test.py * Update sbi/neural_nets/flow.py Co-authored-by: Sebastian Bischoff * Update sbi/neural_nets/flow.py Co-authored-by: Sebastian Bischoff * Update sbi/neural_nets/flow.py Co-authored-by: Sebastian Bischoff * removed pyright --------- Co-authored-by: bkmi <12955549+bkmi@users.noreply.github.com> Co-authored-by: Nastya Krouglova Co-authored-by: Jan Boelts Co-authored-by: Thomas Moreau Co-authored-by: Matthijs Pals <34062419+Matthijspals@users.noreply.github.com> Co-authored-by: Matthijs Co-authored-by: zinaStef <49067201+zinaStef@users.noreply.github.com> Co-authored-by: Sebastian Bischoff --- pyproject.toml | 2 +- .../density_estimators/zuko_flow.py | 8 +- sbi/neural_nets/factory.py | 50 ++ sbi/neural_nets/flow.py | 765 +++++++++++++++--- tests/density_estimator_test.py | 106 ++- tests/neural_nets_factory.py | 76 +- tutorials/00_getting_started.ipynb | 5 +- tutorials/05_embedding_net.ipynb | 2 +- tutorials/07_conditional_distributions.ipynb | 2 +- 9 files changed, 881 insertions(+), 135 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 780604ade..05e3932b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "tensorboard", "torch>=1.8.0", "tqdm", - "zuko>=1.0.0", + "zuko>=1.1.0", ] [project.optional-dependencies] diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index 5b2f98af4..234d08a71 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -2,7 +2,7 @@ import torch from torch import Tensor, nn -from zuko.flows import Flow +from zuko.flows.core import Flow from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.sbi_types import Shape @@ -76,6 +76,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: emb_cond = emb_cond.expand(batch_shape + (emb_cond.shape[-1],)) dists = self.net(emb_cond) + log_probs = dists.log_prob(input) return log_probs @@ -117,7 +118,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: emb_cond = self._embedding_net(condition) dists = self.net(emb_cond) - # zuko.sample() returns (*sample_shape, *batch_shape, input_size). + samples = dists.sample(sample_shape).reshape(*batch_shape, *sample_shape, -1) return samples @@ -141,9 +142,8 @@ def sample_and_log_prob( emb_cond = self._embedding_net(condition) dists = self.net(emb_cond) - samples, log_probs = dists.rsample_and_log_prob(sample_shape) - # zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...). + samples, log_probs = dists.rsample_and_log_prob(sample_shape) samples = samples.reshape(*batch_shape, *sample_shape, -1) log_probs = log_probs.reshape(*batch_shape, *sample_shape) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index 1273b4587..d1eae3f84 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -16,7 +16,16 @@ build_maf, build_maf_rqs, build_nsf, + build_zuko_bpf, + build_zuko_cnf, + build_zuko_gf, build_zuko_maf, + build_zuko_naf, + build_zuko_ncsf, + build_zuko_nice, + build_zuko_nsf, + build_zuko_sospf, + build_zuko_unaf, ) from sbi.neural_nets.mdn import build_mdn from sbi.neural_nets.mnle import build_mnle @@ -174,8 +183,26 @@ def build_fn(batch_theta, batch_x): return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "mnle": return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_nice": + return build_zuko_nice(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "zuko_maf": return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_nsf": + return build_zuko_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_ncsf": + return build_zuko_ncsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_sospf": + return build_zuko_sospf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_naf": + return build_zuko_naf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_unaf": + return build_zuko_unaf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_cnf": + return build_zuko_cnf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_gf": + return build_zuko_gf(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_bpf": + return build_zuko_bpf(batch_x=batch_x, batch_y=batch_theta, **kwargs) else: raise NotImplementedError @@ -266,6 +293,9 @@ def build_fn_snpe_a(batch_theta, batch_x, num_components): def build_fn(batch_theta, batch_x): if model == "mdn": + # The naming might be a bit confusing. + # batch_x are the latent variables, batch_y the conditioned variables. + # batch_theta are the parameters and batch_x the observable variables. return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "made": return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs) @@ -275,8 +305,28 @@ def build_fn(batch_theta, batch_x): return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "nsf": return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "mnle": + return build_mnle(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_nice": + return build_zuko_nice(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "zuko_maf": return build_zuko_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_nsf": + return build_zuko_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_ncsf": + return build_zuko_ncsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_sospf": + return build_zuko_sospf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_naf": + return build_zuko_naf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_unaf": + return build_zuko_unaf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_cnf": + return build_zuko_cnf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_gf": + return build_zuko_gf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_bpf": + return build_zuko_bpf(batch_x=batch_theta, batch_y=batch_x, **kwargs) else: raise NotImplementedError diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 23c531654..709690284 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -2,7 +2,7 @@ # under the Affero General Public License v3, see . from functools import partial -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union from warnings import warn import torch @@ -11,21 +11,48 @@ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets from pyknos.nflows.transforms.splines import ( - rational_quadratic, # pyright: ignore[reportAttributeAccessIssue] + rational_quadratic, ) from torch import Tensor, nn, relu, tanh, tensor, uint8 -from zuko.flows import LazyTransform from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow from sbi.utils.sbiutils import ( standardizing_net, standardizing_transform, + standardizing_transform_zuko, z_score_parser, ) from sbi.utils.torchutils import create_alternating_binary_mask from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device +def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]: + """ + Get the number of elements in the input and output space. + + Args: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + embedding_net: Optional embedding network for y. + + Returns: + Tuple of the number of elements in the input and output space. + + """ + x_numel = batch_x[0].numel() + # Infer the output dimensionality of the embedding_net by making a forward pass. + check_data_device(batch_x, batch_y) + check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) + y_numel = embedding_net(batch_y[:1]).numel() + if x_numel == 1: + warn( + "In one-dimensional output space, this flow is limited to Gaussians", + stacklevel=2, + ) + + return x_numel, y_numel + + def build_made( batch_x: Tensor, batch_y: Tensor, @@ -58,18 +85,7 @@ def build_made( Returns: Neural network. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() - - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=2, - ) + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) transform = transforms.IdentityTransform() @@ -142,18 +158,7 @@ def build_maf( Returns: Neural network. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() - - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=2, - ) + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) transform_list = [] for _ in range(num_transforms): @@ -185,7 +190,7 @@ def build_maf( standardizing_net(batch_y, structured_y), embedding_net ) - # Combine transforms. + # Combine transforms transform = transforms.CompositeTransform(transform_list) distribution = get_base_dist(x_numel, **kwargs) @@ -251,18 +256,7 @@ def build_maf_rqs( Returns: Neural network. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() - - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=2, - ) + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) transform_list = [] for _ in range(num_transforms): @@ -355,12 +349,7 @@ def build_nsf( Returns: Neural network. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) # Define mask function to alternate between predicted x-dimensions. def mask_in_layer(i): @@ -433,6 +422,61 @@ def mask_in_layer(i): return flow +def build_zuko_nice( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + randmask: bool = False, + **kwargs, +) -> ZukoFlow: + """ + Build a Non-linear Independent Components Estimation (NICE) flow. + + Affine transformations are used by default, instead of the additive transformations + used by Dinh et al. (2014) originally. + + References: + | NICE: Non-linear Independent Components Estimation (Dinh et al., 2014) + | https://arxiv.org/abs/1410.8516 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + randmask: Whether to use random masks in the flow. Defaults to False. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "NICE" + additional_kwargs = {"randmask": randmask, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + def build_zuko_maf( batch_x: Tensor, batch_y: Tensor, @@ -441,13 +485,17 @@ def build_zuko_maf( hidden_features: Union[Sequence[int], int] = 50, num_transforms: int = 5, embedding_net: nn.Module = nn.Identity(), - residual: bool = True, randperm: bool = False, **kwargs, ) -> ZukoFlow: - """Builds MAF p(x|y). + """ + Build a Masked Autoregressive Flow (MAF). - Args: + References: + | Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017) + | https://arxiv.org/abs/1705.07057 + + Arguments: batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. z_score_x: Whether to z-score xs passing into the network, can be one of: @@ -458,69 +506,598 @@ def build_zuko_maf( sample is, for example, a time series or an image. z_score_y: Whether to z-score ys passing into the network, same options as z_score_x. - hidden_features: Number of hidden features. - num_transforms: Number of transforms. - embedding_net: Optional embedding network for y. - residual: whether to use residual blocks in the coupling layer. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + randperm: Whether to use random permutations in the flow. Defaults to False. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "MAF" + additional_kwargs = {"randperm": randperm, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_nsf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + num_bins: int = 8, + **kwargs, +) -> ZukoFlow: + """ + Build a Neural Spline Flow (NSF) with monotonic rational-quadratic spline + transformations. + + By default, transformations are fully autoregressive. Coupling transformations + can be obtained by setting :py:`passes=2`. + + Warning: + Spline transformations are defined over the domain :math:`[-5, 5]`. Any feature + outside of this domain is not transformed. It is recommended to standardize + features (zero mean, unit variance) before training. + + References: + | Neural Spline Flows (Durkan et al., 2019) + | https://arxiv.org/abs/1906.04032 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + num_bins: The number of bins in the spline transformations. Defaults to 8. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "NSF" + additional_kwargs = {"bins": num_bins, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_ncsf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + num_bins: int = 8, + **kwargs, +) -> ZukoFlow: + r""" + Build a Neural Circular Spline Flow (NCSF). + + Circular spline transformations are obtained by composing circular domain shifts + with regular spline transformations. Features are assumed to lie in the half-open + interval :math:`[-\pi, \pi[`. + + References: + | Normalizing Flows on Tori and Spheres (Rezende et al., 2020) + | https://arxiv.org/abs/2002.02428 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + num_bins: The number of bins in the spline transformations. Defaults to 8. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "NCSF" + additional_kwargs = {"bins": num_bins, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_sospf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + degree: int = 4, + polynomials: int = 3, + **kwargs, +) -> ZukoFlow: + """ + Build a Sum-of-Squares Polynomial Flow (SOSPF). + + References: + | Sum-of-Squares Polynomial Flow (Jaini et al., 2019) + | https://arxiv.org/abs/1905.02325 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + degree: The degree of the polynomials. Defaults to 4. + polynomials: The number of polynomials. Defaults to 3. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "SOSPF" + additional_kwargs = {"degree": degree, "polynomials": polynomials, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_naf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + randperm: bool = False, + signal: int = 16, + **kwargs, +) -> ZukoFlow: + """ + Build a Neural Autoregressive Flow (NAF). + + Warning: + Invertibility is only guaranteed for features within the interval :math:`[-10, + 10]`. It is recommended to standardize features (zero mean, unit variance) + before training. + + References: + | Neural Autoregressive Flows (Huang et al., 2018) + | https://arxiv.org/abs/1804.00779 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). randperm: Whether features are randomly permuted between transformations or not. - kwargs: Additional arguments that are passed by the build function but are not - relevant for maf and are therefore ignored. + If :py:`False`, features are in ascending (descending) order for even + (odd) transformations. + signal: The number of signal features of the monotonic network. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "NAF" + additional_kwargs = { + "randperm": randperm, + "signal": signal, + # "network": network, + **kwargs, + } + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_unaf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + randperm: bool = False, + signal: int = 16, + **kwargs, +) -> ZukoFlow: + """ + Build an Unconstrained Neural Autoregressive Flow (UNAF). + + Warning: + Invertibility is only guaranteed for features within the interval :math:`[-10, + 10]`. It is recommended to standardize features (zero mean, unit variance) + before training. + + References: + | Unconstrained Monotonic Neural Networks (Wehenkel et al., 2019) + | https://arxiv.org/abs/1908.05164 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + randperm: Whether features are randomly permuted between transformations or not. + If :py:`False`, features are in ascending (descending) order for even + (odd) transformations. + signal: The number of signal features of the monotonic network. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "UNAF" + additional_kwargs = { + "randperm": randperm, + "signal": signal, + # "network": network, + **kwargs, + } + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_cnf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + **kwargs, +) -> ZukoFlow: + """ + Build a Continuous Normalizing Flow (CNF) with a free-form Jacobian transformation. + + References: + | Neural Ordinary Differential Equations (Chen el al., 2018) + | https://arxiv.org/abs/1806.07366 + + | FFJORD: Free-form Continuous Dynamics for Scalable Reversible + | Generative Models (Grathwohl et al., 2018) + | https://arxiv.org/abs/1810.01367 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "CNF" + additional_kwargs = {**kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_gf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 3, + embedding_net: nn.Module = nn.Identity(), + components: int = 8, + **kwargs, +) -> ZukoFlow: + """ + Build a Gaussianization Flow (GF). + + Warning: + Invertibility is only guaranteed for features within the interval :math:`[-10, + 10]`. It is recommended to standardize features (zero mean, unit variance) + before training. + + References: + | Gaussianization Flows (Meng et al., 2020) + | https://arxiv.org/abs/2003.01941 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + components: The number of components in the Gaussian mixture model. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "GF" + additional_kwargs = {"components": components, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_bpf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 3, + embedding_net: nn.Module = nn.Identity(), + degree: int = 16, + linear: bool = False, + **kwargs, +) -> ZukoFlow: + """ + Build a Bernstein polynomial flow (BPF). + + Warning: + Invertibility is only guaranteed for features within the interval :math:`[-10, + 10]`. It is recommended to standardize features (zero mean, unit variance) + before training. + + References: + | Short-Term Density Forecasting of Low-Voltage Load using + | Bernstein-Polynomial Normalizing Flows (Arpogaus et al., 2022) + | https://arxiv.org/abs/2204.13939 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + degree: The degree :math:`M` of the Bernstein polynomial. + linear: Whether to use a linear or sigmoid mapping to :math:`[0, 1]`. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "BPF" + additional_kwargs = {"degree": degree, "linear": linear, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_flow( + which_nf: str, + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + **kwargs, +) -> ZukoFlow: + """ + Fundamental building blocks to build a Zuko normalizing flow model. + + Args: + which_nf (str): The type of normalizing flow to build. + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + **kwargs: Additional keyword arguments to pass to the flow constructor. Returns: - Neural network. + ZukoFlow: The constructed Zuko normalizing flow model. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=1, - ) + + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) if isinstance(hidden_features, int): hidden_features = [hidden_features] * num_transforms - if x_numel == 1: - maf = zuko.flows.MAF( - features=x_numel, - context=y_numel, - hidden_features=hidden_features, - transforms=num_transforms, + build_nf = getattr(zuko.flows, which_nf) + + if which_nf == "CNF": + flow_built = build_nf( + features=x_numel, context=y_numel, hidden_features=hidden_features, **kwargs ) else: - maf = zuko.flows.MAF( + flow_built = build_nf( features=x_numel, context=y_numel, hidden_features=hidden_features, transforms=num_transforms, - randperm=randperm, - residual=residual, + **kwargs, ) - transforms: Union[Sequence[LazyTransform], LazyTransform] - transforms = maf.transform.transforms # pyright: ignore[reportAssignmentType] - z_score_x_bool, structured_x = z_score_parser(z_score_x) - if z_score_x_bool: - # transforms = transforms - transforms = ( - *transforms, - # Ideally `standardizing_transform` would return a `LazyTransform` instead of ` AffineTransform | Unconditional`, maybe all three are compatible - standardizing_transform(batch_x, structured_x, backend="zuko"), # pyright: ignore[reportAssignmentType] - ) + # Continuous normalizing flows (CNF) only have one transform, + # so we need to handle them slightly differently. + if which_nf == "CNF": + transform = flow_built.transform - z_score_y_bool, structured_y = z_score_parser(z_score_y) - if z_score_y_bool: - # Prepend standardizing transform to y-embedding. - embedding_net = nn.Sequential( - standardizing_net(batch_y, structured_y), embedding_net - ) + z_score_x_bool, structured_x = z_score_parser(z_score_x) + if z_score_x_bool: + transform = ( + transform, + standardizing_transform_zuko(batch_x, structured_x, backend="zuko"), + ) - # Combine transforms. - neural_net = zuko.flows.Flow(transforms, maf.base) + z_score_y_bool, structured_y = z_score_parser(z_score_y) + if z_score_y_bool: + # Prepend standardizing transform to y-embedding. + embedding_net = nn.Sequential( + standardizing_transform_zuko(batch_y, structured_y), embedding_net + ) + + # Combine transforms. + neural_net = zuko.flows.Flow(transform, flow_built.base) + else: + transforms = flow_built.transform.transforms + + z_score_x_bool, structured_x = z_score_parser(z_score_x) + if z_score_x_bool: + transforms = ( + *transforms, + standardizing_transform_zuko(batch_x, structured_x), + ) + + z_score_y_bool, structured_y = z_score_parser(z_score_y) + if z_score_y_bool: + # Prepend standardizing transform to y-embedding. + embedding_net = nn.Sequential( + standardizing_net(batch_y, structured_y), embedding_net + ) + + # Combine transforms. + neural_net = zuko.flows.Flow(transforms, flow_built.base) flow = ZukoFlow(neural_net, embedding_net, condition_shape=batch_y[0].shape) diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 2468a0fbc..cc7c97f7a 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -10,21 +10,84 @@ from torch import eye, zeros from torch.distributions import MultivariateNormal -from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow from sbi.neural_nets.density_estimators.shape_handling import reshape_to_iid_batch_event -from sbi.neural_nets.flow import build_nsf, build_zuko_maf +from sbi.neural_nets.flow import ( + build_maf, + build_maf_rqs, + build_nsf, + build_zuko_bpf, + build_zuko_cnf, + build_zuko_gf, + build_zuko_maf, + build_zuko_naf, + build_zuko_ncsf, + build_zuko_nice, + build_zuko_nsf, + build_zuko_sospf, + build_zuko_unaf, +) + + +def get_batch_input(nsamples: int, input_dims: int) -> torch.Tensor: + r"""Generate a batch of input samples from a multivariate normal distribution. + + Args: + nsamples (int): The number of samples to generate. + input_dims (int): The dimensionality of the input samples. + + Returns: + torch.Tensor: A tensor of shape (nsamples, input_dims) + containing the generated samples. + """ + input_mvn = MultivariateNormal( + loc=zeros(input_dims), covariance_matrix=eye(input_dims) + ) + return input_mvn.sample((nsamples,)) + + +def get_batch_context(nsamples: int, condition_shape: tuple[int, ...]) -> torch.Tensor: + r"""Generate a batch of context samples from a multivariate normal distribution. + + Args: + nsamples (int): The number of context samples to generate. + condition_shape (tuple[int, ...]): The shape of the condition for each sample. + + Returns: + torch.Tensor: A tensor containing the generated context samples. + """ + context_mvn = MultivariateNormal( + loc=zeros(*condition_shape), covariance_matrix=eye(condition_shape[-1]) + ) + return context_mvn.sample((nsamples,)) -@pytest.mark.parametrize("density_estimator", (NFlowsFlow, ZukoFlow)) +@pytest.mark.parametrize( + "build_density_estimator", + ( + build_maf, + build_maf_rqs, + build_nsf, + build_zuko_nice, + build_zuko_maf, + build_zuko_nsf, + build_zuko_ncsf, + build_zuko_sospf, + build_zuko_naf, + build_zuko_unaf, + build_zuko_cnf, + build_zuko_gf, + build_zuko_bpf, + ), +) @pytest.mark.parametrize("input_dims", (1, 2)) @pytest.mark.parametrize( "condition_shape", ((1,), (2,), (1, 1), (2, 2), (1, 1, 1), (2, 2, 2)) ) -def test_api_density_estimator(density_estimator, input_dims, condition_shape): +def test_api_density_estimator(build_density_estimator, input_dims, condition_shape): r"""Checks whether we can evaluate and sample from density estimators correctly. Args: - density_estimator: DensityEstimator subclass. + build_density_estimator: function that creates a DensityEstimator subclass. input_dim: Dimensionality of the input. context_shape: Dimensionality of the context. """ @@ -32,14 +95,8 @@ def test_api_density_estimator(density_estimator, input_dims, condition_shape): nsamples = 10 nsamples_test = 5 - input_mvn = MultivariateNormal( - loc=zeros(input_dims), covariance_matrix=eye(input_dims) - ) - batch_input = input_mvn.sample((nsamples,)) - context_mvn = MultivariateNormal( - loc=zeros(*condition_shape), covariance_matrix=eye(condition_shape[-1]) - ) - batch_context = context_mvn.sample((nsamples,)) + batch_input = get_batch_input(nsamples, input_dims) + batch_context = get_batch_context(nsamples, condition_shape) class EmbeddingNet(torch.nn.Module): def forward(self, x): @@ -47,22 +104,13 @@ def forward(self, x): x = torch.sum(x, dim=-1) return x - if density_estimator == NFlowsFlow: - estimator = build_nsf( - batch_input, - batch_context, - hidden_features=10, - num_transforms=2, - embedding_net=EmbeddingNet(), - ) - elif density_estimator == ZukoFlow: - estimator = build_zuko_maf( - batch_input, - batch_context, - hidden_features=10, - num_transforms=2, - embedding_net=EmbeddingNet(), - ) + estimator = build_density_estimator( + batch_input, + batch_context, + hidden_features=10, + num_transforms=2, + embedding_net=EmbeddingNet(), + ) # Loss is only required to work for batched inputs and contexts loss = estimator.loss(batch_input, batch_context) diff --git a/tests/neural_nets_factory.py b/tests/neural_nets_factory.py index f5a2775f3..6e6d8751f 100644 --- a/tests/neural_nets_factory.py +++ b/tests/neural_nets_factory.py @@ -14,8 +14,42 @@ def test_deprecated_import_classifier_nn(model: str): @pytest.mark.parametrize( "model", - ["mdn", "made", "maf", "maf_rqs", "nsf", "mnle", "zuko_maf"], - ids=["mdn", "made", "maf", "maf_rqs", "nsf", "mnle", "zuko_maf"], + [ + "mdn", + "made", + "maf", + "maf_rqs", + "nsf", + "mnle", + "zuko_bpf", + "zuko_cnf", + "zuko_gf", + "zuko_maf", + "zuko_naf", + "zuko_ncsf", + "zuko_nice", + "zuko_nsf", + "zuko_sospf", + "zuko_unaf", + ], + ids=[ + "mdn", + "made", + "maf", + "maf_rqs", + "nsf", + "mnle", + "zuko_bpf", + "zuko_cnf", + "zuko_gf", + "zuko_maf", + "zuko_naf", + "zuko_ncsf", + "zuko_nice", + "zuko_nsf", + "zuko_sospf", + "zuko_unaf", + ], ) def test_deprecated_import_likelihood_nn(model: str): with pytest.warns(DeprecationWarning): @@ -25,8 +59,42 @@ def test_deprecated_import_likelihood_nn(model: str): @pytest.mark.parametrize( "model", - ["mdn", "made", "maf", "maf_rqs", "nsf", "mnle", "zuko_maf"], - ids=["mdn", "made", "maf", "maf_rqs", "nsf", "mnle", "zuko_maf"], + [ + "mdn", + "made", + "maf", + "maf_rqs", + "nsf", + "mnle", + "zuko_bpf", + "zuko_cnf", + "zuko_gf", + "zuko_maf", + "zuko_naf", + "zuko_ncsf", + "zuko_nice", + "zuko_nsf", + "zuko_sospf", + "zuko_unaf", + ], + ids=[ + "mdn", + "made", + "maf", + "maf_rqs", + "nsf", + "mnle", + "zuko_bpf", + "zuko_cnf", + "zuko_gf", + "zuko_maf", + "zuko_naf", + "zuko_ncsf", + "zuko_nice", + "zuko_nsf", + "zuko_sospf", + "zuko_unaf", + ], ) def test_deprecated_import_posterior_nn(model: str): with pytest.warns(DeprecationWarning): diff --git a/tutorials/00_getting_started.ipynb b/tutorials/00_getting_started.ipynb index d9b48e85f..507dc4207 100644 --- a/tutorials/00_getting_started.ipynb +++ b/tutorials/00_getting_started.ipynb @@ -98,7 +98,10 @@ "source": [ "# Other methods are \"SNLE\" or \"SNRE\".\n", "posterior = infer(simulator, prior, method=\"SNPE\", num_simulations=1000)\n", - "# Using `init_kwargs`, `train_kwargs` and `build_posterior_kwargs`, you can also pass additional keyword arguments to `__init__`, `train` and `build_posterior` of the inference method. But we recommend to use the flexible interface which is introduced in a later tutorial." + "# Using `init_kwargs`, `train_kwargs` and `build_posterior_kwargs`,\n", + "# you can also pass additional keyword arguments to `__init__`, `train` and\n", + "# `build_posterior` of the inference method. But we recommend to use the\n", + "# flexible interface which is introduced in a later tutorial." ] }, { diff --git a/tutorials/05_embedding_net.ipynb b/tutorials/05_embedding_net.ipynb index e7adde4ad..f0023bc1f 100644 --- a/tutorials/05_embedding_net.ipynb +++ b/tutorials/05_embedding_net.ipynb @@ -339,7 +339,7 @@ "name": "stdout", "output_type": "stream", "text": [ - " Neural network successfully converged after 96 epochs." + " Neural network successfully converged after 220 epochs." ] } ], diff --git a/tutorials/07_conditional_distributions.ipynb b/tutorials/07_conditional_distributions.ipynb index a98549f88..9e718ca8b 100644 --- a/tutorials/07_conditional_distributions.ipynb +++ b/tutorials/07_conditional_distributions.ipynb @@ -26124,7 +26124,7 @@ "name": "stdout", "output_type": "stream", "text": [ - " Neural network successfully converged after 118 epochs." + " Neural network successfully converged after 65 epochs." ] } ], From 5c65d7d9974077d8f05c08ac34f4ee5cca4a1873 Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:13:55 +0200 Subject: [PATCH 02/14] merge --- sbi/neural_nets/flow.py | 41 +++++++- tutorials/05_embedding_net.ipynb | 163 +++++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 1 deletion(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index fedca6972..929568886 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -3,6 +3,7 @@ from functools import partial from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union from warnings import warn import torch @@ -12,6 +13,7 @@ from pyknos.nflows.nn import nets from pyknos.nflows.transforms.splines import ( rational_quadratic, + rational_quadratic, ) from torch import Tensor, nn, relu, tanh, tensor, uint8 @@ -26,6 +28,33 @@ from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device +def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]: + """ + Get the number of elements in the input and output space. + + Args: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + embedding_net: Optional embedding network for y. + + Returns: + Tuple of the number of elements in the input and output space. + + """ + x_numel = batch_x[0].numel() + # Infer the output dimensionality of the embedding_net by making a forward pass. + check_data_device(batch_x, batch_y) + check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) + y_numel = embedding_net(batch_y[:1]).numel() + if x_numel == 1: + warn( + "In one-dimensional output space, this flow is limited to Gaussians", + stacklevel=2, + ) + + return x_numel, y_numel + + def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]: """ Get the number of elements in the input and output space. @@ -1058,10 +1087,20 @@ def build_zuko_flow( **kwargs, ) +<<<<<<< HEAD + transforms = maf.transform + z_score_x_bool, structured_x = z_score_parser(z_score_x) + if z_score_x_bool: + transforms = ( + transforms, + standardizing_transform_zuko(batch_x, structured_x), + ) +======= # Continuous normalizing flows (CNF) only have one transform, # so we need to handle them slightly differently. - if which_nf == "CNF": + if which_nf == "CNF": transform = flow_built.transform +>>>>>>> bf65102c (Zuko density estimators (#1088)) z_score_x_bool, structured_x = z_score_parser(z_score_x) if z_score_x_bool: diff --git a/tutorials/05_embedding_net.ipynb b/tutorials/05_embedding_net.ipynb index 78bae6598..0bf13d5d4 100644 --- a/tutorials/05_embedding_net.ipynb +++ b/tutorials/05_embedding_net.ipynb @@ -460,8 +460,171 @@ " x = F.relu(self.fc(x))\n", " return x\n", "\n", +<<<<<<< HEAD "# instantiate the custom embedding_net\n", "embedding_net_custom = SummaryNet()" +======= + "\n", + "embedding_net = SummaryNet()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The inference procedure\n", + "\n", + "With the `embedding_net` defined and instantiated, we can follow the usual workflow of an inference procedure in `sbi`. The `embedding_net` object appears as an input argument when instantiating the neural density estimator with `sbi.neural_nets.posterior_nn`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# set prior distribution for the parameters\n", + "prior = utils.BoxUniform(\n", + " low=torch.tensor([0.0, 0.0]), high=torch.tensor([1.0, 2 * torch.pi])\n", + ")\n", + "\n", + "# make a SBI-wrapper on the simulator object for compatibility\n", + "prior, num_parameters, prior_returns_numpy = process_prior(prior)\n", + "simulator_wrapper = process_simulator(simulator_model, prior, prior_returns_numpy)\n", + "check_sbi_inputs(simulator_wrapper, prior)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from sbi.neural_nets import posterior_nn\n", + "\n", + "# instantiate the neural density estimator\n", + "neural_posterior = posterior_nn(\n", + " model=\"maf\", embedding_net=embedding_net, hidden_features=10, num_transforms=2\n", + ")\n", + "\n", + "# setup the inference procedure with the SNPE-C procedure\n", + "inferer = SNPE(prior=prior, density_estimator=neural_posterior)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a073de8ca9094ec9a4909f74ba837bfb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running 10000 simulations.: 0%| | 0/10000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# create the figure\n", + "fig, ax = analysis.pairplot(\n", + " samples,\n", + " points=true_parameter,\n", + " labels=[\"r\", r\"$\\theta$\"],\n", + " limits=[[0, 1], [0, 2 * torch.pi]],\n", + " points_colors=\"r\",\n", + " points_offdiag={\"markersize\": 6},\n", + " figsize=(5, 5),\n", + ")" +>>>>>>> bf65102c (Zuko density estimators (#1088)) ] } ], From 5a2db8700320801e09723f739127d6c6f3eec929 Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:11:40 +0200 Subject: [PATCH 03/14] hate --- sbi/neural_nets/flow.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 929568886..b3028f7e8 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -1087,20 +1087,10 @@ def build_zuko_flow( **kwargs, ) -<<<<<<< HEAD - transforms = maf.transform - z_score_x_bool, structured_x = z_score_parser(z_score_x) - if z_score_x_bool: - transforms = ( - transforms, - standardizing_transform_zuko(batch_x, structured_x), - ) -======= # Continuous normalizing flows (CNF) only have one transform, # so we need to handle them slightly differently. if which_nf == "CNF": transform = flow_built.transform ->>>>>>> bf65102c (Zuko density estimators (#1088)) z_score_x_bool, structured_x = z_score_parser(z_score_x) if z_score_x_bool: From cb87ed041eb22ca90165d045dc575dc5675a648c Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:16:45 +0200 Subject: [PATCH 04/14] merge --- examples/00_HH_simulator.ipynb | 2 +- sbi/neural_nets/flow.py | 29 ------------------- tutorials/00_getting_started_flexible.ipynb | 4 +-- tutorials/01_gaussian_amortized.ipynb | 4 +-- .../17_importance_sampled_posteriors.ipynb | 28 +++++++++--------- 5 files changed, 19 insertions(+), 48 deletions(-) diff --git a/examples/00_HH_simulator.ipynb b/examples/00_HH_simulator.ipynb index 269cd3663..3b88cce0b 100644 --- a/examples/00_HH_simulator.ipynb +++ b/examples/00_HH_simulator.ipynb @@ -256,7 +256,7 @@ "ax.set_xticks([])\n", "ax.set_yticks([-80, -20, 40])\n", "\n", - "# plot the injected current \n", + "# plot the injected current\n", "ax = plt.subplot(gs[1])\n", "plt.plot(t, I_inj * A_soma * 1e3, \"k\", lw=2)\n", "plt.xlabel(\"time (ms)\")\n", diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index b3028f7e8..709690284 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -3,7 +3,6 @@ from functools import partial from typing import List, Optional, Sequence, Tuple, Union -from typing import List, Optional, Sequence, Tuple, Union from warnings import warn import torch @@ -13,7 +12,6 @@ from pyknos.nflows.nn import nets from pyknos.nflows.transforms.splines import ( rational_quadratic, - rational_quadratic, ) from torch import Tensor, nn, relu, tanh, tensor, uint8 @@ -28,33 +26,6 @@ from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device -def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]: - """ - Get the number of elements in the input and output space. - - Args: - batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. - batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. - embedding_net: Optional embedding network for y. - - Returns: - Tuple of the number of elements in the input and output space. - - """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - y_numel = embedding_net(batch_y[:1]).numel() - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=2, - ) - - return x_numel, y_numel - - def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]: """ Get the number of elements in the input and output space. diff --git a/tutorials/00_getting_started_flexible.ipynb b/tutorials/00_getting_started_flexible.ipynb index 552bb71fc..b8094b100 100644 --- a/tutorials/00_getting_started_flexible.ipynb +++ b/tutorials/00_getting_started_flexible.ipynb @@ -136,7 +136,7 @@ "metadata": {}, "outputs": [], "source": [ - "inference = SNPE(prior=prior) " + "inference = SNPE(prior=prior)" ] }, { @@ -266,7 +266,7 @@ "outputs": [], "source": [ "theta_true = prior.sample((1,))\n", - "# generate our observation \n", + "# generate our observation\n", "x_obs = simulator(theta_true)" ] }, diff --git a/tutorials/01_gaussian_amortized.ipynb b/tutorials/01_gaussian_amortized.ipynb index 67b9e7833..69f798783 100644 --- a/tutorials/01_gaussian_amortized.ipynb +++ b/tutorials/01_gaussian_amortized.ipynb @@ -183,7 +183,7 @@ "# plot posterior samples\n", "_ = analysis.pairplot(\n", " posterior_samples_1, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(5, 5),\n", - " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"], \n", + " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n", " points=theta_1 # add ground truth thetas\n", ")" ] @@ -238,7 +238,7 @@ "# plot posterior samples\n", "_ = analysis.pairplot(\n", " posterior_samples_2, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(5, 5),\n", - " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"], \n", + " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n", " points=theta_2 # add ground truth thetas\n", ")" ] diff --git a/tutorials/17_importance_sampled_posteriors.ipynb b/tutorials/17_importance_sampled_posteriors.ipynb index f23f92d8d..7c17ebf5b 100644 --- a/tutorials/17_importance_sampled_posteriors.ipynb +++ b/tutorials/17_importance_sampled_posteriors.ipynb @@ -162,7 +162,7 @@ "class Simulator:\n", " def __init__(self):\n", " pass\n", - " \n", + "\n", " def log_likelihood(self, theta, x):\n", " return MultivariateNormal(theta, eye(2)).log_prob(x)\n", "\n", @@ -312,7 +312,7 @@ "source": [ "# get weighted samples\n", "theta_inferred_is = theta_inferred[torch.where(w > torch.rand(len(w)) * torch.max(w))]\n", - "# *Note*: we here perform rejection sampling, as the plotting function \n", + "# *Note*: we here perform rejection sampling, as the plotting function\n", "# used below does not support weighted samples. In general, with rejection\n", "# sampling the number of samples will be smaller than the effective sample\n", "# size unless we allow for duplicate samples.\n", @@ -323,8 +323,8 @@ "\n", "# plot\n", "fig, ax = marginal_plot(\n", - " [theta_inferred, theta_inferred_is, gt_samples], \n", - " limits=[[-5, 5], [-5, 5]], \n", + " [theta_inferred, theta_inferred_is, gt_samples],\n", + " limits=[[-5, 5], [-5, 5]],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", ")\n", @@ -22243,8 +22243,8 @@ ], "source": [ "fig, ax = marginal_plot(\n", - " [theta_inferred_sir_2, theta_inferred_sir_32, gt_samples], \n", - " limits=[[-5, 5], [-5, 5]], \n", + " [theta_inferred_sir_2, theta_inferred_sir_32, gt_samples],\n", + " limits=[[-5, 5], [-5, 5]],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", ")\n", @@ -22280,8 +22280,8 @@ ], "source": [ "fig, ax = marginal_plot(\n", - " [gt_samples, theta_inferred], \n", - " limits=[[-5, 5], [-5, 5]], \n", + " [gt_samples, theta_inferred],\n", + " limits=[[-5, 5], [-5, 5]],\n", " weights=[None, w],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", @@ -22400,9 +22400,9 @@ "\n", "for i in range(len(observations)):\n", " fig, ax = marginal_plot(\n", - " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], \n", - " limits=[[-5, 5], [-5, 5]], \n", - " points=theta_gt[i], \n", + " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],\n", + " limits=[[-5, 5], [-5, 5]],\n", + " points=theta_gt[i],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", " )\n", @@ -23967,9 +23967,9 @@ "\n", "for i in range(len(observations)):\n", " fig, ax = marginal_plot(\n", - " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], \n", - " limits=[[-5, 5], [-5, 5]], \n", - " points=theta_gt[i], \n", + " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],\n", + " limits=[[-5, 5], [-5, 5]],\n", + " points=theta_gt[i],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", " )\n", From 4c01901ac5ed395c255bbcf7ae6e48a7411d116b Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:23:15 +0200 Subject: [PATCH 05/14] merge --- sbi/neural_nets/flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 709690284..13b1629cd 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -26,7 +26,7 @@ from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device -def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]: +def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[int, int]: """ Get the number of elements in the input and output space. @@ -1067,7 +1067,7 @@ def build_zuko_flow( if z_score_x_bool: transform = ( transform, - standardizing_transform_zuko(batch_x, structured_x, backend="zuko"), + standardizing_transform_zuko(batch_x, structured_x), ) z_score_y_bool, structured_y = z_score_parser(z_score_y) From 34c8dff35256504414942303ad6871d94f91d44c Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:29:19 +0200 Subject: [PATCH 06/14] merge --- sbi/neural_nets/flow.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 13b1629cd..88b93ba96 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -10,9 +10,7 @@ from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets -from pyknos.nflows.transforms.splines import ( - rational_quadratic, -) +from pyknos.nflows.transforms.splines import rational_quadratic_spline from torch import Tensor, nn, relu, tanh, tensor, uint8 from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow @@ -214,9 +212,9 @@ def build_maf_rqs( tail_bound: float = 3.0, dropout_probability: float = 0.0, use_batch_norm: bool = False, - min_bin_width: float = rational_quadratic.DEFAULT_MIN_BIN_WIDTH, - min_bin_height: float = rational_quadratic.DEFAULT_MIN_BIN_HEIGHT, - min_derivative: float = rational_quadratic.DEFAULT_MIN_DERIVATIVE, + min_bin_width: float = rational_quadratic_spline.DEFAULT_MIN_BIN_WIDTH, + min_bin_height: float = rational_quadratic_spline.DEFAULT_MIN_BIN_HEIGHT, + min_derivative: float = rational_quadratic_spline.DEFAULT_MIN_DERIVATIVE, **kwargs, ) -> NFlowsFlow: """Builds MAF p(x|y), where the diffeomorphisms are rational-quadratic From f7f8cd16502ca7f9acf215161cbb3b40f4a54212 Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:36:27 +0200 Subject: [PATCH 07/14] merge --- sbi/neural_nets/flow.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 88b93ba96..e267124f3 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -5,12 +5,18 @@ from typing import List, Optional, Sequence, Tuple, Union from warnings import warn +from functools import partial +from typing import List, Optional, Sequence, Tuple, Union +from warnings import warn + import torch import zuko from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets -from pyknos.nflows.transforms.splines import rational_quadratic_spline +from pyknos.nflows.transforms.splines import ( + rational_quadratic, # pyright: ignore[reportAttributeAccessIssue] +) from torch import Tensor, nn, relu, tanh, tensor, uint8 from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow From 5405203d60f8d8b478a91cb10762143bcb1dc17c Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:40:55 +0200 Subject: [PATCH 08/14] MERGE --- sbi/neural_nets/flow.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index e267124f3..1bbb5e477 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -5,17 +5,13 @@ from typing import List, Optional, Sequence, Tuple, Union from warnings import warn -from functools import partial -from typing import List, Optional, Sequence, Tuple, Union -from warnings import warn - import torch import zuko from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets from pyknos.nflows.transforms.splines import ( - rational_quadratic, # pyright: ignore[reportAttributeAccessIssue] + rational_quadratic, # pyright: ignore[reportAttributeAccessIssue] ) from torch import Tensor, nn, relu, tanh, tensor, uint8 @@ -218,9 +214,9 @@ def build_maf_rqs( tail_bound: float = 3.0, dropout_probability: float = 0.0, use_batch_norm: bool = False, - min_bin_width: float = rational_quadratic_spline.DEFAULT_MIN_BIN_WIDTH, - min_bin_height: float = rational_quadratic_spline.DEFAULT_MIN_BIN_HEIGHT, - min_derivative: float = rational_quadratic_spline.DEFAULT_MIN_DERIVATIVE, + min_bin_width: float = rational_quadratic.DEFAULT_MIN_BIN_WIDTH, + min_bin_height: float = rational_quadratic.DEFAULT_MIN_BIN_HEIGHT, + min_derivative: float = rational_quadratic.DEFAULT_MIN_DERIVATIVE, **kwargs, ) -> NFlowsFlow: """Builds MAF p(x|y), where the diffeomorphisms are rational-quadratic From 41424ac5e37fef577f480931020a554b0833a877 Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:51:45 +0200 Subject: [PATCH 09/14] remove cnf --- sbi/neural_nets/factory.py | 23 ++++++++++++++++++----- tests/density_estimator_test.py | 2 -- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index d1eae3f84..f970e7414 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -17,7 +17,6 @@ build_maf_rqs, build_nsf, build_zuko_bpf, - build_zuko_cnf, build_zuko_gf, build_zuko_maf, build_zuko_naf, @@ -30,6 +29,24 @@ from sbi.neural_nets.mdn import build_mdn from sbi.neural_nets.mnle import build_mnle +model_builders = { + "mdn": build_mdn, + "made": build_made, + "maf": build_maf, + "maf_rqs": build_maf_rqs, + "nsf": build_nsf, + "mnle": build_mnle, + "zuko_nice": build_zuko_nice, + "zuko_maf": build_zuko_maf, + "zuko_nsf": build_zuko_nsf, + "zuko_ncsf": build_zuko_ncsf, + "zuko_sospf": build_zuko_sospf, + "zuko_naf": build_zuko_naf, + "zuko_unaf": build_zuko_unaf, + "zuko_gf": build_zuko_gf, + "zuko_bpf": build_zuko_bpf, +} + def classifier_nn( model: str, @@ -197,8 +214,6 @@ def build_fn(batch_theta, batch_x): return build_zuko_naf(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "zuko_unaf": return build_zuko_unaf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_cnf": - return build_zuko_cnf(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "zuko_gf": return build_zuko_gf(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "zuko_bpf": @@ -321,8 +336,6 @@ def build_fn(batch_theta, batch_x): return build_zuko_naf(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "zuko_unaf": return build_zuko_unaf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_cnf": - return build_zuko_cnf(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "zuko_gf": return build_zuko_gf(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "zuko_bpf": diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index cc7c97f7a..8bd8bbb27 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -16,7 +16,6 @@ build_maf_rqs, build_nsf, build_zuko_bpf, - build_zuko_cnf, build_zuko_gf, build_zuko_maf, build_zuko_naf, @@ -74,7 +73,6 @@ def get_batch_context(nsamples: int, condition_shape: tuple[int, ...]) -> torch. build_zuko_sospf, build_zuko_naf, build_zuko_unaf, - build_zuko_cnf, build_zuko_gf, build_zuko_bpf, ), From 327592219a7c77d099e60b4d9a7d92c43a805501 Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 21:09:35 +0200 Subject: [PATCH 10/14] implement changes Jan --- sbi/neural_nets/factory.py | 76 ++++------------------------ tests/neural_nets_factory.py | 98 ++++++++---------------------------- 2 files changed, 30 insertions(+), 144 deletions(-) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index f970e7414..40e9006ff 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -188,38 +188,10 @@ def likelihood_nn( ) def build_fn(batch_theta, batch_x): - if model == "mdn": - return build_mdn(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "made": - return build_made(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "maf": - return build_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "maf_rqs": - return build_maf_rqs(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "nsf": - return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "mnle": - return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_nice": - return build_zuko_nice(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_maf": - return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_nsf": - return build_zuko_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_ncsf": - return build_zuko_ncsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_sospf": - return build_zuko_sospf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_naf": - return build_zuko_naf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_unaf": - return build_zuko_unaf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_gf": - return build_zuko_gf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_bpf": - return build_zuko_bpf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - else: - raise NotImplementedError + if model not in model_builders: + raise NotImplementedError("Model {model} in not implemented") + + return model_builders[model](batch_x=batch_x, batch_y=batch_theta, **kwargs) return build_fn @@ -307,42 +279,14 @@ def build_fn_snpe_a(batch_theta, batch_x, num_components): ) def build_fn(batch_theta, batch_x): - if model == "mdn": - # The naming might be a bit confusing. - # batch_x are the latent variables, batch_y the conditioned variables. - # batch_theta are the parameters and batch_x the observable variables. - return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "made": - return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "maf": - return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "maf_rqs": - return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "nsf": - return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "mnle": - return build_mnle(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_nice": - return build_zuko_nice(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_maf": - return build_zuko_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_nsf": - return build_zuko_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_ncsf": - return build_zuko_ncsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_sospf": - return build_zuko_sospf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_naf": - return build_zuko_naf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_unaf": - return build_zuko_unaf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_gf": - return build_zuko_gf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_bpf": - return build_zuko_bpf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - else: + if model not in model_builders: raise NotImplementedError + # The naming might be a bit confusing. + # batch_x are the latent variables, batch_y the conditioned variables. + # batch_theta are the parameters and batch_x the observable variables. + return model_builders[model](batch_x=batch_theta, batch_y=batch_x, **kwargs) + if model == "mdn_snpe_a": if num_components != 10: raise ValueError( diff --git a/tests/neural_nets_factory.py b/tests/neural_nets_factory.py index 6e6d8751f..af8c4d4a1 100644 --- a/tests/neural_nets_factory.py +++ b/tests/neural_nets_factory.py @@ -2,6 +2,24 @@ from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn +models_to_test = [ + "mdn", + "made", + "maf", + "maf_rqs", + "nsf", + "mnle", + "zuko_bpf", + "zuko_gf", + "zuko_maf", + "zuko_naf", + "zuko_ncsf", + "zuko_nice", + "zuko_nsf", + "zuko_sospf", + "zuko_unaf", +] + @pytest.mark.parametrize( "model", ["linear", "mlp", "resnet"], ids=["linear", "mlp", "resnet"] @@ -12,90 +30,14 @@ def test_deprecated_import_classifier_nn(model: str): assert callable(build_fcn) -@pytest.mark.parametrize( - "model", - [ - "mdn", - "made", - "maf", - "maf_rqs", - "nsf", - "mnle", - "zuko_bpf", - "zuko_cnf", - "zuko_gf", - "zuko_maf", - "zuko_naf", - "zuko_ncsf", - "zuko_nice", - "zuko_nsf", - "zuko_sospf", - "zuko_unaf", - ], - ids=[ - "mdn", - "made", - "maf", - "maf_rqs", - "nsf", - "mnle", - "zuko_bpf", - "zuko_cnf", - "zuko_gf", - "zuko_maf", - "zuko_naf", - "zuko_ncsf", - "zuko_nice", - "zuko_nsf", - "zuko_sospf", - "zuko_unaf", - ], -) +@pytest.mark.parametrize("model", models_to_test, ids=models_to_test) def test_deprecated_import_likelihood_nn(model: str): with pytest.warns(DeprecationWarning): build_fcn = likelihood_nn(model) assert callable(build_fcn) -@pytest.mark.parametrize( - "model", - [ - "mdn", - "made", - "maf", - "maf_rqs", - "nsf", - "mnle", - "zuko_bpf", - "zuko_cnf", - "zuko_gf", - "zuko_maf", - "zuko_naf", - "zuko_ncsf", - "zuko_nice", - "zuko_nsf", - "zuko_sospf", - "zuko_unaf", - ], - ids=[ - "mdn", - "made", - "maf", - "maf_rqs", - "nsf", - "mnle", - "zuko_bpf", - "zuko_cnf", - "zuko_gf", - "zuko_maf", - "zuko_naf", - "zuko_ncsf", - "zuko_nice", - "zuko_nsf", - "zuko_sospf", - "zuko_unaf", - ], -) +@pytest.mark.parametrize("model", models_to_test, ids=models_to_test) def test_deprecated_import_posterior_nn(model: str): with pytest.warns(DeprecationWarning): build_fcn = posterior_nn(model) From 98dfdc2efcc27b9be05ca884fd9cef95b915b16e Mon Sep 17 00:00:00 2001 From: Nastya Krouglova <41705732+anastasiakrouglova@users.noreply.github.com> Date: Thu, 4 Apr 2024 15:54:03 +0200 Subject: [PATCH 11/14] Update sbi/neural_nets/factory.py Co-authored-by: Jan --- sbi/neural_nets/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index 40e9006ff..87168ff07 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -189,7 +189,7 @@ def likelihood_nn( def build_fn(batch_theta, batch_x): if model not in model_builders: - raise NotImplementedError("Model {model} in not implemented") + raise NotImplementedError(f"Model {model} in not implemented") return model_builders[model](batch_x=batch_x, batch_y=batch_theta, **kwargs) From 8375c6915a89da7c7753587aa7a544f061be5d8e Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Thu, 4 Apr 2024 16:16:30 +0200 Subject: [PATCH 12/14] resolve issues Jan --- sbi/neural_nets/factory.py | 2 +- tutorials/05_embedding_net.ipynb | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index 87168ff07..5bddcc0f5 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -280,7 +280,7 @@ def build_fn_snpe_a(batch_theta, batch_x, num_components): def build_fn(batch_theta, batch_x): if model not in model_builders: - raise NotImplementedError + raise NotImplementedError(f"Model {model} in not implemented") # The naming might be a bit confusing. # batch_x are the latent variables, batch_y the conditioned variables. diff --git a/tutorials/05_embedding_net.ipynb b/tutorials/05_embedding_net.ipynb index 0bf13d5d4..9a3e7ff64 100644 --- a/tutorials/05_embedding_net.ipynb +++ b/tutorials/05_embedding_net.ipynb @@ -460,10 +460,6 @@ " x = F.relu(self.fc(x))\n", " return x\n", "\n", -<<<<<<< HEAD - "# instantiate the custom embedding_net\n", - "embedding_net_custom = SummaryNet()" -======= "\n", "embedding_net = SummaryNet()" ] @@ -624,7 +620,6 @@ " points_offdiag={\"markersize\": 6},\n", " figsize=(5, 5),\n", ")" ->>>>>>> bf65102c (Zuko density estimators (#1088)) ] } ], From f88edc541a51c057bb129cd1f5b5dd096e5ab1dc Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Fri, 5 Apr 2024 11:31:03 +0200 Subject: [PATCH 13/14] undo changes to tutorials folder. --- tutorials/00_getting_started.ipynb | 5 +- tutorials/05_embedding_net.ipynb | 162 +----------------- tutorials/07_conditional_distributions.ipynb | 2 +- .../17_importance_sampled_posteriors.ipynb | 28 +-- 4 files changed, 18 insertions(+), 179 deletions(-) diff --git a/tutorials/00_getting_started.ipynb b/tutorials/00_getting_started.ipynb index 507dc4207..d9b48e85f 100644 --- a/tutorials/00_getting_started.ipynb +++ b/tutorials/00_getting_started.ipynb @@ -98,10 +98,7 @@ "source": [ "# Other methods are \"SNLE\" or \"SNRE\".\n", "posterior = infer(simulator, prior, method=\"SNPE\", num_simulations=1000)\n", - "# Using `init_kwargs`, `train_kwargs` and `build_posterior_kwargs`,\n", - "# you can also pass additional keyword arguments to `__init__`, `train` and\n", - "# `build_posterior` of the inference method. But we recommend to use the\n", - "# flexible interface which is introduced in a later tutorial." + "# Using `init_kwargs`, `train_kwargs` and `build_posterior_kwargs`, you can also pass additional keyword arguments to `__init__`, `train` and `build_posterior` of the inference method. But we recommend to use the flexible interface which is introduced in a later tutorial." ] }, { diff --git a/tutorials/05_embedding_net.ipynb b/tutorials/05_embedding_net.ipynb index 9a3e7ff64..78bae6598 100644 --- a/tutorials/05_embedding_net.ipynb +++ b/tutorials/05_embedding_net.ipynb @@ -460,166 +460,8 @@ " x = F.relu(self.fc(x))\n", " return x\n", "\n", - "\n", - "embedding_net = SummaryNet()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## The inference procedure\n", - "\n", - "With the `embedding_net` defined and instantiated, we can follow the usual workflow of an inference procedure in `sbi`. The `embedding_net` object appears as an input argument when instantiating the neural density estimator with `sbi.neural_nets.posterior_nn`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# set prior distribution for the parameters\n", - "prior = utils.BoxUniform(\n", - " low=torch.tensor([0.0, 0.0]), high=torch.tensor([1.0, 2 * torch.pi])\n", - ")\n", - "\n", - "# make a SBI-wrapper on the simulator object for compatibility\n", - "prior, num_parameters, prior_returns_numpy = process_prior(prior)\n", - "simulator_wrapper = process_simulator(simulator_model, prior, prior_returns_numpy)\n", - "check_sbi_inputs(simulator_wrapper, prior)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "from sbi.neural_nets import posterior_nn\n", - "\n", - "# instantiate the neural density estimator\n", - "neural_posterior = posterior_nn(\n", - " model=\"maf\", embedding_net=embedding_net, hidden_features=10, num_transforms=2\n", - ")\n", - "\n", - "# setup the inference procedure with the SNPE-C procedure\n", - "inferer = SNPE(prior=prior, density_estimator=neural_posterior)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a073de8ca9094ec9a4909f74ba837bfb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Running 10000 simulations.: 0%| | 0/10000 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# create the figure\n", - "fig, ax = analysis.pairplot(\n", - " samples,\n", - " points=true_parameter,\n", - " labels=[\"r\", r\"$\\theta$\"],\n", - " limits=[[0, 1], [0, 2 * torch.pi]],\n", - " points_colors=\"r\",\n", - " points_offdiag={\"markersize\": 6},\n", - " figsize=(5, 5),\n", - ")" + "# instantiate the custom embedding_net\n", + "embedding_net_custom = SummaryNet()" ] } ], diff --git a/tutorials/07_conditional_distributions.ipynb b/tutorials/07_conditional_distributions.ipynb index 9e718ca8b..a98549f88 100644 --- a/tutorials/07_conditional_distributions.ipynb +++ b/tutorials/07_conditional_distributions.ipynb @@ -26124,7 +26124,7 @@ "name": "stdout", "output_type": "stream", "text": [ - " Neural network successfully converged after 65 epochs." + " Neural network successfully converged after 118 epochs." ] } ], diff --git a/tutorials/17_importance_sampled_posteriors.ipynb b/tutorials/17_importance_sampled_posteriors.ipynb index 7c17ebf5b..f23f92d8d 100644 --- a/tutorials/17_importance_sampled_posteriors.ipynb +++ b/tutorials/17_importance_sampled_posteriors.ipynb @@ -162,7 +162,7 @@ "class Simulator:\n", " def __init__(self):\n", " pass\n", - "\n", + " \n", " def log_likelihood(self, theta, x):\n", " return MultivariateNormal(theta, eye(2)).log_prob(x)\n", "\n", @@ -312,7 +312,7 @@ "source": [ "# get weighted samples\n", "theta_inferred_is = theta_inferred[torch.where(w > torch.rand(len(w)) * torch.max(w))]\n", - "# *Note*: we here perform rejection sampling, as the plotting function\n", + "# *Note*: we here perform rejection sampling, as the plotting function \n", "# used below does not support weighted samples. In general, with rejection\n", "# sampling the number of samples will be smaller than the effective sample\n", "# size unless we allow for duplicate samples.\n", @@ -323,8 +323,8 @@ "\n", "# plot\n", "fig, ax = marginal_plot(\n", - " [theta_inferred, theta_inferred_is, gt_samples],\n", - " limits=[[-5, 5], [-5, 5]],\n", + " [theta_inferred, theta_inferred_is, gt_samples], \n", + " limits=[[-5, 5], [-5, 5]], \n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", ")\n", @@ -22243,8 +22243,8 @@ ], "source": [ "fig, ax = marginal_plot(\n", - " [theta_inferred_sir_2, theta_inferred_sir_32, gt_samples],\n", - " limits=[[-5, 5], [-5, 5]],\n", + " [theta_inferred_sir_2, theta_inferred_sir_32, gt_samples], \n", + " limits=[[-5, 5], [-5, 5]], \n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", ")\n", @@ -22280,8 +22280,8 @@ ], "source": [ "fig, ax = marginal_plot(\n", - " [gt_samples, theta_inferred],\n", - " limits=[[-5, 5], [-5, 5]],\n", + " [gt_samples, theta_inferred], \n", + " limits=[[-5, 5], [-5, 5]], \n", " weights=[None, w],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", @@ -22400,9 +22400,9 @@ "\n", "for i in range(len(observations)):\n", " fig, ax = marginal_plot(\n", - " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],\n", - " limits=[[-5, 5], [-5, 5]],\n", - " points=theta_gt[i],\n", + " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], \n", + " limits=[[-5, 5], [-5, 5]], \n", + " points=theta_gt[i], \n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", " )\n", @@ -23967,9 +23967,9 @@ "\n", "for i in range(len(observations)):\n", " fig, ax = marginal_plot(\n", - " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],\n", - " limits=[[-5, 5], [-5, 5]],\n", - " points=theta_gt[i],\n", + " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], \n", + " limits=[[-5, 5], [-5, 5]], \n", + " points=theta_gt[i], \n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", " )\n", From 583671dfd7c77292553374be8d839348068825ab Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Fri, 5 Apr 2024 11:32:31 +0200 Subject: [PATCH 14/14] sort dependencies. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d03e50101..027a83f08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,8 @@ dependencies = [ "tensorboard", "torch>=1.8.0", "tqdm", - "zuko>=1.1.0", "pymc>=5.0.0", + "zuko>=1.1.0", ] [project.optional-dependencies]