diff --git a/test/test_distributions.py b/test/test_distributions.py index 9170be113..7729b12e8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1212,6 +1212,8 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape): # Enable 64bit support for higher accuracy if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) rng_key = random.PRNGKey(0) expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape) @@ -1260,6 +1262,8 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape): def test_infer_shapes(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) shapes = [] for param in params: if param is None: @@ -1287,6 +1291,8 @@ def test_has_rsample(jax_dist, sp_dist, params): jax_dist = jax_dist(*params) if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) masked_dist = jax_dist.mask(False) indept_dist = jax_dist.expand_by([2]).to_event(1) transf_dist = dist.TransformedDistribution(jax_dist, biject_to(constraints.real)) @@ -1343,6 +1349,8 @@ def test_sample_gradient(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) dist_args = [ p @@ -1441,6 +1449,8 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) rng_key = random.PRNGKey(0) samples = jax_dist(*params).sample(key=rng_key, sample_shape=(2, 3)) @@ -1461,6 +1471,8 @@ def log_likelihood(*params): def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) jit_fn = _identity if not jit else jax.jit jax_dist = jax_dist(*params) @@ -1526,6 +1538,8 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): def test_entropy_scipy(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) jax_dist = jax_dist(*params) @@ -1549,6 +1563,8 @@ def test_entropy_scipy(jax_dist, sp_dist, params): def test_entropy_samples(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) jax_dist = jax_dist(*params) @@ -1596,6 +1612,8 @@ def test_mixture_log_prob(): def test_cdf_and_icdf(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) d = jax_dist(*params) if d.event_dim > 0: pytest.skip("skip testing cdf/icdf methods of multivariate distributions") @@ -1650,6 +1668,8 @@ def test_gof(jax_dist, sp_dist, params): pytest.skip("skip gof test for ZeroSumNormal") if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) num_samples = 10000 if "BetaProportion" in jax_dist.__name__: @@ -1682,6 +1702,8 @@ def test_gof(jax_dist, sp_dist, params): def test_independent_shape(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) d = jax_dist(*params) batch_shape, event_shape = d.batch_shape, d.event_shape shape = batch_shape + event_shape @@ -1869,6 +1891,8 @@ def test_log_prob_gradient(jax_dist, sp_dist, params): pytest.skip("no param for ImproperUniform to test for log_prob gradient") if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) rng_key = random.PRNGKey(0) value = jax_dist(*params).sample(rng_key) @@ -2101,6 +2125,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) dist_args = [p for p in inspect.getfullargspec(jax_dist.__init__)[0][1:]] @@ -2782,6 +2808,8 @@ def test_generated_sample_distribution( ) if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) jax_dist = jax_dist(*params) if sp_dist and not jax_dist.event_shape and not jax_dist.batch_shape: @@ -2826,6 +2854,8 @@ def test_zero_inflated_enumerate_support(): def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) jax_dist = jax_dist(*params) new_batch_shape = prepend_shape + jax_dist.batch_shape expanded_dist = jax_dist.expand(new_batch_shape) @@ -2970,6 +3000,8 @@ def f(x, data): def test_dist_pytree(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) def f(x): return jax_dist(*params) @@ -3257,6 +3289,8 @@ def _tree_equal(t1, t2): def test_vmap_dist(jax_dist, sp_dist, params): if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: numpyro.enable_x64() + else: + numpyro.enable_x64(False) param_names = list(inspect.signature(jax_dist).parameters.keys()) vmappable_param_idxs = _get_vmappable_dist_init_params(jax_dist) vmappable_param_idxs = vmappable_param_idxs[: len(params)]