Skip to content

Commit

Permalink
fix: add batch dim to inv transform, add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 2, 2023
1 parent 8769e25 commit e840cce
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 38 deletions.
7 changes: 5 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ def warn_on_iid_x(num_trials):


def check_warn_and_setstate(
state_dict: Dict, key_name: str, replacement_value: Any, warning_msg: str = ""
state_dict: Dict,
key_name: str,
replacement_value: Any,
warning_msg: str = "",
) -> Tuple[Dict, str]:
"""
Check if `key_name` is in `state_dict` and add it if not.
Expand Down Expand Up @@ -881,7 +884,7 @@ def gradient_ascent(
)
best_theta_iter = optimize_inits[ # type: ignore
torch.argmax(log_probs_of_optimized)
]
].view(1, -1)
best_log_prob_iter = potential_fn(
theta_transform.inv(best_theta_iter)
)
Expand Down
146 changes: 110 additions & 36 deletions tests/linearGaussian_snle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import pytest
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal
from torch.distributions import HalfNormal, MultivariateNormal

from sbi import utils as utils
from sbi.inference import (
SNLE,
ImportanceSamplingPosterior,
Expand All @@ -25,7 +24,7 @@
samples_true_posterior_linear_gaussian_uniform_prior,
true_posterior_linear_gaussian_mvn_prior,
)
from sbi.utils import likelihood_nn
from sbi.utils import BoxUniform, likelihood_nn, process_prior

from .test_utils import check_c2st, get_prob_outside_uniform_prior

Expand All @@ -48,9 +47,13 @@ def test_api_snl_on_linearGaussian(num_dim: int):

simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
density_estimator = likelihood_nn("maf", num_transforms=3)
inference = SNLE(density_estimator=density_estimator, show_progress_bars=False)
inference = SNLE(
density_estimator=density_estimator, show_progress_bars=False
)

theta, x = simulate_for_sbi(simulator, prior, 1000, simulation_batch_size=50)
theta, x = simulate_for_sbi(
simulator, prior, 1000, simulation_batch_size=50
)
likelihood_estimator = inference.append_simulations(theta, x).train(
training_batch_size=100
)
Expand Down Expand Up @@ -92,23 +95,30 @@ def test_c2st_snl_on_linearGaussian(density_estimator="maf"):
prior_mean = zeros(theta_dim)
prior_cov = eye(theta_dim)
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
target_samples = samples_true_posterior_linear_gaussian_mvn_prior_different_dims(
x_o,
likelihood_shift,
likelihood_cov,
prior_mean,
prior_cov,
num_discarded_dims=discard_dims,
num_samples=num_samples,
target_samples = (
samples_true_posterior_linear_gaussian_mvn_prior_different_dims(
x_o,
likelihood_shift,
likelihood_cov,
prior_mean,
prior_cov,
num_discarded_dims=discard_dims,
num_samples=num_samples,
)
)
simulator, prior = prepare_for_sbi(
lambda theta: linear_gaussian(
theta, likelihood_shift, likelihood_cov, num_discarded_dims=discard_dims
theta,
likelihood_shift,
likelihood_cov,
num_discarded_dims=discard_dims,
),
prior,
)
density_estimator = likelihood_nn(model=density_estimator, num_transforms=3)
inference = SNLE(density_estimator=density_estimator, show_progress_bars=False)
inference = SNLE(
density_estimator=density_estimator, show_progress_bars=False
)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=50
Expand All @@ -134,7 +144,9 @@ def test_c2st_snl_on_linearGaussian(density_estimator="maf"):
@pytest.mark.slow
@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("prior_str", ("uniform", "gaussian"))
def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: str):
def test_c2st_and_map_snl_on_linearGaussian_different(
num_dim: int, prior_str: str
):
"""Test SNL on linear Gaussian, comparing to ground truth posterior via c2st.
Args:
Expand All @@ -156,13 +168,16 @@ def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: s
prior_cov = eye(num_dim)
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
else:
prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))
prior = BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

simulator, prior = prepare_for_sbi(
lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior
lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov),
prior,
)
density_estimator = likelihood_nn("maf", num_transforms=3)
inference = SNLE(density_estimator=density_estimator, show_progress_bars=False)
inference = SNLE(
density_estimator=density_estimator, show_progress_bars=False
)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=10000
Expand All @@ -178,12 +193,14 @@ def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: s
)
target_samples = gt_posterior.sample((num_samples,))
elif prior_str == "uniform":
target_samples = samples_true_posterior_linear_gaussian_uniform_prior(
x_o,
likelihood_shift,
likelihood_cov,
prior=prior,
num_samples=num_samples,
target_samples = (
samples_true_posterior_linear_gaussian_uniform_prior(
x_o,
likelihood_shift,
likelihood_cov,
prior=prior,
num_samples=num_samples,
)
)
else:
raise ValueError(f"Wrong prior_str: '{prior_str}'.")
Expand All @@ -204,18 +221,24 @@ def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: s

# Check performance based on c2st accuracy.
check_c2st(
samples, target_samples, alg=f"snle_a-{prior_str}-prior-{num_trials}-trials"
samples,
target_samples,
alg=f"snle_a-{prior_str}-prior-{num_trials}-trials",
)

map_ = posterior.map(
num_init_samples=1_000, init_method="proposal", show_progress_bars=False
num_init_samples=1_000,
init_method="proposal",
show_progress_bars=False,
)

# TODO: we do not have a test for SNL log_prob(). This is because the output
# TODO: density is not normalized, so KLd does not make sense.
if prior_str == "uniform":
# Check whether the returned probability outside of the support is zero.
posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim)
posterior_prob = get_prob_outside_uniform_prior(
posterior, prior, num_dim
)
assert (
posterior_prob == 0.0
), "The posterior probability outside of the prior support is not zero"
Expand All @@ -225,6 +248,43 @@ def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: s
assert ((map_ - gt_posterior.mean) ** 2).sum() < 0.5


@pytest.mark.parametrize("use_transform", (True, False))
def test_map_with_multiple_independent_prior(use_transform):
"""Test whether map works with multiple independent priors, see issue #841, #650."""

dim = 2
prior, *_ = process_prior(
[
BoxUniform(low=-ones(dim), high=ones(dim)),
HalfNormal(scale=ones(1) * 2),
]
)

def simulator(theta):
return theta[:, 2:] * torch.randn_like(theta[:, :2]) + theta[:, :2]

num_simulations = 1000
theta = prior.sample((num_simulations,))
x = simulator(theta)
x_o = zeros((1, dim))

trainer = SNLE(prior).append_simulations(theta, x)
likelihood_estimator = trainer.train(max_num_epochs=5)

potential_fn, parameter_transform = likelihood_estimator_based_potential(
likelihood_estimator,
prior,
x_o=x_o,
)
posterior = MCMCPosterior(
potential_fn,
proposal=prior,
theta_transform=parameter_transform if use_transform else None,
)
posterior.map()
posterior.set_default_x(x_o).map(num_iter=10)


@pytest.mark.slow
@pytest.mark.parametrize("num_trials", (1, 3))
def test_c2st_multi_round_snl_on_linearGaussian(num_trials: int):
Expand All @@ -248,7 +308,8 @@ def test_c2st_multi_round_snl_on_linearGaussian(num_trials: int):
target_samples = gt_posterior.sample((num_samples,))

simulator, prior = prepare_for_sbi(
lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior
lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov),
prior,
)
inference = SNLE(show_progress_bars=False)

Expand All @@ -268,7 +329,10 @@ def test_c2st_multi_round_snl_on_linearGaussian(num_trials: int):
)

theta, x = simulate_for_sbi(
simulator, posterior1, num_simulations_per_round, simulation_batch_size=50
simulator,
posterior1,
num_simulations_per_round,
simulation_batch_size=50,
)
likelihood_estimator = inference.append_simulations(theta, x).train()
potential_fn, theta_transform = likelihood_estimator_based_potential(
Expand Down Expand Up @@ -311,7 +375,8 @@ def test_c2st_multi_round_snl_on_linearGaussian_vi(num_trials: int):
target_samples = gt_posterior.sample((num_samples,))

simulator, prior = prepare_for_sbi(
lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior
lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov),
prior,
)
inference = SNLE(show_progress_bars=False)

Expand All @@ -329,7 +394,10 @@ def test_c2st_multi_round_snl_on_linearGaussian_vi(num_trials: int):
posterior1.train()

theta, x = simulate_for_sbi(
simulator, posterior1, num_simulations_per_round, simulation_batch_size=50
simulator,
posterior1,
num_simulations_per_round,
simulation_batch_size=50,
)
likelihood_estimator = inference.append_simulations(theta, x).train()
potential_fn, theta_transform = likelihood_estimator_based_potential(
Expand Down Expand Up @@ -406,9 +474,11 @@ def test_api_snl_sampling_methods(
sample_with = "vi"

if prior_str == "gaussian":
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
prior = MultivariateNormal(
loc=zeros(num_dim), covariance_matrix=eye(num_dim)
)
else:
prior = utils.BoxUniform(-1.0 * ones(num_dim), ones(num_dim))
prior = BoxUniform(-1.0 * ones(num_dim), ones(num_dim))

# Why do we have this if-case? Only the `MCMCPosterior` uses the `init_strategy`.
# Thus, we would not like to run, e.g., VI with all init_strategies, but only once
Expand All @@ -427,7 +497,9 @@ def test_api_snl_sampling_methods(
prior=prior, likelihood_estimator=likelihood_estimator, x_o=x_o
)
if sample_with == "rejection":
posterior = RejectionPosterior(potential_fn=potential_fn, proposal=prior)
posterior = RejectionPosterior(
potential_fn=potential_fn, proposal=prior
)
elif (
"slice" in sampling_method
or "nuts" in sampling_method
Expand All @@ -450,7 +522,9 @@ def test_api_snl_sampling_methods(
)
else:
posterior = VIPosterior(
potential_fn, theta_transform=theta_transform, vi_method=sampling_method
potential_fn,
theta_transform=theta_transform,
vi_method=sampling_method,
)
posterior.train(max_num_iters=10)

Expand Down

0 comments on commit e840cce

Please sign in to comment.