Skip to content

Commit

Permalink
Add an xfail for newer JAX versions that change sampler size behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 15, 2021
1 parent e4ffb02 commit d648969
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tests/link/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,10 @@ def test_extra_ops_omni():
compare_jax_and_py(fgraph, [])


@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.26"),
reason="JAX samplers require concrete/static shape values?",
)
@pytest.mark.parametrize(
"at_dist, dist_params, rng, size",
[
Expand Down

0 comments on commit d648969

Please sign in to comment.