Skip to content

Commit

Permalink
haiku: avoid calls to deprecated jax.random APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575590991
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Oct 22, 2023
1 parent 7a6faba commit 86a00ea
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions haiku/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ hk_py_test(
# pip: absl/testing:absltest
# pip: absl/testing:parameterized
# pip: jax
# pip: jax:extend
# pip: numpy
],
)
Expand Down
6 changes: 3 additions & 3 deletions haiku/_src/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from haiku._src import random
from haiku._src import transform
import jax
from jax import prng
import jax.extend as jex
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -97,7 +97,7 @@ def count_splits(_, num):
num = tuple(num) if isinstance(num, Sequence) else (num,)
return jnp.zeros((*num, 13), np.uint32)

differently_shaped_prng_impl = prng.PRNGImpl(
differently_shaped_prng_impl = jex.random.PRNGImpl(
# Testing a different key shape to make sure it's accepted by Haiku
key_shape=(13,),
seed=lambda _: jnp.zeros((13,), np.uint32),
Expand All @@ -109,7 +109,7 @@ def count_splits(_, num):
init, _ = transform.transform(base.next_rng_key)
if do_jit:
init = jax.jit(init)
key = prng.seed_with_impl(differently_shaped_prng_impl, 42)
key = jex.random.seed_with_impl(differently_shaped_prng_impl, 42)
init(key)
self.assertEqual(count, 1)

Expand Down
4 changes: 2 additions & 2 deletions requirements-jax.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
jax>=0.4.13
jaxlib>=0.4.13
jax>=0.4.16
jaxlib>=0.4.16

0 comments on commit 86a00ea

Please sign in to comment.