From d044a6fc40601b3dd64dda6105326d118592949b Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Tue, 21 Mar 2023 08:09:16 -0700 Subject: [PATCH] Add ArrayDeviceTree and ArrayNumpyTree pytypes + improve ArrayTree. PiperOrigin-RevId: 518275284 --- chex/_src/pytypes.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index 65cc8d51..da747ad9 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -14,7 +14,8 @@ # ============================================================================== """Type definitions to use for type annotations.""" -from typing import Any, Union +from typing import Any, Iterable, Mapping, Union + import jax import jax.numpy as jnp import numpy as np @@ -44,11 +45,14 @@ ] # A tree of generic arrays. - -# Should be Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]. -# Setting to Any for now to not break the existing code that depends on -# dynamically registered jax pytrees. -ArrayTree = Any +# Note that it does not support dynamically registered pytrees. +ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] +ArrayDeviceTree = Union[ + ArrayDevice, Iterable['ArrayDeviceTree'], Mapping[Any, 'ArrayDeviceTree'] +] +ArrayNumpyTree = Union[ + ArrayNumpy, Iterable['ArrayNumpyTree'], Mapping[Any, 'ArrayNumpyTree'] +] # Other types. Scalar = Union[float, int]