Skip to content

Commit

Permalink
haiku.base: handle both new-style and old-style PRNG keys
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 534172296
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 22, 2023
1 parent bf933a8 commit dc9bfd3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
13 changes: 9 additions & 4 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,9 +994,11 @@ def assert_is_prng_key(key: PRNGKey):
# device-to-host copy.
make_error = lambda: ValueError( # pylint: disable=g-long-lambda
f"The provided key is not a JAX PRNGKey but a {type(key)}:\n{key}")
if jax.config.jax_enable_custom_prng:
if not isinstance(key, jax.random.KeyArray):
raise make_error()
if (jax.config.jax_enable_custom_prng and
not isinstance(key, jax.random.PRNGKeyArray)):
raise make_error()

if isinstance(key, jax.random.PRNGKeyArray):
if key.shape:
raise ValueError(
"Provided key did not have expected shape and/or dtype: "
Expand Down Expand Up @@ -1042,7 +1044,10 @@ class PRNGSequence(Iterator[PRNGKey]):

def __init__(self, key_or_seed: Union[PRNGKey, int, PRNGSequenceState]):
"""Creates a new :class:`PRNGSequence`."""
if isinstance(key_or_seed, tuple):
if isinstance(key_or_seed, jax.random.PRNGKeyArray):
self._key = jnp.asarray(key_or_seed)
self._subkeys = collections.deque()
elif isinstance(key_or_seed, tuple):
key, subkeys = key_or_seed
assert_is_prng_key(key)
for subkey in subkeys:
Expand Down
14 changes: 10 additions & 4 deletions haiku/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
import jax.numpy as jnp
import numpy as np

# TODO(jakevdp): remove attr check once jax>=0.4.11 is required
KEY_FUNCS = [jax.random.PRNGKey]
if hasattr(jax.random, "key"):
KEY_FUNCS.append(jax.random.key)

# TODO(tomhennigan) Improve test coverage.

custom_state_creator = functools.partial(
Expand Down Expand Up @@ -597,18 +602,19 @@ def test_new_state_in_apply(self):
self.assertEqual(ctx.collect_state(), {"~": {"count": 1}})

@parameterized.product(
seed=[42, 28], wrap_seed=[True, False], jitted=[True, False])
def test_prng_sequence(self, seed, wrap_seed, jitted):
seed=[42, 28], wrap_seed=[True, False], jitted=[True, False],
key_func=KEY_FUNCS)
def test_prng_sequence(self, seed, wrap_seed, jitted, key_func):
def create_random_values(key_or_seed):
key_seq = base.PRNGSequence(key_or_seed)
return (jax.random.normal(next(key_seq), []),
jax.random.normal(next(key_seq), []))
# Values using our sequence.
key_or_seed = jax.random.PRNGKey(seed) if wrap_seed else seed
key_or_seed = key_func(seed) if wrap_seed else seed
seq_v1, seq_v2 = (jax.jit(create_random_values)(key_or_seed)
if jitted else create_random_values(key_or_seed))
# Generate values using manual splitting.
key = jax.random.PRNGKey(seed)
key = key_func(seed)
key, temp_key = jax.random.split(key)
raw_v1 = jax.random.normal(temp_key, [])
_, temp_key = jax.random.split(key)
Expand Down

0 comments on commit dc9bfd3

Please sign in to comment.