diff --git a/CHANGELOG.md b/CHANGELOG.md index cb103a7f..c40132c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Changelog follow https://keepachangelog.com/ format. ## [Unreleased] * `epy`: Add frozen dataclass support for `epy.ContextManager` +* `enp`: Add `ArraySpec` support for `grain.python.SharedMemoryArrays`. ## [1.9.4] - 2024-09-03 diff --git a/etils/enp/array_spec.py b/etils/enp/array_spec.py index 9caa12d4..0633faa1 100644 --- a/etils/enp/array_spec.py +++ b/etils/enp/array_spec.py @@ -115,6 +115,9 @@ def from_array(cls, array: Array) -> Optional[ArraySpec]: elif _is_grain(array): shape = array.shape dtype = array.dtype + elif _is_pygrain(array): + shape = array.shape + dtype = array.dtype elif _is_orbax(array): shape = array.shape dtype = array.dtype @@ -143,6 +146,7 @@ def is_fake_array(array: Array) -> bool: or isinstance(array, ArraySpec) or _is_orbax(array) or _is_grain(array) + or _is_pygrain(array) or _is_flax_summarry(array) or isinstance(array, array_types.ArrayAliasMeta) ) @@ -164,6 +168,17 @@ def _is_grain(array: Array) -> bool: return isinstance(array, grain.ArraySpec) +def _is_pygrain(array: Array) -> bool: + if ( + 'grain._src.python' not in sys.modules + and 'grain.python' not in sys.modules + ): + return False + from grain._src.python import shared_memory_array # pylint: disable=g-import-not-at-top # pytype: disable=import-error + + return isinstance(array, shared_memory_array.SharedMemoryArrayMetadata) + + def _is_orbax(array: Array) -> bool: if 'orbax.checkpoint' not in sys.modules: return False