Skip to content

Commit

Permalink
Expand StepMetadata.item_handlers type to include non-composite cases.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707324273
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Dec 18, 2024
1 parent a54e71c commit 6faf7af
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,13 @@ def metadata(self, directory: epath.Path) -> StepMetadata:
saved_metadata = step_metadata_serialization.deserialize(
serialized_metadata or {}
)
item_handlers = saved_metadata.item_handlers or {}
if saved_metadata.item_handlers is not None:
assert isinstance(saved_metadata.item_handlers, dict)
item_handlers: dict[str, checkpoint.CheckpointHandlerTypeStr] = (
saved_metadata.item_handlers
)
else:
item_handlers: dict[str, checkpoint.CheckpointHandlerTypeStr] = {}
item_metadata = dict(saved_metadata.item_metadata or {})
assert item_handlers.keys() == item_metadata.keys()

Expand Down
11 changes: 8 additions & 3 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import json
import threading
from typing import Any, Protocol, TypeVar
from typing import Any, Mapping, Protocol, TypeAlias, TypeVar

from absl import logging
from etils import epath
Expand All @@ -31,6 +31,8 @@
_LEGACY_ROOT_METADATA_FILENAME = 'metadata'

ItemMetadata = composite.Composite
CompositeCheckpointHandlerTypeStrs: TypeAlias = Mapping
CheckpointHandlerTypeStr = str
StepStatistics = step_statistics.SaveStepStatistics
SerializedMetadata = TypeVar('SerializedMetadata', bound=dict[str, Any])

Expand Down Expand Up @@ -67,7 +69,8 @@ class StepMetadata:
Attributes:
format: The checkpoint file format. Users should specify the format
explicitly when using something non-standard.
item_handlers: Map of item name to its checkpoint handler.
item_handlers: Map of item name to its checkpoint handler. Or a single
checkpoint handler for non composite checkpoints.
item_metadata: Map of item name to its metadata.
metrics: User-provided metrics (accuracy, loss, etc.)
performance_metrics: Performance metrics (time, memory, etc.)
Expand All @@ -79,7 +82,9 @@ class StepMetadata:
"""

format: str | None = None
item_handlers: dict[str, str] = dataclasses.field(default_factory=dict)
item_handlers: (
dict[str, CheckpointHandlerTypeStr] | CheckpointHandlerTypeStr | None
) = None
item_metadata: ItemMetadata | None = None
metrics: dict[str, Any] = dataclasses.field(default_factory=dict)
performance_metrics: StepStatistics = dataclasses.field(
Expand Down
9 changes: 9 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,15 @@ def test_unknown_key_in_metadata(
):
self.deserialize_metadata(metadata_class, serialized_metadata)

@parameterized.named_parameters(
('single', 'a_handler'),
('composite', {'a': 'a_handler'}),
)
def test_deserialize_item_handlers(self, item_handlers):
serialized_metadata = {'item_handlers': item_handlers}
metadata = self.deserialize_metadata(StepMetadata, serialized_metadata)
self.assertEqual(metadata.item_handlers, item_handlers)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,39 @@

"""Utilities for serializing and deserializing metadata."""

from typing import Any
from typing import Any, Sequence


def validate_field(
obj: Any,
field_name: str,
field_type: type[Any]
field_type: type[Any] | Sequence[type[Any]]
):
"""Validates a single field in a dictionary.
field_type can optionally be a sequence of types, in which case the field
must be of any one of the types in the sequence.
Args:
obj: The object to validate.
field_name: The name of the field to validate.
field_type: The type (or sequence of types) of the field to validate.
"""
if field_name not in obj or obj[field_name] is None:
return
field = obj[field_name]
if not isinstance(field, field_type):
raise ValueError(
'StepMetadata {} must be of type {}, got {}.'.format(
field_name, field_type, type(field)
)
)
if isinstance(field_type, Sequence):
if not any(isinstance(field, f_type) for f_type in field_type):
raise ValueError(
f'StepMetadata field "{field_name}" must be of type {field_type}, '
f'got {type(field)}.'
)
else:
if not isinstance(field, field_type):
raise ValueError(
f'StepMetadata field "{field_name}" must be of type {field_type}, '
f'got {type(field)}.'
)


def validate_dict_entry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
SerializedMetadata = checkpoint.SerializedMetadata
StepMetadata = checkpoint.StepMetadata
ItemMetadata = checkpoint.ItemMetadata
CompositeCheckpointHandlerTypeStrs = (
checkpoint.CompositeCheckpointHandlerTypeStrs
)
CheckpointHandlerTypeStr = checkpoint.CheckpointHandlerTypeStr
StepStatistics = step_statistics.SaveStepStatistics


Expand Down Expand Up @@ -79,12 +83,14 @@ def deserialize(
utils.validate_field(metadata_dict, 'format', str)
validated_metadata_dict['format'] = metadata_dict.get('format', None)

utils.validate_field(metadata_dict, 'item_handlers', dict)
for k in metadata_dict.get('item_handlers', {}) or {}:
utils.validate_dict_entry(metadata_dict, 'item_handlers', k, str)
validated_metadata_dict['item_handlers'] = metadata_dict.get(
'item_handlers', {}
)
utils.validate_field(metadata_dict, 'item_handlers', [dict, str])
item_handlers = metadata_dict.get('item_handlers')
if isinstance(item_handlers, CompositeCheckpointHandlerTypeStrs):
for k in metadata_dict.get('item_handlers', {}) or {}:
utils.validate_dict_entry(metadata_dict, 'item_handlers', k, str)
validated_metadata_dict['item_handlers'] = item_handlers
elif isinstance(item_handlers, CheckpointHandlerTypeStr):
validated_metadata_dict['item_handlers'] = item_handlers

utils.validate_field(metadata_dict, 'item_metadata', dict)
for k in metadata_dict.get('item_metadata', {}) or {}:
Expand Down

0 comments on commit 6faf7af

Please sign in to comment.