Skip to content

Commit

Permalink
fix handling of multi dimensional x.
Browse files Browse the repository at this point in the history
* add assert to SNLE and SNRE that multiD x does not work.
* add method for SNPE to check for appropriate embedding net and give debug hint.
* adapt tests.
  • Loading branch information
janfb committed Oct 8, 2020
1 parent b7d520a commit c0579a3
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 35 deletions.
7 changes: 4 additions & 3 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Union

import numpy as np
import torch
from torch import Tensor, nn, optim
from torch import optim
from torch.nn.utils import clip_grad_norm_
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler
Expand All @@ -17,7 +16,6 @@
from sbi import utils as utils
from sbi.inference import NeuralInference
from sbi.inference.posteriors.likelihood_based_posterior import LikelihoodBasedPosterior
from sbi.types import ScalarFloat
from sbi.utils import check_estimator_arg, x_shape_from_simulation


Expand Down Expand Up @@ -149,6 +147,9 @@ def __call__(
# can `sample()` and `log_prob()`. The network is accessible via `.net`.
if self._posterior is None or retrain_from_scratch_each_round:
x_shape = x_shape_from_simulation(x)
assert (
len(x_shape) < 3
), "SNLE cannot handle multi-dimensional simulator output."
self._posterior = LikelihoodBasedPosterior(
method_family="snle",
neural_net=self._build_neural_net(theta, x),
Expand Down
14 changes: 8 additions & 6 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@

from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union, cast
from warnings import warn
from typing import Any, Callable, Dict, Optional, Union, cast

import numpy as np
import torch
from torch import Tensor, nn, ones, optim
from torch import Tensor, ones, optim
from torch.nn.utils import clip_grad_norm_
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler
Expand All @@ -18,8 +16,11 @@
from sbi import utils as utils
from sbi.inference import NeuralInference
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.types import ScalarFloat
from sbi.utils import check_estimator_arg, x_shape_from_simulation
from sbi.utils import (
check_estimator_arg,
x_shape_from_simulation,
test_posterior_net_for_multi_d_x,
)


class PosteriorEstimator(NeuralInference, ABC):
Expand Down Expand Up @@ -197,6 +198,7 @@ def __call__(
mcmc_parameters=self._mcmc_parameters,
rejection_sampling_parameters=self._rejection_sampling_parameters,
)
test_posterior_net_for_multi_d_x(self._posterior.net, theta, x)

# Copy MCMC init parameters for latest sample init
if hasattr(proposal, "_mcmc_init_params"):
Expand Down
14 changes: 9 additions & 5 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Union

import numpy as np
import torch
from torch import Tensor, eye, nn, ones, optim
from torch import Tensor, eye, ones, optim
from torch.nn.utils import clip_grad_norm_
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler
Expand All @@ -13,9 +12,11 @@
from sbi import utils as utils
from sbi.inference.base import NeuralInference
from sbi.inference.posteriors.ratio_based_posterior import RatioBasedPosterior
from sbi.types import ScalarFloat
from sbi.utils import check_estimator_arg, clamp_and_warn, x_shape_from_simulation
from sbi.utils.torchutils import ensure_theta_batched, ensure_x_batched
from sbi.utils import (
check_estimator_arg,
clamp_and_warn,
x_shape_from_simulation,
)


class RatioEstimator(NeuralInference, ABC):
Expand Down Expand Up @@ -155,6 +156,9 @@ def __call__(
# can `sample()` and `log_prob()`. The network is accessible via `.net`.
if self._posterior is None or retrain_from_scratch_each_round:
x_shape = x_shape_from_simulation(x)
assert (
len(x_shape) < 3
), "For now, SNRE cannot handle multi-dimensional simulator output, see issue #360."
self._posterior = RatioBasedPosterior(
method_family=self.__class__.__name__.lower(),
neural_net=self._build_neural_net(theta, x),
Expand Down
25 changes: 25 additions & 0 deletions sbi/user_input/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,28 @@ def check_estimator_arg(estimator: Union[str, Callable]) -> None:
"The passed density estimator / classifier must be a string or a function "
f"returning a nn.Module, but is {type(estimator)}"
)


def test_posterior_net_for_multi_d_x(net: nn.Module, theta: Tensor, x: Tensor) -> None:
"""Test log prob method of the net.
This is done to make sure the net can handle multidimensional inputs via an
embedding net. If not, it usually fails with a RuntimeError. Here we catch the
error, append a debug hint and raise it again.
"""

try:
# torch.nn.functional needs at least two inputs here.
net.log_prob(theta[:2], x[:2])
except RuntimeError as rte:
ndims = x.ndim
if ndims > 2:
message = f"""Debug hint: The simulated data x has {ndims-1} dimensions.
With default settings, sbi cannot deal with multidimensional simulations.
Make sure to use an embedding net that reduces the dimensionality, e.g., a
CNN in case of images, or change the simulator to return one-dimensional x.
"""
else:
message = ""

raise RuntimeError(rte, message)
5 changes: 4 additions & 1 deletion sbi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# flake8: noqa
from sbi.user_input.user_input_checks import check_estimator_arg
from sbi.user_input.user_input_checks import (
check_estimator_arg,
test_posterior_net_for_multi_d_x,
)
from sbi.user_input.user_input_checks_utils import MultipleIndependent
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.io import get_data_root, get_log_root, get_project_root
Expand Down
9 changes: 0 additions & 9 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,12 @@
from torch import nn as nn
from torch import ones, zeros
from tqdm.auto import tqdm
import warnings


def x_shape_from_simulation(batch_x: Tensor) -> torch.Size:
ndims = batch_x.ndim
assert ndims >= 2, "Simulated data must be a batch with at least two dimensions."

# Warn in case of multi-dimensional x.
if ndims > 2:
warnings.warn(
f"""The simulated data x has {ndims-1} dimensions. With default settings,
sbi cannot deal with multidimensional simulations. Make sure to use an
embedding net that reduces the dimensionality, e.g., a CNN in case of
images, or change the simulator to return one-dimensional x."""
)
return batch_x[0].unsqueeze(0).shape


Expand Down
36 changes: 25 additions & 11 deletions tests/multidimensional_x_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import pytest
from sbi import utils as utils
from sbi.inference import SNPE, prepare_for_sbi
from sbi.inference import SNPE, SNLE, SNRE, prepare_for_sbi
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -39,13 +39,25 @@ def forward(self, x):


@pytest.mark.parametrize(
"embedding",
"embedding, method",
(
pytest.param(nn.Identity, marks=pytest.mark.xfail(reason="Invalid embedding.")),
pytest.param(CNNEmbedding),
pytest.param(
nn.Identity, SNPE, marks=pytest.mark.xfail(reason="Invalid embedding.")
),
pytest.param(
nn.Identity,
SNRE,
marks=pytest.mark.xfail(reason="SNLE cannot handle multiD x."),
),
pytest.param(
CNNEmbedding,
SNLE,
marks=pytest.mark.xfail(reason="SNRE cannot handle multiD x."),
),
pytest.param(CNNEmbedding, SNPE),
),
)
def test_inference_with_2d_x(embedding):
def test_inference_with_2d_x(embedding, method):

num_dim = 2
num_samples = 10
Expand All @@ -58,12 +70,14 @@ def test_inference_with_2d_x(embedding):
theta_o = torch.ones(1, num_dim)
x_o = simulator(theta_o)

infer = SNPE(
simulator,
prior,
show_progress_bars=False,
density_estimator=utils.posterior_nn(model="mdn", embedding_net=embedding(),),
)
if method == SNPE:
net_provider = utils.posterior_nn(model="mdn", embedding_net=embedding(),)
elif method == SNLE:
net_provider = utils.likelihood_nn(model="mdn", embedding_net=embedding())
else:
net_provider = utils.classifier_nn(model="mlp", embedding_net_x=embedding(),)

infer = method(simulator, prior, 1, 1, net_provider, show_progress_bars=False,)

posterior = infer(
num_simulations=num_simulations, training_batch_size=100, max_num_epochs=10
Expand Down

0 comments on commit c0579a3

Please sign in to comment.