Skip to content

Commit

Permalink
Add ArrayDeviceTree and ArrayNumpyTree pytypes + improve ArrayTree.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 518275284
  • Loading branch information
hbq1 authored and ChexDev committed Mar 22, 2023
1 parent 48f60f0 commit 109da53
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions chex/_src/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# ==============================================================================
"""Type definitions to use for type annotations."""

from typing import Any, Union
from typing import Any, Iterable, Mapping, Union

import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -44,11 +45,13 @@
]

# A tree of generic arrays.

# Should be Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']].
# Setting to Any for now to not break the existing code that depends on
# dynamically registered jax pytrees.
ArrayTree = Any
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
ArrayDeviceTree = Union[
ArrayDevice, Iterable['ArrayDeviceTree'], Mapping[Any, 'ArrayDeviceTree']
]
ArrayNumpyTree = Union[
ArrayNumpy, Iterable['ArrayNumpyTree'], Mapping[Any, 'ArrayNumpyTree']
]

# Other types.
Scalar = Union[float, int]
Expand Down

0 comments on commit 109da53

Please sign in to comment.