Skip to content

Commit

Permalink
test abc with uniform prior.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jun 4, 2021
1 parent f242f2b commit b4b1927
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion sbi/simulators/linear_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def samples_true_posterior_linear_gaussian_uniform_prior(

while num_remaining > 0:
candidate_samples = posterior.sample(sample_shape=(num_remaining,))
is_in_prior = within_support(prior.log_prob, candidate_samples)
is_in_prior = within_support(prior, candidate_samples)
# accept if in prior
if is_in_prior.sum():
samples.append(candidate_samples[is_in_prior, :])
Expand Down
35 changes: 26 additions & 9 deletions tests/abc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from sbi.inference import ABC, SMC
from sbi.simulators.linear_gaussian import (
linear_gaussian,
samples_true_posterior_linear_gaussian_uniform_prior,
true_posterior_linear_gaussian_mvn_prior,
)
from sbi.utils import BoxUniform
from tests.test_utils import check_c2st


Expand Down Expand Up @@ -64,21 +66,34 @@ def test_mcabc_sass_lra(lra, sass_expansion_degree, set_seed):


@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("prior_type", ("uniform", "gaussian"))
def test_smcabc_inference_on_linear_gaussian(
num_dim, lra=False, sass=False, sass_expansion_degree=1
num_dim, prior_type: str, lra=False, sass=False, sass_expansion_degree=1
):
x_o = zeros((1, num_dim))
num_samples = 1000
likelihood_shift = -1.0 * ones(num_dim)
likelihood_cov = 0.3 * eye(num_dim)

prior_mean = zeros(num_dim)
prior_cov = eye(num_dim)
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
gt_posterior = true_posterior_linear_gaussian_mvn_prior(
x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov
)
target_samples = gt_posterior.sample((num_samples,))
if prior_type == "gaussian":
prior_mean = zeros(num_dim)
prior_cov = eye(num_dim)
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
gt_posterior = true_posterior_linear_gaussian_mvn_prior(
x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov
)
target_samples = gt_posterior.sample((num_samples,))
elif prior_type == "uniform":
prior = BoxUniform(-ones(num_dim), ones(num_dim))
target_samples = samples_true_posterior_linear_gaussian_uniform_prior(
x_o[0],
likelihood_shift,
likelihood_cov,
prior,
num_samples,
)
else:
raise ValueError("Wrong prior string.")

def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)
Expand All @@ -98,7 +113,9 @@ def simulator(theta):
sass_expansion_degree=sass_expansion_degree,
)

check_c2st(phat.sample((num_samples,)), target_samples, alg="SMCABC")
check_c2st(
phat.sample((num_samples,)), target_samples, alg=f"SMCABC-{prior_type}-prior"
)


@pytest.mark.slow
Expand Down

0 comments on commit b4b1927

Please sign in to comment.