Skip to content

Commit

Permalink
chore: explicit enabling/disabling of 64bit floating point numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash committed Sep 11, 2024
1 parent 21b50b5 commit f1da2d5
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,8 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape):
# Enable 64bit support for higher accuracy
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
rng_key = random.PRNGKey(0)
expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape
samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
Expand Down Expand Up @@ -1260,6 +1262,8 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape):
def test_infer_shapes(jax_dist, sp_dist, params):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
shapes = []
for param in params:
if param is None:
Expand Down Expand Up @@ -1287,6 +1291,8 @@ def test_has_rsample(jax_dist, sp_dist, params):
jax_dist = jax_dist(*params)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
masked_dist = jax_dist.mask(False)
indept_dist = jax_dist.expand_by([2]).to_event(1)
transf_dist = dist.TransformedDistribution(jax_dist, biject_to(constraints.real))
Expand Down Expand Up @@ -1343,6 +1349,8 @@ def test_sample_gradient(jax_dist, sp_dist, params):

if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

dist_args = [
p
Expand Down Expand Up @@ -1441,6 +1449,8 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params):

if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

rng_key = random.PRNGKey(0)
samples = jax_dist(*params).sample(key=rng_key, sample_shape=(2, 3))
Expand All @@ -1461,6 +1471,8 @@ def log_likelihood(*params):
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
jit_fn = _identity if not jit else jax.jit
jax_dist = jax_dist(*params)

Expand Down Expand Up @@ -1526,6 +1538,8 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
def test_entropy_scipy(jax_dist, sp_dist, params):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

jax_dist = jax_dist(*params)

Expand All @@ -1549,6 +1563,8 @@ def test_entropy_scipy(jax_dist, sp_dist, params):
def test_entropy_samples(jax_dist, sp_dist, params):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

jax_dist = jax_dist(*params)

Expand Down Expand Up @@ -1596,6 +1612,8 @@ def test_mixture_log_prob():
def test_cdf_and_icdf(jax_dist, sp_dist, params):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
d = jax_dist(*params)
if d.event_dim > 0:
pytest.skip("skip testing cdf/icdf methods of multivariate distributions")
Expand Down Expand Up @@ -1650,6 +1668,8 @@ def test_gof(jax_dist, sp_dist, params):
pytest.skip("skip gof test for ZeroSumNormal")
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

num_samples = 10000
if "BetaProportion" in jax_dist.__name__:
Expand Down Expand Up @@ -1682,6 +1702,8 @@ def test_gof(jax_dist, sp_dist, params):
def test_independent_shape(jax_dist, sp_dist, params):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
d = jax_dist(*params)
batch_shape, event_shape = d.batch_shape, d.event_shape
shape = batch_shape + event_shape
Expand Down Expand Up @@ -1869,6 +1891,8 @@ def test_log_prob_gradient(jax_dist, sp_dist, params):
pytest.skip("no param for ImproperUniform to test for log_prob gradient")
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

rng_key = random.PRNGKey(0)
value = jax_dist(*params).sample(rng_key)
Expand Down Expand Up @@ -2101,6 +2125,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):

if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

dist_args = [p for p in inspect.getfullargspec(jax_dist.__init__)[0][1:]]

Expand Down Expand Up @@ -2782,6 +2808,8 @@ def test_generated_sample_distribution(
)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

jax_dist = jax_dist(*params)
if sp_dist and not jax_dist.event_shape and not jax_dist.batch_shape:
Expand Down Expand Up @@ -2826,6 +2854,8 @@ def test_zero_inflated_enumerate_support():
def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
jax_dist = jax_dist(*params)
new_batch_shape = prepend_shape + jax_dist.batch_shape
expanded_dist = jax_dist.expand(new_batch_shape)
Expand Down Expand Up @@ -2970,6 +3000,8 @@ def f(x, data):
def test_dist_pytree(jax_dist, sp_dist, params):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

def f(x):
return jax_dist(*params)
Expand Down Expand Up @@ -3257,6 +3289,8 @@ def _tree_equal(t1, t2):
def test_vmap_dist(jax_dist, sp_dist, params):
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
param_names = list(inspect.signature(jax_dist).parameters.keys())
vmappable_param_idxs = _get_vmappable_dist_init_params(jax_dist)
vmappable_param_idxs = vmappable_param_idxs[: len(params)]
Expand Down

0 comments on commit f1da2d5

Please sign in to comment.