From b7ba49d3d1a0eb58abe53cc657ff7127837788a8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 13 Sep 2023 12:49:29 -0700 Subject: [PATCH] alias chex.PRNGKey to jax.Array Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](https://github.com/google/jax/pull/17297) for details) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously. PiperOrigin-RevId: 565133147 --- chex/_src/pytypes.py | 2 +- requirements/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index 4996f432..66d41a8a 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -51,7 +51,7 @@ Scalar = Union[float, int] Numeric = Union[Array, Scalar] Shape = jax.core.Shape -PRNGKey = Union[jax.random.KeyArray, jax.Array] +PRNGKey = jax.Array PyTreeDef = jax.tree_util.PyTreeDef Device = jax.Device ArrayDType = type(jnp.float32) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 34e84f1a..3b349c67 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ absl-py>=0.9.0 typing_extensions>=4.2.0 -jax>=0.4.6 +jax>=0.4.16 jaxlib>=0.1.37 numpy>=1.24.1 toolz>=0.9.0