Skip to content

Commit

Permalink
Update handling of typed PRNG keys
Browse files Browse the repository at this point in the history
This follows the recommendations of [JEP 9263](jax-ml/jax#17297)
  • Loading branch information
jakevdp committed Sep 8, 2023
1 parent bec1c80 commit 9a1496f
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,19 +1189,24 @@ def _is_valid_variables(variables: VariableDict) -> bool:

def _is_valid_rng(rng: Array):
"""Checks whether rng is a valid JAX PRNGKey, also handling custom prngs."""
# New-style JAX KeyArrays have a base type.
if jax_config.jax_enable_custom_prng: # type: ignore[attr-defined]
if not isinstance(rng, jax.random.KeyArray):
return False
# Old-style JAX PRNGKeys are plain uint32 arrays.
else:
if not isinstance(rng, (np.ndarray, jnp.ndarray)):
return False
if (
rng.shape != random.default_prng_impl().key_shape
or rng.dtype != jnp.uint32
):
return False
# This check is valid for either new-style or old-style PRNG keys
if not isinstance(rng, (np.ndarray, jnp.ndarray)):
return False

# Handle new-style typed PRNG keys
if hasattr(jax.dtypes, 'prng_key'): # JAX 0.4.14 or newer
if jax.dtypes.issubdtype(rng.dtype, jax.dtypes.prng_key):
return True
elif hasattr(jax.random, 'KeyArray'): # Previous JAX versions
if isinstance(rng, jax.random.KeyArray):
return True

# Handle old-style raw PRNG keys
if (
rng.shape != random.default_prng_impl().key_shape
or rng.dtype != jnp.uint32
):
return False
return True


Expand Down

0 comments on commit 9a1496f

Please sign in to comment.