Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix npe iid handling #1262

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,19 @@
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
show_progress_bars: Whether to show sampling progress monitor.
"""

num_samples = torch.Size(sample_shape).numel()
x = self._x_else_default_x(x)
x = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
if x.shape[0] > 1:
raise ValueError(
".sample() supports only `batchsize == 1`. If you intend "
"to sample multiple observations, use `.sample_batched()`. "
"If you intend to sample i.i.d. observations, set up the "
"posterior density estimator with an appropriate permutation "
"invariant embedding net."
)

max_sampling_batch_size = (
self.max_sampling_batch_size
Expand All @@ -132,7 +139,7 @@
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
alternative_method="build_posterior(..., sample_with='mcmc')",
)[0]
)[0] # [0] to return only samples, not acceptance probabilities.

return samples[:, 0] # Remove batch dimension.

Expand Down Expand Up @@ -221,9 +228,14 @@
x_density_estimator = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
assert (
x_density_estimator.shape[0] == 1
), ".log_prob() supports only `batchsize == 1`."
if x_density_estimator.shape[0] > 1:
raise ValueError(

Check warning on line 232 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L232

Added line #L232 was not covered by tests
".log_prob() supports only `batchsize == 1`. If you intend "
"to evaluate given multiple observations, use `.log_prob_batched()`. "
"If you intend to evaluate given i.i.d. observations, set up the "
"posterior density estimator with an appropriate permutation "
"invariant embedding net."
)

self.posterior_estimator.eval()

Expand Down
15 changes: 0 additions & 15 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,21 +402,6 @@ def nle_nre_apt_msg_on_invalid_x(
)


def warn_on_batched_x(batch_size):
"""Warn if more than one x was passed."""

if batch_size > 1:
warnings.warn(
f"An x with a batch size of {batch_size} was passed. "
"Unless you are using `sample_batched` or `log_prob_batched`, this will "
"be interpreted as a batch of independent and identically distributed data"
" X={x_1, ..., x_n}, i.e., data generated based on the same underlying"
"(unknown) parameter. The resulting posterior will be with respect to"
" the entire batch, i.e,. p(theta | X).",
stacklevel=2,
)


def check_warn_and_setstate(
state_dict: Dict,
key_name: str,
Expand Down
3 changes: 1 addition & 2 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.distributions import Distribution, Uniform

from sbi.sbi_types import Array
from sbi.utils.sbiutils import warn_on_batched_x, within_support
from sbi.utils.sbiutils import within_support
from sbi.utils.torchutils import BoxUniform, atleast_2d
from sbi.utils.user_input_checks_utils import (
CustomPriorWrapper,
Expand Down Expand Up @@ -582,7 +582,6 @@ def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor:
x = x.unsqueeze(0)

input_x_shape = x.shape
warn_on_batched_x(batch_size=input_x_shape[0])

if x_event_shape is not None:
# Number of trials can change for every new x, but single trial x shape must
Expand Down
2 changes: 1 addition & 1 deletion tests/posterior_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
pytest.param(
2,
marks=pytest.mark.xfail(
raises=AssertionError,
raises=ValueError,
reason=".log_prob() supports only batch size 1 for x_o.",
),
),
Expand Down