Skip to content

Commit

Permalink
Refactor metadata de/serialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715155343
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Jan 14, 2025
1 parent 2a7e309 commit c391304
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 227 deletions.
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ py_library(
py_library(
name = "metadata_serialization_utils",
srcs = ["metadata_serialization_utils.py"],
deps = [
":checkpoint",
"//checkpoint/orbax/checkpoint/_src/logging:step_statistics",
],
)

py_library(
Expand Down
67 changes: 23 additions & 44 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,10 @@ def test_deserialize_wrong_types_root_metadata(
@parameterized.parameters(
({'item_handlers': list()},),
({'item_handlers': {int(): None}},),
({'metrics': list()},),
({'metrics': [int()]},),
({'metrics': {int(): None}},),
({'performance_metrics': list()},),
({'performance_metrics': {int(): float()}},),
({'performance_metrics': {str(): int()}},),
({'init_timestamp_nsecs': float()},),
({'commit_timestamp_nsecs': float()},),
({'custom': list()},),
Expand Down Expand Up @@ -566,7 +565,7 @@ def test_deserialize_item_metadata_with_item_metadata_kwarg(
@parameterized.parameters(
('metrics', 1, dict, int),
)
def test_validate_type_wrong_type(
def test_deserialize_metrics_kwarg(
self, kwarg_name, kwarg_value, expected_type, wrong_type
):
with self.assertRaisesRegex(
Expand All @@ -576,74 +575,54 @@ def test_validate_type_wrong_type(
self.deserialize_metadata(StepMetadata, {}, **{kwarg_name: kwarg_value})

@parameterized.parameters(
({'metrics': 1}, 'metrics', dict, int),
({'performance_metrics': 1}, 'performance_metrics', dict, int),
({'init_timestamp_nsecs': 'a'}, 'init_timestamp_nsecs', int, str),
({'commit_timestamp_nsecs': 'a'}, 'commit_timestamp_nsecs', int, str),
({'custom': 1}, 'custom', dict, int),
({'metrics': 1}, dict, int),
({'init_timestamp_nsecs': 'a'}, int, str),
({'commit_timestamp_nsecs': 'a'}, int, str),
({'custom': 1}, dict, int),
)
def test_validate_field_wrong_type(
self, step_metadata, field_name, expected_field_type, wrong_field_type
self, step_metadata, expected_field_type, wrong_field_type
):
with self.assertRaisesRegex(
ValueError,
f'Metadata field "{field_name}" must be of type '
f'{expected_field_type}, got {wrong_field_type}.',
f'Object must be of type {expected_field_type}, got {wrong_field_type}',
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'item_handlers': 1}, 'item_handlers', int),
({'item_handlers': 1}, int),
({'performance_metrics': 1}, int),
)
def test_validate_field_sequence_wrong_type(
self, step_metadata, field_name, wrong_field_type
self, step_metadata, wrong_field_type
):
with self.assertRaisesRegex(
ValueError,
f'Metadata field "{field_name}" must be any one of types '
f'Object must be any one of types '
r'\[.+]'
f', got {wrong_field_type}.',
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'item_handlers': {1: 'a'}}, 'item_handlers', str, int),
({'metrics': {1: 'a'}}, 'metrics', str, int),
(
{'performance_metrics': {1: 1.0}},
'performance_metrics',
str,
int,
),
({'custom': {1: 'a'}}, 'custom', str, int),
({'item_handlers': {1: 'a'}}, str, int),
({'metrics': {1: 'a'}}, str, int),
({'performance_metrics': {1: 1.0}}, str, int),
({'custom': {1: 'a'}}, str, int),
)
def test_validate_dict_entry_wrong_key_type(
self, step_metadata, field_name, expected_key_type, wrong_key_type
self, step_metadata, expected_key_type, wrong_key_type
):
with self.assertRaisesRegex(
ValueError,
f'Metadata field "{field_name}" keys must be of type '
f'{expected_key_type}, got {wrong_key_type}.',
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'performance_metrics': {'a': 1}}, 'performance_metrics', float, int),
)
def test_validate_dict_entry_wrong_value_type(
self, step_metadata, field_name, expected_value_type, wrong_value_type
):
with self.assertRaisesRegex(
ValueError,
f'Metadata field "{field_name}" values must be of type '
f'{expected_value_type}, got {wrong_value_type}.',
f'Object must be of type {expected_key_type}, got {wrong_key_type}.',
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'item_handlers': {'a': 'b'}},),
({'performance_metrics': {'a': 1.0}},),
({'user_metadata': {'a': 1}, 'init_timestamp_nsecs': 1},),
({'custom': {'a': 1}, 'init_timestamp_nsecs': 1},),
)
def test_serialize_for_update_valid_kwargs(
self, kwargs: dict[str, Any]
Expand All @@ -656,13 +635,13 @@ def test_serialize_for_update_valid_kwargs(
@parameterized.parameters(
({'item_handlers': list()},),
({'item_handlers': {int(): None}},),
({'metrics': list()},),
({'metrics': [int()]},),
({'metrics': {int(): None}},),
({'performance_metrics': list()},),
({'init_timestamp_nsecs': float()},),
({'commit_timestamp_nsecs': float()},),
({'user_metadata': list()},),
({'user_metadata': {int(): None}},),
({'custom': list()},),
({'custom': {int(): None}},),
)
def test_serialize_for_update_wrong_types(
self, kwargs: dict[str, Any]
Expand All @@ -675,7 +654,7 @@ def test_serialize_for_update_with_unknown_kwargs(self):
ValueError, 'Provided metadata contains unknown key blah'
):
step_metadata_serialization.serialize_for_update(
user_metadata={'a': 1},
custom={'a': 1},
blah=123,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,62 +14,148 @@

"""Utilities for serializing and deserializing metadata."""

from typing import Any, Sequence
import dataclasses
from typing import Any, Optional, Sequence

from absl import logging
from orbax.checkpoint._src.logging import step_statistics
from orbax.checkpoint._src.metadata import checkpoint

def validate_type(obj: Any, field_type: type[Any]):
if not isinstance(obj, field_type):
raise ValueError(f'Object must be of type {field_type}, got {type(obj)}.')

CompositeCheckpointHandlerTypeStrs = (
checkpoint.CompositeCheckpointHandlerTypeStrs
)
CheckpointHandlerTypeStr = checkpoint.CheckpointHandlerTypeStr
CompositeItemMetadata = checkpoint.CompositeItemMetadata
SingleItemMetadata = checkpoint.SingleItemMetadata
StepStatistics = step_statistics.SaveStepStatistics


def validate_field(
obj: Any,
field_name: str,
field_type: type[Any] | Sequence[type[Any]]
):
"""Validates a single field in a dictionary.
field_type can optionally be a sequence of types, in which case the field
must be of any one of the types in the sequence.
Args:
obj: The object to validate.
field_name: The name of the field to validate.
field_type: The type (or sequence of types) of the field to validate.
"""
if field_name not in obj or obj[field_name] is None:
return
field = obj[field_name]
def _validate_type(obj: Any, field_type: type[Any] | Sequence[type[Any]]):
if isinstance(field_type, Sequence):
if not any(isinstance(field, f_type) for f_type in field_type):
if not any(isinstance(obj, f_type) for f_type in field_type):
raise ValueError(
f'Metadata field "{field_name}" must be any one of '
f'types {list(field_type)}, got {type(field)}.'
f'Object must be any one of types {list(field_type)}, got '
f'{type(obj)}.'
)
elif not isinstance(field, field_type):
raise ValueError(
f'Metadata field "{field_name}" must be of type {field_type}, '
f'got {type(field)}.'
)
elif not isinstance(obj, field_type):
raise ValueError(f'Object must be of type {field_type}, got {type(obj)}.')


def validate_and_process_item_handlers(
item_handlers: Any,
) -> CompositeCheckpointHandlerTypeStrs | CheckpointHandlerTypeStr | None:
"""Validates and processes item_handlers field."""
if item_handlers is None:
return None

_validate_type(item_handlers, [dict, str])
if isinstance(item_handlers, CompositeCheckpointHandlerTypeStrs):
for k in item_handlers or {}:
_validate_type(k, str)
return item_handlers
elif isinstance(item_handlers, CheckpointHandlerTypeStr):
return item_handlers


def validate_and_process_item_metadata(
item_metadata: Any,
) -> CompositeItemMetadata | SingleItemMetadata | None:
"""Validates and processes item_metadata field."""
if item_metadata is None:
return None

if isinstance(item_metadata, CompositeItemMetadata):
_validate_type(item_metadata, dict)
for k in item_metadata:
_validate_type(k, str)
return item_metadata
else:
return item_metadata


def validate_and_process_metrics(
metrics: Any,
additional_metrics: Optional[Any] = None
) -> dict[str, Any]:
"""Validates and processes metrics field."""
metrics = metrics or {}

_validate_type(metrics, dict)
for k in metrics:
_validate_type(k, str)
validated_metrics = metrics

if additional_metrics is not None:
_validate_type(additional_metrics, dict)
for k, v in additional_metrics.items():
_validate_type(k, str)
validated_metrics[k] = v

return validated_metrics


def validate_dict_entry(
dict_field: dict[Any, Any],
dict_field_name: str,
key: Any,
key_type: type[Any],
value_type: type[Any] | None = None,
):
"""Validates a single entry in a dictionary field."""
if not isinstance(key, key_type):
def validate_and_process_performance_metrics(
performance_metrics: Any,
) -> dict[str, float]:
"""Validates and processes performance_metrics field."""
if performance_metrics is None:
return {}

_validate_type(performance_metrics, [dict, StepStatistics])
if isinstance(performance_metrics, StepStatistics):
performance_metrics = dataclasses.asdict(performance_metrics)

for k in performance_metrics:
_validate_type(k, str)

return {
metric: val
for metric, val in performance_metrics.items()
if isinstance(val, float)
}


def validate_and_process_init_timestamp_nsecs(
init_timestamp_nsecs: Any,
) -> int | None:
"""Validates and processes init_timestamp_nsecs field."""
if init_timestamp_nsecs is None:
return None

_validate_type(init_timestamp_nsecs, int)
return init_timestamp_nsecs


def validate_and_process_commit_timestamp_nsecs(
commit_timestamp_nsecs: Any,
) -> int | None:
"""Validates and processes commit_timestamp_nsecs field."""
if commit_timestamp_nsecs is None:
return None

_validate_type(commit_timestamp_nsecs, int)
return commit_timestamp_nsecs


def validate_and_process_custom(custom: Any) -> dict[str, Any]:
"""Validates and processes custom field."""
if custom is None:
return {}

_validate_type(custom, dict)
for k in custom:
_validate_type(k, str)
return custom


def process_unknown_key(key: str, metadata_dict: dict[str, Any]) -> Any:
if 'custom' in metadata_dict and metadata_dict['custom']:
raise ValueError(
f'Metadata field "{dict_field_name}" keys must be of type {key_type}, '
f'got {type(key)}.'
'Provided metadata contains unknown key %s, and the custom field '
'is already defined.' % key
)
if value_type is not None:
dict_field = dict_field[dict_field_name]
if not isinstance(dict_field[key], value_type):
raise ValueError(
f'Metadata field "{dict_field_name}" values must be of '
f'type {value_type}, got {type(dict_field[key])}.'
)
logging.warning(
'Provided metadata contains unknown key %s. Adding it to custom.', key
)
return metadata_dict[key]
Loading

0 comments on commit c391304

Please sign in to comment.