Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715616303
  • Loading branch information
niketkumar authored and Orbax Authors committed Jan 29, 2025
1 parent 1b39903 commit f63d713
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 17 deletions.
32 changes: 21 additions & 11 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,17 +1681,6 @@ def has(self, ty: Any) -> bool:
return False


GLOBAL_TYPE_HANDLER_REGISTRY = _TypeHandlerRegistryImpl(
(int, ScalarHandler()),
(float, ScalarHandler()),
(bytes, ScalarHandler()),
(np.number, ScalarHandler()),
(np.ndarray, NumpyHandler()),
(jax.Array, ArrayHandler()),
(str, StringHandler()),
)


def create_type_handler_registry(
*handlers: Tuple[Any, types.TypeHandler]
) -> types.TypeHandlerRegistry:
Expand All @@ -1707,6 +1696,22 @@ def create_type_handler_registry(
return _TypeHandlerRegistryImpl(*handlers)


_DEFAULT_TYPE_HANDLERS = tuple([
(int, ScalarHandler()),
(float, ScalarHandler()),
(bytes, ScalarHandler()),
(np.number, ScalarHandler()),
(np.ndarray, NumpyHandler()),
(jax.Array, ArrayHandler()),
(str, StringHandler()),
])


GLOBAL_TYPE_HANDLER_REGISTRY = create_type_handler_registry(
*_DEFAULT_TYPE_HANDLERS
)


def register_type_handler(
ty: Any,
handler: types.TypeHandler,
Expand Down Expand Up @@ -1744,6 +1749,11 @@ def has_type_handler(ty: Any) -> bool:
return GLOBAL_TYPE_HANDLER_REGISTRY.has(ty)


def supported_types() -> list[Any]:
"""Returns the default list of supported types."""
return [ty for ty, _ in _DEFAULT_TYPE_HANDLERS]


def register_standard_handlers_with_options(**kwargs):
"""Re-registers a select set of handlers with the given options.
Expand Down
43 changes: 37 additions & 6 deletions checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from __future__ import annotations

import asyncio
from collections.abc import Sequence
from concurrent import futures
import contextlib
import copy
import inspect
import time
import typing
Expand Down Expand Up @@ -48,6 +50,8 @@
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.tree import utils as tree_utils

PyTree = Any


def sync_global_processes(name: str):
jax.experimental.multihost_utils.sync_global_devices(name)
Expand Down Expand Up @@ -514,22 +518,34 @@ def register_type_handler(ty, handler, func):
)


@contextlib.contextmanager
def global_type_handler_registry_context():
"""Context manager for changing the GLOBAL_TYPE_HANDLER_REGISTRY."""
original_type_handlers = copy.deepcopy(type_handlers._DEFAULT_TYPE_HANDLERS)
try:
yield
finally:
for original_type, original_handler in original_type_handlers:
type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY.add(
original_type, original_handler, override=True
)


@contextlib.contextmanager
def ocdbt_checkpoint_context(use_ocdbt: bool, ts_context: Any):
"""Use OCDBT driver within context."""
original_registry = list(
type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY._type_registry
)
original_type_handlers = copy.deepcopy(type_handlers._DEFAULT_TYPE_HANDLERS)
if use_ocdbt:
type_handlers.register_standard_handlers_with_options(
use_ocdbt=use_ocdbt, ts_context=ts_context
)
try:
yield
finally:
type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY._type_registry = (
original_registry
)
for original_type, original_handler in original_type_handlers:
type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY.add(
original_type, original_handler, override=True
)


def _get_test_wrapper_function(test_func):
Expand Down Expand Up @@ -724,3 +740,18 @@ def get_expected_chunk_shape(
if local_shape is None:
local_shape = arr.sharding.shard_shape(arr.shape)
return local_shape


def filter_metadata_fields(
pytree: PyTree, include_fields: Sequence[str]
) -> PyTree:
"""Returns a PyTree of dicts with keys in `include_fields`."""

def _include(metadata):
result = {}
for f in include_fields:
if hasattr(metadata, f):
result[f] = getattr(metadata, f)
return result

return jax.tree.map(_include, pytree)
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from orbax.checkpoint._src.serialization.type_handlers import has_type_handler
from orbax.checkpoint._src.serialization.type_handlers import register_standard_handlers_with_options
from orbax.checkpoint._src.serialization.type_handlers import register_type_handler
from orbax.checkpoint._src.serialization.type_handlers import supported_types

from orbax.checkpoint._src.serialization.type_handlers import is_ocdbt_checkpoint

Expand Down

0 comments on commit f63d713

Please sign in to comment.