Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor test_gof into a separate file #1965

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom
from numpyro.distributions.distribution import DistributionLike
from numpyro.distributions.flows import InverseAutoregressiveTransform
from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit
from numpyro.distributions.transforms import (
LowerCholeskyAffine,
PermuteTransform,
Expand All @@ -53,8 +52,6 @@
)
from numpyro.nn import AutoregressiveNN

TEST_FAILURE_RATE = 2e-5 # For all goodness-of-fit tests.


def my_kron(A, B):
D = A[..., :, None, :, None] * B[..., None, :, None, :]
Expand Down Expand Up @@ -1637,46 +1634,6 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
pytest.skip("cdf/icdf not implemented")


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DIRECTIONAL)
def test_gof(jax_dist, sp_dist, params):
if "Improper" in jax_dist.__name__:
pytest.skip("distribution has improper .log_prob()")
if "LKJ" in jax_dist.__name__ or "Wishart" in jax_dist.__name__:
pytest.xfail("incorrect submanifold scaling")
if jax_dist is dist.EulerMaruyama:
d = jax_dist(*params)
if d.event_dim > 1:
pytest.skip("EulerMaruyama skip test when event shape is non-trivial.")
if jax_dist is dist.ZeroSumNormal:
pytest.skip("skip gof test for ZeroSumNormal")

num_samples = 10000
if "BetaProportion" in jax_dist.__name__:
num_samples = 20000
rng_key = random.PRNGKey(0)
d = jax_dist(*params)
samples = d.sample(key=rng_key, sample_shape=(num_samples,))
probs = np.exp(d.log_prob(samples))

dim = None
if jax_dist is dist.ProjectedNormal:
dim = samples.shape[-1] - 1

# Test each batch independently.
probs = probs.reshape(num_samples, -1)
samples = samples.reshape(probs.shape + d.event_shape)
if "Dirichlet" in jax_dist.__name__:
# The Dirichlet density is over all but one of the probs.
samples = samples[..., :-1]
for b in range(probs.shape[1]):
try:
gof = auto_goodness_of_fit(samples[:, b], probs[:, b], dim=dim)
except InvalidTest:
pytest.skip("expensive test")
else:
assert gof > TEST_FAILURE_RATE


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DISCRETE)
def test_independent_shape(jax_dist, sp_dist, params):
d = jax_dist(*params)
Expand Down
60 changes: 60 additions & 0 deletions test/test_gof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
from test_distributions import CONTINUOUS, DIRECTIONAL

import jax.random as random

import numpyro.distributions as dist
from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit

TEST_FAILURE_RATE = 2e-5 # For all goodness-of-fit tests.


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DIRECTIONAL)
def test_gof(jax_dist, sp_dist, params):
if "Improper" in jax_dist.__name__:
pytest.skip("distribution has improper .log_prob()")
if "LKJ" in jax_dist.__name__ or "Wishart" in jax_dist.__name__:
pytest.xfail("incorrect submanifold scaling")
if jax_dist is dist.EulerMaruyama:
d = jax_dist(*params)
if d.event_dim > 1:
pytest.skip("EulerMaruyama skip test when event shape is non-trivial.")
if jax_dist is dist.ZeroSumNormal:
pytest.skip("skip gof test for ZeroSumNormal")
if jax_dist is dist.MatrixNormal:
pytest.skip(
"skip gof test for MatrixNormal, likely incorrect submanifold scaling"
)

num_samples = 10000
if any(
name in jax_dist.__name__
for name in ["BetaProportion", "SineBivariateVonMises"]
):
num_samples = 20000
rng_key = random.PRNGKey(0)
d = jax_dist(*params)
samples = d.sample(key=rng_key, sample_shape=(num_samples,))
probs = np.exp(d.log_prob(samples))

dim = None
if jax_dist is dist.ProjectedNormal:
dim = samples.shape[-1] - 1

# Test each batch independently.
probs = probs.reshape(num_samples, -1)
samples = samples.reshape(probs.shape + d.event_shape)
if "Dirichlet" in jax_dist.__name__:
# The Dirichlet density is over all but one of the probs.
samples = samples[..., :-1]
for b in range(probs.shape[1]):
try:
gof = auto_goodness_of_fit(samples[:, b], probs[:, b], dim=dim)
except InvalidTest:
pytest.skip("expensive test")
else:
assert gof > TEST_FAILURE_RATE