diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index ade208754d..e8ecc42d05 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -600,12 +600,10 @@ def test_random_concrete_shape_subtensor_tuple(): assert jax_fn(np.ones((2, 3))).shape == (2,) -@pytest.mark.xfail( - reason="`size_at` should be specified as a static argument", strict=True -) def test_random_concrete_shape_graph_input(): + """JAX cannot JIT-compile random variables whose `size` argument is not static.""" rng = shared(np.random.RandomState(123)) size_at = at.scalar() out = at.random.normal(0, 1, size=size_at, rng=rng) - jax_fn = function([size_at], out, mode=jax_mode) - assert jax_fn(10).shape == (10,) + with pytest.raises(NotImplementedError, match=r".* concrete values .*"): + function([size_at], out, mode=jax_mode)