From c412811726c4c30451e71fa8b068b9a1f632f454 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Sat, 23 Nov 2024 11:01:54 -0800 Subject: [PATCH] Fix orbax error after new release PiperOrigin-RevId: 699508213 --- CHANGELOG.md | 1 + etils/enp/array_spec.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93fe4979..9cb74e90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ Changelog follow https://keepachangelog.com/ format. * `enp`: * Make `enp.testing.parametrize_xnp()` import only requested xnp modules. + * Fix orbax error when inspecting specs of an orbax checkpoint. * `ecolab`: * `ecolab.inspect`: Proto are better displayed (hide attributes `DESCRIPTOR`, `Extensions` in sub-section) diff --git a/etils/enp/array_spec.py b/etils/enp/array_spec.py index 0633faa1..91324b89 100644 --- a/etils/enp/array_spec.py +++ b/etils/enp/array_spec.py @@ -182,12 +182,12 @@ def _is_pygrain(array: Array) -> bool: def _is_orbax(array: Array) -> bool: if 'orbax.checkpoint' not in sys.modules: return False - from orbax import checkpoint as ocp # pylint: disable=g-import-not-at-top # pytype: disable=import-error + from orbax.checkpoint._src.serialization import type_handlers # pylint: disable=g-import-not-at-top # pytype: disable=import-error return isinstance( array, ( - ocp.type_handlers.ArrayMetadata, - ocp.type_handlers.ScalarMetadata, + type_handlers.ArrayMetadata, + type_handlers.ScalarMetadata, ), )