From c0579a3cc4e8b2415be12d1e7df01a0703650a40 Mon Sep 17 00:00:00 2001 From: janfb Date: Wed, 7 Oct 2020 17:08:44 +0200 Subject: [PATCH] fix handling of multi dimensional x. * add assert to SNLE and SNRE that multiD x does not work. * add method for SNPE to check for appropriate embedding net and give debug hint. * adapt tests. --- sbi/inference/snle/snle_base.py | 7 +++--- sbi/inference/snpe/snpe_base.py | 14 ++++++----- sbi/inference/snre/snre_base.py | 14 +++++++---- sbi/user_input/user_input_checks.py | 25 ++++++++++++++++++++ sbi/utils/__init__.py | 5 +++- sbi/utils/sbiutils.py | 9 -------- tests/multidimensional_x_test.py | 36 ++++++++++++++++++++--------- 7 files changed, 75 insertions(+), 35 deletions(-) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 81eb7937b..b9ff6f521 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -6,9 +6,8 @@ from copy import deepcopy from typing import Any, Callable, Dict, Optional, Union -import numpy as np import torch -from torch import Tensor, nn, optim +from torch import optim from torch.nn.utils import clip_grad_norm_ from torch.utils import data from torch.utils.data.sampler import SubsetRandomSampler @@ -17,7 +16,6 @@ from sbi import utils as utils from sbi.inference import NeuralInference from sbi.inference.posteriors.likelihood_based_posterior import LikelihoodBasedPosterior -from sbi.types import ScalarFloat from sbi.utils import check_estimator_arg, x_shape_from_simulation @@ -149,6 +147,9 @@ def __call__( # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._posterior is None or retrain_from_scratch_each_round: x_shape = x_shape_from_simulation(x) + assert ( + len(x_shape) < 3 + ), "SNLE cannot handle multi-dimensional simulator output." self._posterior = LikelihoodBasedPosterior( method_family="snle", neural_net=self._build_neural_net(theta, x), diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index b6fb2bdcf..eca5a4f6b 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -4,12 +4,10 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Union, cast -from warnings import warn +from typing import Any, Callable, Dict, Optional, Union, cast -import numpy as np import torch -from torch import Tensor, nn, ones, optim +from torch import Tensor, ones, optim from torch.nn.utils import clip_grad_norm_ from torch.utils import data from torch.utils.data.sampler import SubsetRandomSampler @@ -18,8 +16,11 @@ from sbi import utils as utils from sbi.inference import NeuralInference from sbi.inference.posteriors.direct_posterior import DirectPosterior -from sbi.types import ScalarFloat -from sbi.utils import check_estimator_arg, x_shape_from_simulation +from sbi.utils import ( + check_estimator_arg, + x_shape_from_simulation, + test_posterior_net_for_multi_d_x, +) class PosteriorEstimator(NeuralInference, ABC): @@ -197,6 +198,7 @@ def __call__( mcmc_parameters=self._mcmc_parameters, rejection_sampling_parameters=self._rejection_sampling_parameters, ) + test_posterior_net_for_multi_d_x(self._posterior.net, theta, x) # Copy MCMC init parameters for latest sample init if hasattr(proposal, "_mcmc_init_params"): diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index 3241fb90b..97fab609a 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -2,9 +2,8 @@ from copy import deepcopy from typing import Any, Callable, Dict, Optional, Union -import numpy as np import torch -from torch import Tensor, eye, nn, ones, optim +from torch import Tensor, eye, ones, optim from torch.nn.utils import clip_grad_norm_ from torch.utils import data from torch.utils.data.sampler import SubsetRandomSampler @@ -13,9 +12,11 @@ from sbi import utils as utils from sbi.inference.base import NeuralInference from sbi.inference.posteriors.ratio_based_posterior import RatioBasedPosterior -from sbi.types import ScalarFloat -from sbi.utils import check_estimator_arg, clamp_and_warn, x_shape_from_simulation -from sbi.utils.torchutils import ensure_theta_batched, ensure_x_batched +from sbi.utils import ( + check_estimator_arg, + clamp_and_warn, + x_shape_from_simulation, +) class RatioEstimator(NeuralInference, ABC): @@ -155,6 +156,9 @@ def __call__( # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._posterior is None or retrain_from_scratch_each_round: x_shape = x_shape_from_simulation(x) + assert ( + len(x_shape) < 3 + ), "For now, SNRE cannot handle multi-dimensional simulator output, see issue #360." self._posterior = RatioBasedPosterior( method_family=self.__class__.__name__.lower(), neural_net=self._build_neural_net(theta, x), diff --git a/sbi/user_input/user_input_checks.py b/sbi/user_input/user_input_checks.py index 76662171d..313ba78e3 100644 --- a/sbi/user_input/user_input_checks.py +++ b/sbi/user_input/user_input_checks.py @@ -527,3 +527,28 @@ def check_estimator_arg(estimator: Union[str, Callable]) -> None: "The passed density estimator / classifier must be a string or a function " f"returning a nn.Module, but is {type(estimator)}" ) + + +def test_posterior_net_for_multi_d_x(net: nn.Module, theta: Tensor, x: Tensor) -> None: + """Test log prob method of the net. + + This is done to make sure the net can handle multidimensional inputs via an + embedding net. If not, it usually fails with a RuntimeError. Here we catch the + error, append a debug hint and raise it again. + """ + + try: + # torch.nn.functional needs at least two inputs here. + net.log_prob(theta[:2], x[:2]) + except RuntimeError as rte: + ndims = x.ndim + if ndims > 2: + message = f"""Debug hint: The simulated data x has {ndims-1} dimensions. + With default settings, sbi cannot deal with multidimensional simulations. + Make sure to use an embedding net that reduces the dimensionality, e.g., a + CNN in case of images, or change the simulator to return one-dimensional x. + """ + else: + message = "" + + raise RuntimeError(rte, message) diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py index 192d6866e..04ef82bf1 100644 --- a/sbi/utils/__init__.py +++ b/sbi/utils/__init__.py @@ -1,5 +1,8 @@ # flake8: noqa -from sbi.user_input.user_input_checks import check_estimator_arg +from sbi.user_input.user_input_checks import ( + check_estimator_arg, + test_posterior_net_for_multi_d_x, +) from sbi.user_input.user_input_checks_utils import MultipleIndependent from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn from sbi.utils.io import get_data_root, get_log_root, get_project_root diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index e46271770..65069407b 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -17,21 +17,12 @@ from torch import nn as nn from torch import ones, zeros from tqdm.auto import tqdm -import warnings def x_shape_from_simulation(batch_x: Tensor) -> torch.Size: ndims = batch_x.ndim assert ndims >= 2, "Simulated data must be a batch with at least two dimensions." - # Warn in case of multi-dimensional x. - if ndims > 2: - warnings.warn( - f"""The simulated data x has {ndims-1} dimensions. With default settings, - sbi cannot deal with multidimensional simulations. Make sure to use an - embedding net that reduces the dimensionality, e.g., a CNN in case of - images, or change the simulator to return one-dimensional x.""" - ) return batch_x[0].unsqueeze(0).shape diff --git a/tests/multidimensional_x_test.py b/tests/multidimensional_x_test.py index 6f5469763..866122d39 100644 --- a/tests/multidimensional_x_test.py +++ b/tests/multidimensional_x_test.py @@ -4,7 +4,7 @@ import torch import pytest from sbi import utils as utils -from sbi.inference import SNPE, prepare_for_sbi +from sbi.inference import SNPE, SNLE, SNRE, prepare_for_sbi import torch.nn as nn import torch.nn.functional as F @@ -39,13 +39,25 @@ def forward(self, x): @pytest.mark.parametrize( - "embedding", + "embedding, method", ( - pytest.param(nn.Identity, marks=pytest.mark.xfail(reason="Invalid embedding.")), - pytest.param(CNNEmbedding), + pytest.param( + nn.Identity, SNPE, marks=pytest.mark.xfail(reason="Invalid embedding.") + ), + pytest.param( + nn.Identity, + SNRE, + marks=pytest.mark.xfail(reason="SNLE cannot handle multiD x."), + ), + pytest.param( + CNNEmbedding, + SNLE, + marks=pytest.mark.xfail(reason="SNRE cannot handle multiD x."), + ), + pytest.param(CNNEmbedding, SNPE), ), ) -def test_inference_with_2d_x(embedding): +def test_inference_with_2d_x(embedding, method): num_dim = 2 num_samples = 10 @@ -58,12 +70,14 @@ def test_inference_with_2d_x(embedding): theta_o = torch.ones(1, num_dim) x_o = simulator(theta_o) - infer = SNPE( - simulator, - prior, - show_progress_bars=False, - density_estimator=utils.posterior_nn(model="mdn", embedding_net=embedding(),), - ) + if method == SNPE: + net_provider = utils.posterior_nn(model="mdn", embedding_net=embedding(),) + elif method == SNLE: + net_provider = utils.likelihood_nn(model="mdn", embedding_net=embedding()) + else: + net_provider = utils.classifier_nn(model="mlp", embedding_net_x=embedding(),) + + infer = method(simulator, prior, 1, 1, net_provider, show_progress_bars=False,) posterior = infer( num_simulations=num_simulations, training_batch_size=100, max_num_epochs=10