Skip to content

Commit

Permalink
filter out tests waiting for next tfp release (#1817)
Browse files Browse the repository at this point in the history
* filter oout tests waiting for next tfp release

* add warning mathc

* fix warnings

* put warnings up

* move warning to top init

* use tree_map

* ignore yet another test
  • Loading branch information
juanitorduz authored Jun 22, 2024
1 parent 40565d0 commit 9785376
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 2 deletions.
7 changes: 7 additions & 0 deletions numpyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

import logging
import warnings

warnings.filterwarnings(
"ignore", message=".*Attempting to hash a tracer.*", category=FutureWarning
)

# ruff: noqa: E402

from numpyro import compat, diagnostics, distributions, handlers, infer, ops, optim
from numpyro.distributions.distribution import enable_validation, validation_enabled
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def body_fn(wrapped_carry, x, prefix=None):
# return early if length = unroll_steps
if length == unroll_steps:
return wrapped_carry, (PytreeTrace({}), y0s)
wrapped_carry = device_put(wrapped_carry)
wrapped_carry = tree_map(device_put, wrapped_carry)
wrapped_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
)
Expand Down Expand Up @@ -324,7 +324,7 @@ def body_fn(wrapped_carry, x):

return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

wrapped_carry = device_put((0, rng_key, init))
wrapped_carry = tree_map(device_put, (0, rng_key, init))
last_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs, length=length, reverse=reverse
)
Expand Down
1 change: 1 addition & 0 deletions numpyro/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from numpyro.infer.barker import BarkerMH
from numpyro.infer.elbo import (
ELBO,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ known-jax = ["flax", "haiku", "jax", "optax", "tensorflow_probability"]
addopts = ["-v", "--color=yes"]
filterwarnings = [
"error",
"ignore:.*Attempting to hash a tracer:FutureWarning",
"ignore:numpy.ufunc size changed,:RuntimeWarning",
"ignore:Using a non-tuple sequence:FutureWarning",
"ignore:jax.tree_structure is deprecated:FutureWarning",
Expand Down
5 changes: 5 additions & 0 deletions test/contrib/test_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def f(x):
assert res.scale == 1


@pytest.mark.skip(reason="Waiting for the next tfp release")
@pytest.mark.filterwarnings("ignore:can't resolve package")
def test_transformed_distributions():
from tensorflow_probability.substrates.jax import (
Expand Down Expand Up @@ -113,6 +114,7 @@ def make_kernel_fn(target_log_prob_fn):
)


@pytest.mark.skip(reason="Waiting for the next tfp release")
@pytest.mark.parametrize(
"kernel, kwargs",
[
Expand Down Expand Up @@ -166,6 +168,7 @@ def model(data):
assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)


@pytest.mark.skip(reason="Waiting for the next tfp release")
@pytest.mark.parametrize(
"kernel, kwargs",
[
Expand Down Expand Up @@ -243,6 +246,7 @@ def test_sample_tfp_distributions():

# test that sampling from unwrapped tensorflow_probability distributions works as
# expected using numpyro.sample primitive
@pytest.mark.skip(reason="Waiting for the next tfp release")
@pytest.mark.parametrize(
"dist,args",
[
Expand Down Expand Up @@ -270,6 +274,7 @@ def test_sample_unwrapped_tfp_distributions(dist, args):


# test mixture distributions
@pytest.mark.skip(reason="Waiting for the next tfp release")
def test_sample_unwrapped_mixture_same_family():
from tensorflow_probability.substrates.jax import distributions as tfd

Expand Down

0 comments on commit 9785376

Please sign in to comment.