diff --git a/haiku/_src/BUILD b/haiku/_src/BUILD index 918bf3687..9133b7ed8 100644 --- a/haiku/_src/BUILD +++ b/haiku/_src/BUILD @@ -960,6 +960,7 @@ hk_py_test( # pip: absl/testing:absltest # pip: absl/testing:parameterized # pip: jax + # pip: jax:extend # pip: numpy ], ) diff --git a/haiku/_src/random_test.py b/haiku/_src/random_test.py index ad891dd73..0240cb165 100644 --- a/haiku/_src/random_test.py +++ b/haiku/_src/random_test.py @@ -23,7 +23,7 @@ from haiku._src import random from haiku._src import transform import jax -from jax import prng +import jax.extend as jex import jax.numpy as jnp import numpy as np @@ -97,7 +97,7 @@ def count_splits(_, num): num = tuple(num) if isinstance(num, Sequence) else (num,) return jnp.zeros((*num, 13), np.uint32) - differently_shaped_prng_impl = prng.PRNGImpl( + differently_shaped_prng_impl = jex.random.PRNGImpl( # Testing a different key shape to make sure it's accepted by Haiku key_shape=(13,), seed=lambda _: jnp.zeros((13,), np.uint32), @@ -109,7 +109,7 @@ def count_splits(_, num): init, _ = transform.transform(base.next_rng_key) if do_jit: init = jax.jit(init) - key = prng.seed_with_impl(differently_shaped_prng_impl, 42) + key = jex.random.seed_with_impl(differently_shaped_prng_impl, 42) init(key) self.assertEqual(count, 1) diff --git a/requirements-jax.txt b/requirements-jax.txt index adcce858f..a305fd19e 100644 --- a/requirements-jax.txt +++ b/requirements-jax.txt @@ -1,2 +1,2 @@ -jax>=0.4.13 -jaxlib>=0.4.13 +jax>=0.4.16 +jaxlib>=0.4.16