From b4b1927c3bfe9475a30ad6719a722485e6370e02 Mon Sep 17 00:00:00 2001 From: janfb Date: Fri, 4 Jun 2021 11:12:25 +0200 Subject: [PATCH] test abc with uniform prior. --- sbi/simulators/linear_gaussian.py | 2 +- tests/abc_test.py | 35 +++++++++++++++++++++++-------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/sbi/simulators/linear_gaussian.py b/sbi/simulators/linear_gaussian.py index 4135a1e3b..7b029854f 100644 --- a/sbi/simulators/linear_gaussian.py +++ b/sbi/simulators/linear_gaussian.py @@ -195,7 +195,7 @@ def samples_true_posterior_linear_gaussian_uniform_prior( while num_remaining > 0: candidate_samples = posterior.sample(sample_shape=(num_remaining,)) - is_in_prior = within_support(prior.log_prob, candidate_samples) + is_in_prior = within_support(prior, candidate_samples) # accept if in prior if is_in_prior.sum(): samples.append(candidate_samples[is_in_prior, :]) diff --git a/tests/abc_test.py b/tests/abc_test.py index 297443e61..d4c8522a1 100644 --- a/tests/abc_test.py +++ b/tests/abc_test.py @@ -5,8 +5,10 @@ from sbi.inference import ABC, SMC from sbi.simulators.linear_gaussian import ( linear_gaussian, + samples_true_posterior_linear_gaussian_uniform_prior, true_posterior_linear_gaussian_mvn_prior, ) +from sbi.utils import BoxUniform from tests.test_utils import check_c2st @@ -64,21 +66,34 @@ def test_mcabc_sass_lra(lra, sass_expansion_degree, set_seed): @pytest.mark.parametrize("num_dim", (1, 2)) +@pytest.mark.parametrize("prior_type", ("uniform", "gaussian")) def test_smcabc_inference_on_linear_gaussian( - num_dim, lra=False, sass=False, sass_expansion_degree=1 + num_dim, prior_type: str, lra=False, sass=False, sass_expansion_degree=1 ): x_o = zeros((1, num_dim)) num_samples = 1000 likelihood_shift = -1.0 * ones(num_dim) likelihood_cov = 0.3 * eye(num_dim) - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - gt_posterior = true_posterior_linear_gaussian_mvn_prior( - x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - target_samples = gt_posterior.sample((num_samples,)) + if prior_type == "gaussian": + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + gt_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + target_samples = gt_posterior.sample((num_samples,)) + elif prior_type == "uniform": + prior = BoxUniform(-ones(num_dim), ones(num_dim)) + target_samples = samples_true_posterior_linear_gaussian_uniform_prior( + x_o[0], + likelihood_shift, + likelihood_cov, + prior, + num_samples, + ) + else: + raise ValueError("Wrong prior string.") def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) @@ -98,7 +113,9 @@ def simulator(theta): sass_expansion_degree=sass_expansion_degree, ) - check_c2st(phat.sample((num_samples,)), target_samples, alg="SMCABC") + check_c2st( + phat.sample((num_samples,)), target_samples, alg=f"SMCABC-{prior_type}-prior" + ) @pytest.mark.slow