From 42c9d751cb48123e06279f36c7de44f8bee52321 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 30 Jan 2025 11:27:50 -0500 Subject: [PATCH 1/2] refactor test_gof in a separate file --- test/test_distributions.py | 43 ----------------------------- test/test_gof.py | 56 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 43 deletions(-) create mode 100644 test/test_gof.py diff --git a/test/test_distributions.py b/test/test_distributions.py index 003c20b9c..f08c9ebfc 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -35,7 +35,6 @@ from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom from numpyro.distributions.distribution import DistributionLike from numpyro.distributions.flows import InverseAutoregressiveTransform -from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit from numpyro.distributions.transforms import ( LowerCholeskyAffine, PermuteTransform, @@ -53,8 +52,6 @@ ) from numpyro.nn import AutoregressiveNN -TEST_FAILURE_RATE = 2e-5 # For all goodness-of-fit tests. - def my_kron(A, B): D = A[..., :, None, :, None] * B[..., None, :, None, :] @@ -1637,46 +1634,6 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params): pytest.skip("cdf/icdf not implemented") -@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DIRECTIONAL) -def test_gof(jax_dist, sp_dist, params): - if "Improper" in jax_dist.__name__: - pytest.skip("distribution has improper .log_prob()") - if "LKJ" in jax_dist.__name__ or "Wishart" in jax_dist.__name__: - pytest.xfail("incorrect submanifold scaling") - if jax_dist is dist.EulerMaruyama: - d = jax_dist(*params) - if d.event_dim > 1: - pytest.skip("EulerMaruyama skip test when event shape is non-trivial.") - if jax_dist is dist.ZeroSumNormal: - pytest.skip("skip gof test for ZeroSumNormal") - - num_samples = 10000 - if "BetaProportion" in jax_dist.__name__: - num_samples = 20000 - rng_key = random.PRNGKey(0) - d = jax_dist(*params) - samples = d.sample(key=rng_key, sample_shape=(num_samples,)) - probs = np.exp(d.log_prob(samples)) - - dim = None - if jax_dist is dist.ProjectedNormal: - dim = samples.shape[-1] - 1 - - # Test each batch independently. - probs = probs.reshape(num_samples, -1) - samples = samples.reshape(probs.shape + d.event_shape) - if "Dirichlet" in jax_dist.__name__: - # The Dirichlet density is over all but one of the probs. - samples = samples[..., :-1] - for b in range(probs.shape[1]): - try: - gof = auto_goodness_of_fit(samples[:, b], probs[:, b], dim=dim) - except InvalidTest: - pytest.skip("expensive test") - else: - assert gof > TEST_FAILURE_RATE - - @pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DISCRETE) def test_independent_shape(jax_dist, sp_dist, params): d = jax_dist(*params) diff --git a/test/test_gof.py b/test/test_gof.py new file mode 100644 index 000000000..613ceb01a --- /dev/null +++ b/test/test_gof.py @@ -0,0 +1,56 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +from test_distributions import CONTINUOUS, DIRECTIONAL + +import jax.random as random + +import numpyro.distributions as dist +from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit + +TEST_FAILURE_RATE = 2e-5 # For all goodness-of-fit tests. + + +@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DIRECTIONAL) +def test_gof(jax_dist, sp_dist, params): + if "Improper" in jax_dist.__name__: + pytest.skip("distribution has improper .log_prob()") + if "LKJ" in jax_dist.__name__ or "Wishart" in jax_dist.__name__: + pytest.xfail("incorrect submanifold scaling") + if jax_dist is dist.EulerMaruyama: + d = jax_dist(*params) + if d.event_dim > 1: + pytest.skip("EulerMaruyama skip test when event shape is non-trivial.") + if jax_dist is dist.ZeroSumNormal: + pytest.skip("skip gof test for ZeroSumNormal") + + num_samples = 10000 + if any( + name in jax_dist.__name__ + for name in ["BetaProportion", "MatrixNormal", "SineBivariateVonMises"] + ): + num_samples = 20000 + rng_key = random.PRNGKey(0) + d = jax_dist(*params) + samples = d.sample(key=rng_key, sample_shape=(num_samples,)) + probs = np.exp(d.log_prob(samples)) + + dim = None + if jax_dist is dist.ProjectedNormal: + dim = samples.shape[-1] - 1 + + # Test each batch independently. + probs = probs.reshape(num_samples, -1) + samples = samples.reshape(probs.shape + d.event_shape) + if "Dirichlet" in jax_dist.__name__: + # The Dirichlet density is over all but one of the probs. + samples = samples[..., :-1] + for b in range(probs.shape[1]): + try: + gof = auto_goodness_of_fit(samples[:, b], probs[:, b], dim=dim) + except InvalidTest: + pytest.skip("expensive test") + else: + assert gof > TEST_FAILURE_RATE From 266f90aea4b60b766f31aecac8d4ce64d5177762 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 30 Jan 2025 12:49:31 -0500 Subject: [PATCH 2/2] skip MatrixNormal for gof test --- test/test_gof.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_gof.py b/test/test_gof.py index 613ceb01a..01c59f56e 100644 --- a/test/test_gof.py +++ b/test/test_gof.py @@ -25,11 +25,15 @@ def test_gof(jax_dist, sp_dist, params): pytest.skip("EulerMaruyama skip test when event shape is non-trivial.") if jax_dist is dist.ZeroSumNormal: pytest.skip("skip gof test for ZeroSumNormal") + if jax_dist is dist.MatrixNormal: + pytest.skip( + "skip gof test for MatrixNormal, likely incorrect submanifold scaling" + ) num_samples = 10000 if any( name in jax_dist.__name__ - for name in ["BetaProportion", "MatrixNormal", "SineBivariateVonMises"] + for name in ["BetaProportion", "SineBivariateVonMises"] ): num_samples = 20000 rng_key = random.PRNGKey(0)