Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap 1D PyTorch distributions #1286

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sbi.utils.user_input_checks_utils import (
CustomPriorWrapper,
MultipleIndependent,
OneDimPriorWrapper,
PytorchReturnTypeWrapper,
)

Expand Down Expand Up @@ -220,6 +221,11 @@ def process_pytorch_prior(prior: Distribution) -> Tuple[Distribution, int, bool]
# This will fail for float64 priors.
check_prior_return_type(prior)

# Potentially required wrapper if the prior returns an additional sample dimension
# for `.log_prob()`.
if prior.log_prob(prior.sample(torch.Size((10,)))).shape == torch.Size([10, 1]):
prior = OneDimPriorWrapper(prior, validate_args=False)

theta_numel = prior.sample().numel()

return prior, theta_numel, False
Expand Down
66 changes: 66 additions & 0 deletions sbi/utils/user_input_checks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,69 @@
support = constraints.interval(lower_bound, upper_bound)

return support


class OneDimPriorWrapper(Distribution):
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
"""Wrap batched 1D distributions to get rid of the batch dim of `.log_prob()`.

1D pytorch distributions such as `torch.distributions.Exponential`, `.Uniform`, or
`.Normal` do not, by default return __any__ sample or batch dimension. E.g.:
```python
dist = torch.distributions.Exponential(torch.tensor(3.0))
dist.sample((10,)).shape # (10,)
```

`sbi` will raise an error that the sample dimension is missing. A simple solution is
to add a batch dimension to `dist` as follows:
```python
dist = torch.distributions.Exponential(torch.tensor([3.0]))
dist.sample((10,)).shape # (10, 1)
```

Unfortunately, this `dist` will return the batch dimension also for `.log_prob():
```python
dist = torch.distributions.Exponential(torch.tensor([3.0]))
samples = dist.sample((10,))
dist.log_prob(samples).shape # (10, 1)
```

This will lead to unexpected errors in `sbi`. The point of this class is to wrap
those batched 1D distributions to get rid of their batch dimension in `.log_prob()`.
"""

def __init__(
self,
prior: Distribution,
validate_args=None,
) -> None:
super().__init__(
batch_shape=prior.batch_shape,
event_shape=prior.event_shape,
validate_args=(
prior._validate_args if validate_args is None else validate_args
),
)
self.prior = prior

def sample(self, *args, **kwargs) -> Tensor:
return self.prior.sample(*args, **kwargs)

def log_prob(self, *args, **kwargs) -> Tensor:
"""Override the log_prob method to get rid of the additional batch dimension."""
return self.prior.log_prob(*args, **kwargs)[..., 0]

@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
return self.prior.arg_constraints

Check warning on line 429 in sbi/utils/user_input_checks_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/user_input_checks_utils.py#L429

Added line #L429 was not covered by tests

@property
def support(self):
return self.prior.support

@property
def mean(self) -> Tensor:
return self.prior.mean

Check warning on line 437 in sbi/utils/user_input_checks_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/user_input_checks_utils.py#L437

Added line #L437 was not covered by tests

@property
def variance(self) -> Tensor:
return self.prior.variance

Check warning on line 441 in sbi/utils/user_input_checks_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/user_input_checks_utils.py#L441

Added line #L441 was not covered by tests
20 changes: 18 additions & 2 deletions tests/user_input_checks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
import torch
from pyknos.mdn.mdn import MultivariateGaussianMDN
from torch import Tensor, eye, nn, ones, zeros
from torch.distributions import Beta, Distribution, Gamma, MultivariateNormal, Uniform
from torch.distributions import (
Beta,
Distribution,
Exponential,
Gamma,
MultivariateNormal,
Uniform,
)

from sbi.inference import NPE_A, NPE_C, simulate_for_sbi
from sbi.inference.posteriors.direct_posterior import DirectPosterior
Expand All @@ -27,6 +34,7 @@
from sbi.utils.user_input_checks_utils import (
CustomPriorWrapper,
MultipleIndependent,
OneDimPriorWrapper,
PytorchReturnTypeWrapper,
)

Expand Down Expand Up @@ -93,6 +101,11 @@ def matrix_simulator(theta):
BoxUniform(zeros(3, dtype=torch.float64), ones(3, dtype=torch.float64)),
dict(),
),
(
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
OneDimPriorWrapper,
Exponential(torch.tensor([3.0])),
dict(),
),
),
)
def test_prior_wrappers(wrapper, prior, kwargs):
Expand All @@ -118,6 +131,9 @@ def test_prior_wrappers(wrapper, prior, kwargs):
# Test transform
mcmc_transform(prior)

# For 1D priors, the `log_prob()` should not have a batch dim.
assert len(prior.log_prob(prior.sample((10,))).shape) == 1


def test_reinterpreted_batch_dim_prior():
"""Test whether the right warning and error are raised for reinterpreted priors."""
Expand Down Expand Up @@ -268,7 +284,6 @@ def test_prepare_sbi_problem(simulator: Callable, prior):
prior: prior as defined by the user (pytorch, scipy, custom)
x_shape: shape of data as defined by the user.
"""

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
check_sbi_inputs(simulator, prior)
Expand Down Expand Up @@ -308,6 +323,7 @@ def test_prepare_sbi_problem(simulator: Callable, prior):
MultivariateNormal(zeros(2), eye(2)),
),
),
(diagonal_linear_gaussian, Exponential(torch.tensor([3.0]))),
),
)
def test_inference_with_user_sbi_problems(
Expand Down