Skip to content

Commit

Permalink
fix: leakage correction for log prob batched (#1355)
Browse files Browse the repository at this point in the history
* Fixing leakage correction inconsistency.

* Improving test to cover batched log_prob on bounded support priors

* Fixing test

* Formating fix

* add type
  • Loading branch information
manuelgloeckler authored Jan 10, 2025
1 parent 9152e93 commit a6a220d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
4 changes: 1 addition & 3 deletions sbi/samplers/rejection/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,4 @@ def accept_reject_sample(
samples.shape[0] == num_samples
), "Number of accepted samples must match required samples."

# NOTE: Restriction prior does currently require a float as return for the
# acceptance rate, which is why we for now also return the minimum acceptance rate.
return samples, as_tensor(min_acceptance_rate)
return samples, as_tensor(acceptance_rate)
5 changes: 5 additions & 0 deletions sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,11 @@ def sample(
max_sampling_batch_size=max_sampling_batch_size,
alternative_method="sample_with='sir'",
)
# NOTE: This currently requires a float acceptance rate. A previous version
# of accept_reject_sample returned a float. In favour to batched sampling
# it now returns a tensor.
acceptance_rate = acceptance_rate.min().item()

if save_acceptance_rate:
self.acceptance_rate = torch.as_tensor(acceptance_rate)
if print_rejected_frac:
Expand Down
28 changes: 25 additions & 3 deletions tests/posterior_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal
from torch.distributions import Independent, MultivariateNormal, Uniform

from sbi.inference import (
NLE_A,
Expand Down Expand Up @@ -98,13 +98,20 @@ def test_importance_posterior_sample_log_prob(snplre_method: type):

@pytest.mark.parametrize("snpe_method", [NPE_A, NPE_C])
@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
@pytest.mark.parametrize("prior", ("mvn", "uniform"))
def test_batched_sample_log_prob_with_different_x(
snpe_method: type, x_o_batch_dim: bool
snpe_method: type,
x_o_batch_dim: bool,
prior: str,
):
num_dim = 2
num_simulations = 1000

prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
# We also want to test on bounded support! Which will invoke leakage correction.
if prior == "mvn":
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
elif prior == "uniform":
prior = Independent(Uniform(-1.0 * ones(num_dim), 1.0 * ones(num_dim)), 1)
simulator = diagonal_linear_gaussian

inference = snpe_method(prior=prior)
Expand All @@ -116,6 +123,7 @@ def test_batched_sample_log_prob_with_different_x(

posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior)

torch.manual_seed(0)
samples = posterior.sample_batched((10,), x_o)
batched_log_probs = posterior.log_prob_batched(samples, x_o)

Expand All @@ -126,6 +134,20 @@ def test_batched_sample_log_prob_with_different_x(
), "Sample shape wrong"
assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)), "logprob shape wrong"

# Test consistency with non-batched log_prob
# NOTE: Leakage factor is a MC estimate, so we need to relax the tolerance here.
if x_o_batch_dim == 0:
log_probs = posterior.log_prob(samples, x=x_o)
assert torch.allclose(
log_probs, batched_log_probs[:, 0], atol=1e-1, rtol=1e-1
), "Batched log probs different from non-batched log probs"
else:
for idx in range(x_o_batch_dim):
log_probs = posterior.log_prob(samples[:, idx], x=x_o[idx])
assert torch.allclose(
log_probs, batched_log_probs[:, idx], atol=1e-1, rtol=1e-1
), "Batched log probs different from non-batched log probs"


@pytest.mark.mcmc
@pytest.mark.parametrize("snlre_method", [NLE_A, NRE_A, NRE_B, NRE_C, NPE_C])
Expand Down

0 comments on commit a6a220d

Please sign in to comment.