Skip to content

Commit 0c4f9d2

Browse files
authored
[train] Unblock get_all_reported_checkpoints if reporting only metrics (#58870)
When reporting a checkpoint to Ray Train, every worker needs to form a barrier with a `ray.train.report` call. If every worker reports an empty checkpoint, we should notify the condition to unblock `ray.train.get_all_reported_checkpoint` calls. Before this fix, reporting an empty checkpoint and calling `get_all_reported_checkpoints` would result in a hang. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>
1 parent 5e206d8 commit 0c4f9d2

File tree

4 files changed

+40
-25
lines changed

4 files changed

+40
-25
lines changed

python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def after_report(
342342
):
343343
if not training_report.checkpoint:
344344
self._current_report_index += 1
345+
self._notify()
345346
return
346347

347348
self.register_checkpoint(

python/ray/train/v2/tests/test_async_checkpointing_validation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,30 @@ def train_fn():
389389
trainer.fit()
390390

391391

392+
def test_report_get_all_reported_checkpoints():
393+
"""Check that get_all_reported_checkpoints returns checkpoints depending on # report calls."""
394+
395+
def train_fn():
396+
if ray.train.get_context().get_world_rank() == 0:
397+
ray.train.report(metrics={}, checkpoint=None)
398+
with create_dict_checkpoint({}) as checkpoint:
399+
ray.train.report(metrics={}, checkpoint=checkpoint)
400+
assert len(ray.train.get_all_reported_checkpoints()) == 1
401+
with create_dict_checkpoint({}) as checkpoint:
402+
ray.train.report(metrics={}, checkpoint=checkpoint)
403+
else:
404+
ray.train.report(metrics={}, checkpoint=None)
405+
ray.train.report(metrics={}, checkpoint=None)
406+
ray.train.report(metrics={}, checkpoint=None)
407+
assert len(ray.train.get_all_reported_checkpoints()) == 2
408+
409+
trainer = DataParallelTrainer(
410+
train_fn,
411+
scaling_config=ScalingConfig(num_workers=2),
412+
)
413+
trainer.fit()
414+
415+
392416
def test_get_all_reported_checkpoints_all_consistency_modes():
393417
signal_actor = create_remote_signal_actor(ray).remote()
394418

@@ -440,6 +464,18 @@ def validate_fn(checkpoint, config):
440464
trainer.fit()
441465

442466

467+
def test_get_all_reported_checkpoints_empty_reports():
468+
def train_fn():
469+
ray.train.report(metrics={}, checkpoint=None)
470+
assert len(ray.train.get_all_reported_checkpoints()) == 0
471+
472+
trainer = DataParallelTrainer(
473+
train_fn,
474+
scaling_config=ScalingConfig(num_workers=2),
475+
)
476+
trainer.fit()
477+
478+
443479
if __name__ == "__main__":
444480
import sys
445481

python/ray/train/v2/tests/test_data_parallel_trainer.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -141,30 +141,6 @@ def train_fn():
141141
assert tmp_path.joinpath("validate", str(rank)).exists()
142142

143143

144-
def test_report_get_all_reported_checkpoints():
145-
"""Check that get_all_reported_checkpoints returns checkpoints depending on # report calls."""
146-
147-
def train_fn():
148-
if ray.train.get_context().get_world_rank() == 0:
149-
ray.train.report(metrics={}, checkpoint=None)
150-
with create_dict_checkpoint({}) as checkpoint:
151-
ray.train.report(metrics={}, checkpoint=checkpoint)
152-
assert len(ray.train.get_all_reported_checkpoints()) == 1
153-
with create_dict_checkpoint({}) as checkpoint:
154-
ray.train.report(metrics={}, checkpoint=checkpoint)
155-
else:
156-
ray.train.report(metrics={}, checkpoint=None)
157-
ray.train.report(metrics={}, checkpoint=None)
158-
ray.train.report(metrics={}, checkpoint=None)
159-
assert len(ray.train.get_all_reported_checkpoints()) == 2
160-
161-
trainer = DataParallelTrainer(
162-
train_fn,
163-
scaling_config=ScalingConfig(num_workers=2),
164-
)
165-
trainer.fit()
166-
167-
168144
def test_error(tmp_path):
169145
def _error_func_rank_0():
170146
"""An example train_fun that raises an error on rank 0."""

python/ray/train/v2/tests/test_report_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def generate_worker_group_poll_status(num_workers, num_ckpt, num_dummy, num_none
6262
(10, 1, 8, 1, 0), # one worker with checkpoint, one worker with None
6363
],
6464
)
65-
def test_report_handler(tmp_path, num_workers, num_ckpt, num_dummy, num_none, expected):
65+
async def test_report_handler(
66+
tmp_path, num_workers, num_ckpt, num_dummy, num_none, expected
67+
):
6668
"""`expected` is the number of times that the
6769
CheckpointManager.register_checkpoint is called.
6870
"""

0 commit comments

Comments
 (0)