Skip to content

Commit 83207b6

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Eliminate usages of ParamInfo.path in favor of parent_dir / name.
PiperOrigin-RevId: 834378892
1 parent 3b0e3fa commit 83207b6

19 files changed

+42
-44
lines changed

checkpoint/orbax/checkpoint/_src/handlers/array_checkpoint_handler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,8 @@ def restore(
153153
# checkpoints lacking PYTREE_METADATA_FILE is no longer needed.
154154
restore_args = args.restore_args or type_handlers.RestoreArgs()
155155

156-
checkpoint_path = directory / self._checkpoint_name
157156
info = type_handlers.ParamInfo(
158157
name=self._checkpoint_name,
159-
path=checkpoint_path,
160158
parent_dir=directory,
161159
skip_deserialize=False,
162160
is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(directory),

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,6 @@ def _param_info(keypath, name, value):
472472
return ParamInfo(
473473
name=name,
474474
keypath=keypath,
475-
path=(directory / name),
476475
parent_dir=directory,
477476
skip_deserialize=skip_deserialize,
478477
is_ocdbt_checkpoint=use_ocdbt,

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ def _get_param_info(
268268
skip_deserialize = meta_or_value.skip_deserialize
269269
return ParamInfo(
270270
name=name,
271-
path=directory / name,
272271
parent_dir=directory,
273272
skip_deserialize=skip_deserialize,
274273
is_ocdbt_checkpoint=is_ocdbt_checkpoint,
@@ -283,7 +282,9 @@ def _get_param_info(
283282
if partial_restore:
284283
for key, meta in flat_structure.items():
285284
if key not in flat_item:
286-
flat_param_infos[key] = ParamInfo(skip_deserialize=True)
285+
flat_param_infos[key] = ParamInfo(
286+
name='', parent_dir=directory, skip_deserialize=True
287+
)
287288
flat_input_restore_args[key] = RestoreArgs()
288289
else:
289290
flat_param_infos[key] = _get_param_info(flat_param_names[key], meta)
@@ -322,7 +323,9 @@ def _get_param_info(
322323
# Specified `use_fallback`, but key was also present in the
323324
# checkpoint. This means we should skip loading, since it will be
324325
# overridden with a new value.
325-
flat_param_infos[input_key] = ParamInfo(skip_deserialize=True)
326+
flat_param_infos[input_key] = ParamInfo(
327+
name='', parent_dir=directory, skip_deserialize=True
328+
)
326329
flat_input_restore_args[input_key] = RestoreArgs()
327330
else:
328331
# Specified `use_fallback`, but `transforms_default_to_original`
@@ -343,12 +346,16 @@ def _get_param_info(
343346
else:
344347
# Take the value from the user-provided `item`, ignoring any value
345348
# in the checkpoint.
346-
flat_param_infos[input_key] = ParamInfo(skip_deserialize=True)
349+
flat_param_infos[input_key] = ParamInfo(
350+
name='', parent_dir=directory, skip_deserialize=True
351+
)
347352
flat_input_restore_args[input_key] = RestoreArgs()
348353
else:
349354
# No match, restoration not required since it will be dropped from the
350355
# output.
351-
flat_param_infos[input_key] = ParamInfo(skip_deserialize=True)
356+
flat_param_infos[input_key] = ParamInfo(
357+
name='', parent_dir=directory, skip_deserialize=True
358+
)
352359
flat_input_restore_args[input_key] = RestoreArgs()
353360

354361
restore_args = tree_utils.from_flat_dict(

checkpoint/orbax/checkpoint/_src/metadata/tree.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,6 @@ def as_custom_metadata(
485485
param_name = '.'.join(keypath)
486486
flat_param_infos[keypath] = types.ParamInfo(
487487
name=param_name,
488-
path=directory / param_name,
489488
parent_dir=directory,
490489
skip_deserialize=value_meta.skip_deserialize,
491490
is_ocdbt_checkpoint=use_ocdbt,

checkpoint/orbax/checkpoint/_src/metadata/tree_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from absl.testing import absltest
1818
from absl.testing import parameterized
1919
import chex
20+
from etils import epath
2021
import jax
2122
import numpy as np
2223
from orbax.checkpoint._src.metadata import tree as tree_metadata_lib
@@ -43,6 +44,8 @@ def _to_param_infos(
4344
return jax.tree.map(
4445
# Other properties are not relevant.
4546
lambda x: types.ParamInfo(
47+
name='',
48+
parent_dir=epath.Path(''),
4649
value_typestr=type_handler_registry.get_param_typestr(
4750
x,
4851
type_handler_registry.GLOBAL_TYPE_HANDLER_REGISTRY,

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ async def _validate_non_ocdbt_files(
681681
):
682682
await asyncio.gather(*[
683683
ts_utils.assert_parameter_files_exist( # pylint: disable=protected-access
684-
info.path, metadata_key, info.use_zarr3
684+
info.parent_dir / info.name, metadata_key, info.use_zarr3
685685
)
686686
for info in infos
687687
])

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,8 @@ def _get_json_tspec(
583583
raise_array_data_missing_error: bool = True,
584584
) -> dict[str, Any]:
585585
"""Gets Tensorstore spec in JSON format."""
586-
if info.path is None:
587-
raise ValueError('Must construct serialization path.')
586+
if info.name is None or info.parent_dir is None:
587+
raise ValueError('Must provide info.name and info.parent_dir.')
588588
parent_dir = info.parent_dir
589589
assert parent_dir is not None
590590
directory = parent_dir.as_posix()
@@ -691,8 +691,8 @@ def build_array_write_spec(
691691
ext_metadata: dict[str, Any] | None = None,
692692
) -> ArrayWriteSpec:
693693
"""Gets ArrayWriteSpec for writing."""
694-
if info.path is None:
695-
raise ValueError('Must construct serialization path.')
694+
if info.name is None or info.parent_dir is None:
695+
raise ValueError('Must provide info.name and info.parent_dir.')
696696
parent_dir = info.parent_dir
697697
assert parent_dir is not None
698698
directory = parent_dir.as_posix()

checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from typing import Any, Dict, Optional, Sequence, Tuple, TypeAlias, Union
2323

2424
from absl import logging
25-
from etils import epath
2625
import jax
2726
import numpy as np
2827
from orbax.checkpoint._src.futures import future
@@ -202,7 +201,7 @@ async def deserialize(
202201
for info, arg in zip(infos, args):
203202
if not info.is_ocdbt_checkpoint:
204203
await ts_utils.assert_parameter_files_exist(
205-
info.path, self._metadata_key, info.use_zarr3
204+
info.parent_dir / info.name, self._metadata_key, info.use_zarr3
206205
)
207206
# Use OCDBT flag from the existing checkpoint.
208207
use_ocdbt = info.is_ocdbt_checkpoint
@@ -310,8 +309,8 @@ def _get_json_tspec(
310309
info: types.ParamInfo,
311310
) -> Dict[str, Any]:
312311
"""Gets Tensorstore spec in JSON format."""
313-
if info.path is None:
314-
raise ValueError('Must construct serialization path.')
312+
if info.parent_dir is None:
313+
raise ValueError('Must provide info.parent_dir.')
315314
directory = (info.parent_dir / self._filename).as_posix()
316315
kvstore_tspec = ts_utils.build_kvstore_tspec(directory, use_ocdbt=False)
317316
tspec = {
@@ -384,11 +383,9 @@ async def deserialize(
384383
"""See superclass documentation."""
385384
del args
386385
types.check_input_arguments(infos)
387-
directory = epath.Path(infos[0].path).parent
388386
open_futures = []
389387

390388
for info in infos:
391-
info.path = epath.Path(directory / self._filename)
392389
tspec = self._get_json_tspec(info)
393390
open_future = ts.open(
394391
tspec, open=True, read=True, context=self._ts_context

checkpoint/orbax/checkpoint/_src/serialization/types.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def check_input_arguments(*args):
6060
raise ValueError('Found input args with mismatched lengths.')
6161

6262

63-
@dataclasses.dataclass
63+
@dataclasses.dataclass(kw_only=True)
6464
class ParamInfo:
6565
"""Information describing a parameter in a PyTree.
6666
@@ -70,13 +70,12 @@ class ParamInfo:
7070
7171
name:
7272
Name of the parameter.
73-
path:
74-
A path providing a location where file(s) should be saved. The path is
75-
assumed to be a directory.
7673
parent_dir:
7774
A path providing location where all files under the same checkpoint should
7875
be saved under. All `ParamInfo` provided to a given TypeHandler should have
7976
the same `parent_dir`. The parent_dir is assumed to be a directory.
77+
path:
78+
Do not provide directly. Automatically set to `parent_dir / name`.
8079
skip_deserialize:
8180
If specified, skips deserialization of the given parameter using the
8281
TypeHandler. This may be for multiple different reasons, including that the
@@ -115,10 +114,10 @@ class ParamInfo:
115114
is_prioritized_key_fn: See `IsPrioritizedKeyFn` definition.
116115
"""
117116

118-
name: Optional[str] = None
119-
keypath: Optional[Tuple[Any, ...]] = None
117+
name: str
118+
parent_dir: epath.Path
120119
path: Optional[epath.Path] = None
121-
parent_dir: Optional[epath.Path] = None
120+
keypath: Optional[Tuple[Any, ...]] = None
122121
skip_deserialize: Optional[bool] = None
123122
byte_limiter: Optional[limits.ByteLimiter] = None
124123
device_host_byte_limiter: Optional[limits.ByteLimiter] = None
@@ -133,6 +132,10 @@ class ParamInfo:
133132
write_shape: arrays_types.Shape | None = None
134133
is_prioritized_key_fn: Optional[IsPrioritizedKeyFn] = None
135134

135+
def __post_init__(self):
136+
if self.path is None:
137+
self.path = self.parent_dir / self.name
138+
136139

137140
@dataclasses.dataclass
138141
class SaveArgs:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/array_handler_benchmark.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ def test_fn(self, test_context: core.TestContext) -> core.TestResult:
151151
)
152152
sharded_array = test_context.pytree['array']
153153
array_name = 'array'
154-
array_path = test_context.path / array_name
155154

156155
ts_context = ts_utils.get_ts_context(use_ocdbt=options.use_ocdbt)
157156
value_typestr = type_handler_registry.get_param_typestr(
@@ -162,7 +161,6 @@ def test_fn(self, test_context: core.TestContext) -> core.TestResult:
162161

163162
param_info = type_handlers.ParamInfo(
164163
name=array_name,
165-
path=array_path,
166164
parent_dir=test_context.path,
167165
use_zarr3=options.use_zarr3,
168166
is_ocdbt_checkpoint=options.use_ocdbt,

0 commit comments

Comments
 (0)