diff --git a/tests/test_smc_ess.py b/tests/test_smc_ess.py index ed4b19520..e8ec5fb64 100644 --- a/tests/test_smc_ess.py +++ b/tests/test_smc_ess.py @@ -16,7 +16,7 @@ class SMCEffectiveSampleSizeTest(chex.TestCase): @chex.all_variants(with_pmap=False) - @parameterized.parameters([100, 1000, 5000]) + @parameterized.parameters([1000, 5000]) def test_ess(self, N): log_ess_fn = self.variant(functools.partial(ess.ess, log=True)) ess_fn = self.variant(functools.partial(ess.ess, log=False)) @@ -39,7 +39,7 @@ def test_ess(self, N): ) @chex.all_variants(with_pmap=False) - @parameterized.parameters(itertools.product([0.25, 0.5], [100, 1000, 5000])) + @parameterized.parameters(itertools.product([0.25, 0.5], [1000, 5000])) def test_ess_solver(self, target_ess, N): potential_fn = lambda pytree: -univariate_logpdf(pytree, scale=0.1) potential = jax.vmap(lambda x: potential_fn(x), in_axes=[0]) @@ -47,7 +47,7 @@ def test_ess_solver(self, target_ess, N): self.ess_solver_test_case(potential, particles, target_ess, N, 1.0) @chex.all_variants(with_pmap=False) - @parameterized.parameters(itertools.product([0.25, 0.5], [100, 1000, 5000])) + @parameterized.parameters(itertools.product([0.25, 0.5], [1000, 5000])) def test_ess_solver_multivariate(self, target_ess, N): """ Posterior with more than one variable. Let's assume we want to @@ -63,7 +63,7 @@ def test_ess_solver_multivariate(self, target_ess, N): self.ess_solver_test_case(potential, particles, target_ess, N, 10.0) @chex.all_variants(with_pmap=False) - @parameterized.parameters(itertools.product([0.25, 0.5], [100, 1000, 5000])) + @parameterized.parameters(itertools.product([0.25, 0.5], [1000, 5000])) def test_ess_solver_posterior_signature(self, target_ess, N): """ Posterior with more than one variable. Let's assume we want to