Skip to content

Commit

Permalink
Test that exception is raised when size is not static
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 9, 2023
1 parent 2fa3f58 commit bf58a94
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,12 +600,10 @@ def test_random_concrete_shape_subtensor_tuple():
assert jax_fn(np.ones((2, 3))).shape == (2,)


@pytest.mark.xfail(
reason="`size_at` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input():
"""JAX cannot JIT-compile random variables whose `size` argument is not static."""
rng = shared(np.random.RandomState(123))
size_at = at.scalar()
out = at.random.normal(0, 1, size=size_at, rng=rng)
jax_fn = function([size_at], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)
with pytest.raises(NotImplementedError, match=r".* concrete values .*"):
function([size_at], out, mode=jax_mode)

0 comments on commit bf58a94

Please sign in to comment.