Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor metadata de/serialization. #1493

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ py_library(
py_library(
name = "metadata_serialization_utils",
srcs = ["metadata_serialization_utils.py"],
deps = [
":checkpoint",
"//checkpoint/orbax/checkpoint/_src/logging:step_statistics",
],
)

py_library(
Expand Down
67 changes: 23 additions & 44 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,10 @@ def test_deserialize_wrong_types_root_metadata(
@parameterized.parameters(
({'item_handlers': list()},),
({'item_handlers': {int(): None}},),
({'metrics': list()},),
({'metrics': [int()]},),
({'metrics': {int(): None}},),
({'performance_metrics': list()},),
({'performance_metrics': {int(): float()}},),
({'performance_metrics': {str(): int()}},),
({'init_timestamp_nsecs': float()},),
({'commit_timestamp_nsecs': float()},),
({'custom': list()},),
Expand Down Expand Up @@ -566,7 +565,7 @@ def test_deserialize_item_metadata_with_item_metadata_kwarg(
@parameterized.parameters(
('metrics', 1, dict, int),
)
def test_validate_type_wrong_type(
def test_deserialize_metrics_kwarg(
self, kwarg_name, kwarg_value, expected_type, wrong_type
):
with self.assertRaisesRegex(
Expand All @@ -576,74 +575,54 @@ def test_validate_type_wrong_type(
self.deserialize_metadata(StepMetadata, {}, **{kwarg_name: kwarg_value})

@parameterized.parameters(
({'metrics': 1}, 'metrics', dict, int),
({'performance_metrics': 1}, 'performance_metrics', dict, int),
({'init_timestamp_nsecs': 'a'}, 'init_timestamp_nsecs', int, str),
({'commit_timestamp_nsecs': 'a'}, 'commit_timestamp_nsecs', int, str),
({'custom': 1}, 'custom', dict, int),
({'metrics': 1}, dict, int),
({'init_timestamp_nsecs': 'a'}, int, str),
({'commit_timestamp_nsecs': 'a'}, int, str),
({'custom': 1}, dict, int),
)
def test_validate_field_wrong_type(
self, step_metadata, field_name, expected_field_type, wrong_field_type
self, step_metadata, expected_field_type, wrong_field_type
):
with self.assertRaisesRegex(
ValueError,
f'Metadata field "{field_name}" must be of type '
f'{expected_field_type}, got {wrong_field_type}.',
f'Object must be of type {expected_field_type}, got {wrong_field_type}',
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'item_handlers': 1}, 'item_handlers', int),
({'item_handlers': 1}, int),
({'performance_metrics': 1}, int),
)
def test_validate_field_sequence_wrong_type(
self, step_metadata, field_name, wrong_field_type
self, step_metadata, wrong_field_type
):
with self.assertRaisesRegex(
ValueError,
f'Metadata field "{field_name}" must be any one of types '
f'Object must be any one of types '
r'\[.+]'
f', got {wrong_field_type}.',
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'item_handlers': {1: 'a'}}, 'item_handlers', str, int),
({'metrics': {1: 'a'}}, 'metrics', str, int),
(
{'performance_metrics': {1: 1.0}},
'performance_metrics',
str,
int,
),
({'custom': {1: 'a'}}, 'custom', str, int),
({'item_handlers': {1: 'a'}}, str, int),
({'metrics': {1: 'a'}}, str, int),
({'performance_metrics': {1: 1.0}}, str, int),
({'custom': {1: 'a'}}, str, int),
)
def test_validate_dict_entry_wrong_key_type(
self, step_metadata, field_name, expected_key_type, wrong_key_type
self, step_metadata, expected_key_type, wrong_key_type
):
with self.assertRaisesRegex(
ValueError,
f'Metadata field "{field_name}" keys must be of type '
f'{expected_key_type}, got {wrong_key_type}.',
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'performance_metrics': {'a': 1}}, 'performance_metrics', float, int),
)
def test_validate_dict_entry_wrong_value_type(
self, step_metadata, field_name, expected_value_type, wrong_value_type
):
with self.assertRaisesRegex(
ValueError,
f'Metadata field "{field_name}" values must be of type '
f'{expected_value_type}, got {wrong_value_type}.',
f'Object must be of type {expected_key_type}, got {wrong_key_type}.',
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'item_handlers': {'a': 'b'}},),
({'performance_metrics': {'a': 1.0}},),
({'user_metadata': {'a': 1}, 'init_timestamp_nsecs': 1},),
({'custom': {'a': 1}, 'init_timestamp_nsecs': 1},),
)
def test_serialize_for_update_valid_kwargs(
self, kwargs: dict[str, Any]
Expand All @@ -656,13 +635,13 @@ def test_serialize_for_update_valid_kwargs(
@parameterized.parameters(
({'item_handlers': list()},),
({'item_handlers': {int(): None}},),
({'metrics': list()},),
({'metrics': [int()]},),
({'metrics': {int(): None}},),
({'performance_metrics': list()},),
({'init_timestamp_nsecs': float()},),
({'commit_timestamp_nsecs': float()},),
({'user_metadata': list()},),
({'user_metadata': {int(): None}},),
({'custom': list()},),
({'custom': {int(): None}},),
)
def test_serialize_for_update_wrong_types(
self, kwargs: dict[str, Any]
Expand All @@ -675,7 +654,7 @@ def test_serialize_for_update_with_unknown_kwargs(self):
ValueError, 'Provided metadata contains unknown key blah'
):
step_metadata_serialization.serialize_for_update(
user_metadata={'a': 1},
custom={'a': 1},
blah=123,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,62 +14,148 @@

"""Utilities for serializing and deserializing metadata."""

from typing import Any, Sequence
import dataclasses
from typing import Any, Optional, Sequence

from absl import logging
from orbax.checkpoint._src.logging import step_statistics
from orbax.checkpoint._src.metadata import checkpoint

def validate_type(obj: Any, field_type: type[Any]):
if not isinstance(obj, field_type):
raise ValueError(f'Object must be of type {field_type}, got {type(obj)}.')

CompositeCheckpointHandlerTypeStrs = (
checkpoint.CompositeCheckpointHandlerTypeStrs
)
CheckpointHandlerTypeStr = checkpoint.CheckpointHandlerTypeStr
CompositeItemMetadata = checkpoint.CompositeItemMetadata
SingleItemMetadata = checkpoint.SingleItemMetadata
StepStatistics = step_statistics.SaveStepStatistics


def validate_field(
obj: Any,
field_name: str,
field_type: type[Any] | Sequence[type[Any]]
):
"""Validates a single field in a dictionary.

field_type can optionally be a sequence of types, in which case the field
must be of any one of the types in the sequence.

Args:
obj: The object to validate.
field_name: The name of the field to validate.
field_type: The type (or sequence of types) of the field to validate.
"""
if field_name not in obj or obj[field_name] is None:
return
field = obj[field_name]
def _validate_type(obj: Any, field_type: type[Any] | Sequence[type[Any]]):
if isinstance(field_type, Sequence):
if not any(isinstance(field, f_type) for f_type in field_type):
if not any(isinstance(obj, f_type) for f_type in field_type):
raise ValueError(
f'Metadata field "{field_name}" must be any one of '
f'types {list(field_type)}, got {type(field)}.'
f'Object must be any one of types {list(field_type)}, got '
f'{type(obj)}.'
)
elif not isinstance(field, field_type):
raise ValueError(
f'Metadata field "{field_name}" must be of type {field_type}, '
f'got {type(field)}.'
)
elif not isinstance(obj, field_type):
raise ValueError(f'Object must be of type {field_type}, got {type(obj)}.')


def validate_and_process_item_handlers(
item_handlers: Any,
) -> CompositeCheckpointHandlerTypeStrs | CheckpointHandlerTypeStr | None:
"""Validates and processes item_handlers field."""
if item_handlers is None:
return None

_validate_type(item_handlers, [dict, str])
if isinstance(item_handlers, CompositeCheckpointHandlerTypeStrs):
for k in item_handlers or {}:
_validate_type(k, str)
return item_handlers
elif isinstance(item_handlers, CheckpointHandlerTypeStr):
return item_handlers


def validate_and_process_item_metadata(
item_metadata: Any,
) -> CompositeItemMetadata | SingleItemMetadata | None:
"""Validates and processes item_metadata field."""
if item_metadata is None:
return None

if isinstance(item_metadata, CompositeItemMetadata):
_validate_type(item_metadata, dict)
for k in item_metadata:
_validate_type(k, str)
return item_metadata
else:
return item_metadata


def validate_and_process_metrics(
metrics: Any,
additional_metrics: Optional[Any] = None
) -> dict[str, Any]:
"""Validates and processes metrics field."""
metrics = metrics or {}

_validate_type(metrics, dict)
for k in metrics:
_validate_type(k, str)
validated_metrics = metrics

if additional_metrics is not None:
_validate_type(additional_metrics, dict)
for k, v in additional_metrics.items():
_validate_type(k, str)
validated_metrics[k] = v

return validated_metrics


def validate_dict_entry(
dict_field: dict[Any, Any],
dict_field_name: str,
key: Any,
key_type: type[Any],
value_type: type[Any] | None = None,
):
"""Validates a single entry in a dictionary field."""
if not isinstance(key, key_type):
def validate_and_process_performance_metrics(
performance_metrics: Any,
) -> dict[str, float]:
"""Validates and processes performance_metrics field."""
if performance_metrics is None:
return {}

_validate_type(performance_metrics, [dict, StepStatistics])
if isinstance(performance_metrics, StepStatistics):
performance_metrics = dataclasses.asdict(performance_metrics)

for k in performance_metrics:
_validate_type(k, str)

return {
metric: val
for metric, val in performance_metrics.items()
if isinstance(val, float)
}


def validate_and_process_init_timestamp_nsecs(
init_timestamp_nsecs: Any,
) -> int | None:
"""Validates and processes init_timestamp_nsecs field."""
if init_timestamp_nsecs is None:
return None

_validate_type(init_timestamp_nsecs, int)
return init_timestamp_nsecs


def validate_and_process_commit_timestamp_nsecs(
commit_timestamp_nsecs: Any,
) -> int | None:
"""Validates and processes commit_timestamp_nsecs field."""
if commit_timestamp_nsecs is None:
return None

_validate_type(commit_timestamp_nsecs, int)
return commit_timestamp_nsecs


def validate_and_process_custom(custom: Any) -> dict[str, Any]:
"""Validates and processes custom field."""
if custom is None:
return {}

_validate_type(custom, dict)
for k in custom:
_validate_type(k, str)
return custom


def process_unknown_key(key: str, metadata_dict: dict[str, Any]) -> Any:
if 'custom' in metadata_dict and metadata_dict['custom']:
raise ValueError(
f'Metadata field "{dict_field_name}" keys must be of type {key_type}, '
f'got {type(key)}.'
'Provided metadata contains unknown key %s, and the custom field '
'is already defined.' % key
)
if value_type is not None:
dict_field = dict_field[dict_field_name]
if not isinstance(dict_field[key], value_type):
raise ValueError(
f'Metadata field "{dict_field_name}" values must be of '
f'type {value_type}, got {type(dict_field[key])}.'
)
logging.warning(
'Provided metadata contains unknown key %s. Adding it to custom.', key
)
return metadata_dict[key]
Loading
Loading