Skip to content

Commit 879cb6c

Browse files
author
Orbax Authors
committed
Run all correctness benchmarks in Github
PiperOrigin-RevId: 831318632
1 parent d966ddf commit 879cb6c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+562
-298
lines changed

.github/workflows/build.yml

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ jobs:
242242
working-directory: checkpoint
243243
strategy:
244244
matrix:
245-
python-version: ["3.10", "3.11", "3.12"]
245+
python-version: ["3.10"]
246246
jax-version: ["0.6.0"]
247247
steps:
248248
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -267,9 +267,30 @@ jobs:
267267
- name: Run benchmarks
268268
env:
269269
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
270+
TF_FORCE_GPU_ALLOW_GROWTH: true
271+
XLA_PYTHON_CLIENT_PREALLOCATE: false
272+
KERAS_BACKEND: "jax"
270273
run: |
271-
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
274+
cd orbax/checkpoint/_src/testing/benchmarks
275+
failed_benchmarks=""
276+
benchmark_configs_file="multiprocess_benchmark_configs.txt"
277+
echo "Running benchmarks specified in $benchmark_configs_file"
278+
while IFS= read -r entry || [ -n "$entry" ]; do
279+
if [ -n "$entry" ]; then
280+
echo "Running benchmark for $entry"
281+
if ! python -c "import sys; import jax; jax.distributed.initialize(); print(jax.devices()); from absl import app; import run_benchmarks; sys.argv = ['run_benchmarks.py', '--config_file="$entry"', '--output_directory=$GCS_BUCKET_PATH']; app.run(run_benchmarks.main)"; then
282+
echo "Benchmark $entry failed"
283+
failed_benchmarks="$failed_benchmarks $entry"
284+
fi
285+
fi
286+
done < "$benchmark_configs_file"
272287
cd ../../../../..
288+
if [ -n "$failed_benchmarks" ]; then
289+
echo "The following benchmarks failed:$failed_benchmarks"
290+
exit 1
291+
fi
292+
# cd orbax/checkpoint/_src/testing/benchmarks && python -c "import sys; import jax; jax.distributed.initialize(); print(jax.devices()); from absl import app; import run_benchmarks; sys.argv = ['run_benchmarks.py', '--config_file=configs/pytree_checkpoint_benchmark.yaml', '--output_directory=$GCS_BUCKET_PATH']; app.run(run_benchmarks.main)"
293+
# cd ../../../../..
273294
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
274295
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
275296
# The below step just reports the success or failure of tests as a "commit status".

.github/workflows/multiprocess_tests.yml

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
working-directory: checkpoint
2626
strategy:
2727
matrix:
28-
python-version: ["3.10", "3.11", "3.12"]
28+
python-version: ["3.12"]
2929
jax-version: ["0.6.0"]
3030
steps:
3131
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -50,9 +50,27 @@ jobs:
5050
- name: Run benchmarks
5151
env:
5252
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
53+
TF_FORCE_GPU_ALLOW_GROWTH: true
54+
XLA_PYTHON_CLIENT_PREALLOCATE: false
5355
run: |
54-
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
56+
cd orbax/checkpoint/_src/testing/benchmarks
57+
failed_benchmarks=""
58+
benchmark_configs_file="multiprocess_benchmark_configs.txt"
59+
echo "Running benchmarks specified in $benchmark_configs_file"
60+
while IFS= read -r entry || [ -n "$entry" ]; do
61+
if [ -n "$entry" ]; then
62+
echo "Running benchmark for $entry"
63+
if ! python run_benchmarks.py --config_file="$entry" --output_directory=$GCS_BUCKET_PATH; then
64+
echo "Benchmark $entry failed"
65+
failed_benchmarks="$failed_benchmarks $entry"
66+
fi
67+
fi
68+
done < "$benchmark_configs_file"
5569
cd ../../../../..
70+
if [ -n "$failed_benchmarks" ]; then
71+
echo "The following benchmarks failed:$failed_benchmarks"
72+
exit 1
73+
fi
5674
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
5775
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
5876
# The below step just reports the success or failure of tests as a "commit status".

checkpoint/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
- Fix `step_from_checkpoint_name` to allow the passed in checkpoint name to
1313
include an arbitrary `step_prefix` with any character(s) such as underscores.
1414
- Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`.
15+
- Fix using jax.eval_shape with StandardRestore
1516

1617
### Changed
1718

1819
- Validate checkpoints before writing merged OCDBT database using in-memory
1920
state, avoiding additional I/O to re-read metadata.
2021
- add `support_format` to utils.to_shape_dtype_struct()
2122
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
23+
- Replace usage of `get_json_tpec_read` and delegate functionality to new
24+
function `build_array_read_spec` which constructs and returns an
25+
`ArrayReadSpec`.
2226

2327
## [0.11.28] - 2025-11-06
2428

checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,13 @@ def handler(self) -> StandardCheckpointHandler:
505505
return StandardCheckpointHandler()
506506

507507
def test_with_random_keys(self):
508+
# TODO(b/393160483) investigate Pathways remote Python support for
509+
# random.keys.
508510
if utils.is_pathways_backend():
509-
self.skipTest('Pathways does not support random keys checkpoint.')
511+
self.skipTest(
512+
'Disabled on Pathways because random keys are not supported by'
513+
' remote Python.'
514+
)
510515

511516
def create_random_keys(seed):
512517
duplicated_sharding = jax.sharding.NamedSharding(
@@ -559,3 +564,38 @@ def create_random_keys(seed):
559564
args=self.restore_args_cls(abstract_tree),
560565
)
561566
test_utils.assert_tree_equal(self, self.pytree, restored)
567+
568+
def test_save_restore_random_keys_with_jax_eval_shape(self):
569+
# TODO(b/393160483) investigate Pathways remote Python support for
570+
# random.keys.
571+
if utils.is_pathways_backend():
572+
self.skipTest(
573+
'Disabled on Pathways because random keys are not supported by'
574+
' remote Python.'
575+
)
576+
577+
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
578+
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
579+
580+
@functools.partial(
581+
jax.jit,
582+
in_shardings=sharding,
583+
out_shardings=sharding,
584+
)
585+
def sharded_create_state_fn(root_key):
586+
return dict(
587+
matrix=jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
588+
rngkey=jax.random.fold_in(root_key, 42),
589+
)
590+
591+
pytree = sharded_create_state_fn(jax.random.key(0))
592+
abstract_pytree = jax.eval_shape(
593+
sharded_create_state_fn, jax.random.key(0)
594+
)
595+
596+
self.handler.save(self.directory, args=self.save_args_cls(pytree))
597+
598+
restored = self.handler.restore(
599+
self.directory, args=self.restore_args_cls(abstract_pytree)
600+
)
601+
test_utils.assert_tree_equal(self, pytree, restored)

checkpoint/orbax/checkpoint/_src/path/atomicity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,13 @@ async def _create_tmp_directory(
161161
def _get_tmp_directory(final_path: epath.Path) -> epath.Path:
162162
# Path may not be completely unique if a preemption occurs. We rely on the
163163
# existing tmp directory being deleted elsewhere.
164-
return epath.Path(final_path.parent) / (final_path.name + TMP_DIR_SUFFIX)
164+
return final_path.parent / (final_path.name + TMP_DIR_SUFFIX)
165165

166166

167167
def _get_final_directory(tmp_path: epath.Path) -> epath.Path:
168168
if (suffix_idx := tmp_path.name.find(TMP_DIR_SUFFIX)) == -1:
169169
raise ValueError(f'Expected {tmp_path} to end with "{TMP_DIR_SUFFIX}".')
170-
return epath.Path(tmp_path.parent) / tmp_path.name[:suffix_idx]
170+
return tmp_path.parent / tmp_path.name[:suffix_idx]
171171

172172

173173
class TemporaryPathBase(atomicity_types.TemporaryPath):

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -763,12 +763,13 @@ async def _async_deserialize(
763763
await _validate_non_ocdbt_files(infos, metadata_key)
764764
deserialize_ops = []
765765
for info, arg, sharding in zip(infos, args, shardings):
766-
tspec = ts_utils.get_json_tspec_read(
766+
array_read_spec = ts_utils.build_array_read_spec(
767767
info,
768768
use_ocdbt=use_ocdbt,
769769
metadata_key=metadata_key,
770770
raise_array_data_missing_error=info.raise_array_data_missing_error,
771771
)
772+
tspec = array_read_spec.json
772773
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)
773774

774775
# set dtype=None to deserialize for random keys
@@ -939,19 +940,6 @@ def __init__(
939940
def has_dispatcher(self) -> bool:
940941
return self._dispatcher is not None
941942

942-
def _get_json_tspec_read(
943-
self,
944-
info: types.ParamInfo,
945-
use_ocdbt: bool,
946-
) -> Dict[str, Any]:
947-
"""Gets Tensorstore spec for reading."""
948-
return ts_utils.get_json_tspec_read(
949-
info,
950-
use_ocdbt=use_ocdbt,
951-
metadata_key=self._metadata_key,
952-
raise_array_data_missing_error=info.raise_array_data_missing_error,
953-
)
954-
955943
def typestr(self) -> str:
956944
return JAX_ARRAY_TYPE_STR
957945

@@ -968,7 +956,13 @@ async def metadata(
968956
for info in infos:
969957
# Use OCDBT flag from the existing checkpoint.
970958
use_ocdbt = info.is_ocdbt_checkpoint
971-
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
959+
array_read_spec = ts_utils.build_array_read_spec(
960+
info,
961+
use_ocdbt=use_ocdbt,
962+
metadata_key=self._metadata_key,
963+
raise_array_data_missing_error=info.raise_array_data_missing_error,
964+
)
965+
tspec = array_read_spec.json
972966
open_ops.append(
973967
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
974968
)

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,45 @@ def _maybe_add_cast_to_write_spec(
391391
return array_tspec
392392

393393

394+
class ArrayReadSpec:
395+
"""Full TensorStore spec for reading an array."""
396+
397+
def __init__(
398+
self,
399+
directory: str,
400+
relative_array_filename: str,
401+
use_zarr3: bool,
402+
*,
403+
use_ocdbt: bool,
404+
metadata_key: str | None = None,
405+
raise_array_data_missing_error: bool = True,
406+
):
407+
"""Builds a TensorStore spec for reading an array."""
408+
kvstore_tspec = build_kvstore_tspec(
409+
directory,
410+
name=relative_array_filename,
411+
use_ocdbt=use_ocdbt,
412+
process_id=None,
413+
)
414+
415+
tspec = {
416+
'driver': ZARR_VER3 if use_zarr3 else ZARR_VER2,
417+
'kvstore': kvstore_tspec,
418+
'recheck_cached_data': False,
419+
'recheck_cached_metadata': False,
420+
# Raise error if data is missing.
421+
'fill_missing_data_reads': not raise_array_data_missing_error,
422+
}
423+
if metadata_key is not None:
424+
tspec['metadata_key'] = metadata_key
425+
self._json_spec = tspec
426+
427+
@property
428+
def json(self) -> JsonSpec:
429+
"""Spec to be used to open a TensorStore for reading the array."""
430+
return self._json_spec
431+
432+
394433
class ArrayWriteSpec:
395434
"""Full TensorStore spec for writing an array."""
396435

@@ -677,6 +716,26 @@ def get_json_tspec_write(
677716
return tspec
678717

679718

719+
def build_array_read_spec(
720+
info: types.ParamInfo,
721+
*,
722+
use_ocdbt: bool,
723+
metadata_key: str | None = None,
724+
raise_array_data_missing_error: bool = True,
725+
) -> ArrayReadSpec:
726+
"""Gets ArrayReadSpec for reading."""
727+
if info.name is None or info.parent_dir is None:
728+
raise ValueError('Must provide info.name and info.parent_dir.')
729+
return ArrayReadSpec(
730+
directory=info.parent_dir.as_posix(),
731+
relative_array_filename=info.name,
732+
use_zarr3=info.use_zarr3,
733+
use_ocdbt=use_ocdbt,
734+
metadata_key=metadata_key,
735+
raise_array_data_missing_error=raise_array_data_missing_error,
736+
)
737+
738+
680739
def build_array_write_spec(
681740
info: types.ParamInfo,
682741
arg: types.SaveArgs | None = None,

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,60 @@ def test_maybe_cloud_storage(self):
613613
self.assertTrue(ts_utils.is_remote_storage(nested_tspec))
614614

615615

616+
class BuildArrayTSpecForReadTest(parameterized.TestCase):
617+
618+
def setUp(self):
619+
super().setUp()
620+
self.directory = self.create_tempdir().full_path
621+
self.param_name = 'params/a'
622+
623+
self.array_read_spec_constructor = functools.partial(
624+
ts_utils.ArrayReadSpec,
625+
directory=self.directory,
626+
relative_array_filename=self.param_name,
627+
)
628+
629+
@parameterized.product(
630+
use_zarr3=(True, False),
631+
use_ocdbt=(True, False),
632+
)
633+
def test_basic(self, use_zarr3: bool, use_ocdbt: bool):
634+
tspec = self.array_read_spec_constructor(
635+
use_zarr3=use_zarr3,
636+
use_ocdbt=use_ocdbt,
637+
)
638+
json_spec = tspec.json
639+
self.assertEqual(json_spec['driver'], 'zarr3' if use_zarr3 else 'zarr')
640+
self.assertEqual(
641+
json_spec['kvstore']['driver'],
642+
'ocdbt' if use_ocdbt else ts_utils.DEFAULT_DRIVER,
643+
)
644+
self.assertFalse(json_spec['recheck_cached_data'])
645+
self.assertFalse(json_spec['recheck_cached_metadata'])
646+
self.assertFalse(json_spec['fill_missing_data_reads'])
647+
self.assertNotIn('metadata_key', json_spec)
648+
649+
def test_metadata_key(self):
650+
tspec = self.array_read_spec_constructor(
651+
use_zarr3=False,
652+
use_ocdbt=False,
653+
metadata_key='custom_metadata',
654+
)
655+
self.assertEqual(tspec.json['metadata_key'], 'custom_metadata')
656+
657+
@parameterized.parameters(True, False)
658+
def test_fill_missing_data_reads(self, raise_array_data_missing_error):
659+
tspec = self.array_read_spec_constructor(
660+
use_zarr3=False,
661+
use_ocdbt=False,
662+
raise_array_data_missing_error=raise_array_data_missing_error,
663+
)
664+
self.assertEqual(
665+
tspec.json['fill_missing_data_reads'],
666+
not raise_array_data_missing_error,
667+
)
668+
669+
616670
class GetTsContextTest(parameterized.TestCase):
617671

618672
@parameterized.product(

0 commit comments

Comments
 (0)