Skip to content

Commit

Permalink
lower bound for num_simulations
Browse files Browse the repository at this point in the history
  • Loading branch information
Julia Linhart committed Feb 7, 2022
1 parent 2f8cda2 commit 9ddd73f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/inference_with_NaN_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def linear_gaussian_nan(
num_rounds = 2

for r in range(num_rounds):
theta, x = simulate_for_sbi(simulator, proposals[-1], 2000)
theta, x = simulate_for_sbi(simulator, proposals[-1], 1000)
restriction_estimator.append_simulations(theta, x)
if r < num_rounds - 1:
_ = restriction_estimator.train()
Expand Down
4 changes: 2 additions & 2 deletions tests/linearGaussian_snle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_c2st_and_map_snl_on_linearGaussian_different(
set_seed: fixture for manual seeding
"""
num_samples = 500
num_simulations = 3100
num_simulations = 3000
trials_to_test = [1]

# likelihood_mean will be likelihood_shift+theta
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_api_snl_sampling_methods(
num_dim = 2
num_samples = 10
num_trials = 2
num_simulations = 3100
num_simulations = 1000
x_o = zeros((num_trials, num_dim))
# Test for multiple chains is cheap when vectorized.
num_chains = 3 if sampling_method == "slice_np_vectorized" else 1
Expand Down
2 changes: 1 addition & 1 deletion tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_c2st_multi_round_snpe_on_linearGaussian(method_str: str, set_seed):
).set_default_x(x_o)
elif method_str == "snpe_c":
inference = SNPE_C(**creation_args)
theta, x = simulate_for_sbi(simulator, prior, 1000, simulation_batch_size=50)
theta, x = simulate_for_sbi(simulator, prior, 900, simulation_batch_size=50)
posterior_estimator = inference.append_simulations(theta, x).train()
posterior1 = DirectPosterior(
prior=prior, posterior_estimator=posterior_estimator
Expand Down

0 comments on commit 9ddd73f

Please sign in to comment.