diff --git a/test/test_distributions.py b/test/test_distributions.py index 8369dde0b..c702b3fb3 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1838,6 +1838,11 @@ def test_log_prob_gradient(jax_dist, sp_dist, params): pytest.skip("we have separated tests for LKJCholesky distribution") if jax_dist is _ImproperWrapper: pytest.skip("no param for ImproperUniform to test for log_prob gradient") + if ( + jax_dist in [dist.DoublyTruncatedPowerLaw] + and jnp.result_type(float) == jnp.float32 + ): + pytest.skip("DoublyTruncatedPowerLaw is tested with x64 only.") rng_key = random.PRNGKey(0) value = jax_dist(*params).sample(rng_key)