diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index 65cc8d51..e6326f56 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -14,7 +14,8 @@ # ============================================================================== """Type definitions to use for type annotations.""" -from typing import Any, Union +from typing import Any, Iterable, Mapping, Union + import jax import jax.numpy as jnp import numpy as np @@ -44,11 +45,13 @@ ] # A tree of generic arrays. - -# Should be Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]. -# Setting to Any for now to not break the existing code that depends on -# dynamically registered jax pytrees. -ArrayTree = Any +ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] +ArrayDeviceTree = Union[ + ArrayDevice, Iterable['ArrayDeviceTree'], Mapping[Any, 'ArrayDeviceTree'] +] +ArrayNumpyTree = Union[ + ArrayNumpy, Iterable['ArrayNumpyTree'], Mapping[Any, 'ArrayNumpyTree'] +] # Other types. Scalar = Union[float, int]