Skip to content

Commit

Permalink
pygrain support for etils.enp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675149949
  • Loading branch information
Jan Hosang authored and The etils Authors committed Sep 16, 2024
1 parent 0b5ad73 commit f717193
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Changelog follow https://keepachangelog.com/ format.
## [Unreleased]

* `epy`: Add frozen dataclass support for `epy.ContextManager`
* `enp`: Add `ArraySpec` support for `grain.python.SharedMemoryArrays`.

## [1.9.4] - 2024-09-03

Expand Down
15 changes: 15 additions & 0 deletions etils/enp/array_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def from_array(cls, array: Array) -> Optional[ArraySpec]:
elif _is_grain(array):
shape = array.shape
dtype = array.dtype
elif _is_pygrain(array):
shape = array.shape
dtype = array.dtype
elif _is_orbax(array):
shape = array.shape
dtype = array.dtype
Expand Down Expand Up @@ -143,6 +146,7 @@ def is_fake_array(array: Array) -> bool:
or isinstance(array, ArraySpec)
or _is_orbax(array)
or _is_grain(array)
or _is_pygrain(array)
or _is_flax_summarry(array)
or isinstance(array, array_types.ArrayAliasMeta)
)
Expand All @@ -164,6 +168,17 @@ def _is_grain(array: Array) -> bool:
return isinstance(array, grain.ArraySpec)


def _is_pygrain(array: Array) -> bool:
if (
'grain._src.python' not in sys.modules
and 'grain.python' not in sys.modules
):
return False
from grain._src.python import shared_memory_array # pylint: disable=g-import-not-at-top # pytype: disable=import-error

return isinstance(array, shared_memory_array.SharedMemoryArrayMetadata)


def _is_orbax(array: Array) -> bool:
if 'orbax.checkpoint' not in sys.modules:
return False
Expand Down

0 comments on commit f717193

Please sign in to comment.