Skip to content

Conversation

@TimothySeah
Copy link
Contributor

@TimothySeah TimothySeah commented Jul 12, 2025

Summary

This PR adds a ray.train.get_all_reported_checkpoints method that allows users to get all the checkpoints they have reported from within their training function.

This is different from Result in two ways:

  • It is called from the training function on the training worker instead of from the driver
  • It can be called while training is still in progress

Implementation Notes

The main idea is to use a worker-side counter and controller-side counter as follows:

  • Train worker: ray.train.report increments a num_reported_checkpoints counter and puts the training result into its queue
  • Train controller: polls the training results from all worker, registers the checkpoint, increments num_reported_checkpoints, and then creates an asyncio task to notify asyncio Condition. This works because asyncio Ray actors should always have an event loop.
  • Train worker: get_all_reported_results uses an asyncio.Condition to wait until the worker-side num_reported_checkpoints counter matches its controller-side counterpart before returning the checkpoints. This ensures that we wait for all pending reports to finish. It has access to the controller actor through init_train_context.

get_checkpoint should be unaffected because it uses the local checkpoint; we can consider changing it to use the "centrally committed" checkpoint in the future.

Testing

I ran the ray train pytorch example and called ray.train.get_all_reported_checkpoints at the end of each epoch. The results are as expected; here are a few examples

epoch 4: [TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994), metrics={'loss': 0.24510294198989868, 'epoch': 0}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694), metrics={'loss': 0.23799467086791992, 'epoch': 1}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974), metrics={'loss': 0.39628422260284424, 'epoch': 2}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211), metrics={'loss': 0.15193207561969757, 'epoch': 3}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119), metrics={'loss': 0.17416314780712128, 'epoch': 4})]

epoch 9: [TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994), metrics={'loss': 0.24510294198989868, 'epoch': 0}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694), metrics={'loss': 0.23799467086791992, 'epoch': 1}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974), metrics={'loss': 0.39628422260284424, 'epoch': 2}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211), metrics={'loss': 0.15193207561969757, 'epoch': 3}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119), metrics={'loss': 0.17416314780712128, 'epoch': 4}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-12.547310), metrics={'loss': 0.2924661934375763, 'epoch': 5}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-28.538090), metrics={'loss': 0.18640762567520142, 'epoch': 6}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-44.583228), metrics={'loss': 0.12567029893398285, 'epoch': 7}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-00.540405), metrics={'loss': 0.1620682030916214, 'epoch': 8}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-17.129973), metrics={'loss': 0.07022886723279953, 'epoch': 9})]

I also modified all the Ray Train v2 unit tests that call ray.train.report:

  • test_persistence also verifies that get_all_reported_checkpoints works on resumption
  • test_data_parallel_trainer verifies that get_all_reported_checkpoints stalls until all workers report.

I also verified that get_all_reported_checkpoints produced similar output when called from Tune + Train.

I tried to test that get_all_reported_checkpoints finished even with graceful abort but was unable to create such a scenario since get_all_reported_checkpoints returns very quickly and each report forms a barrier.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @TimothySeah, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the checkpointing mechanism within Ray Train by introducing a consistent way to track the total number of reported checkpoints across the system. It ensures that both the central CheckpointManager and individual training workers maintain this count, improving visibility and coordination. Furthermore, it lays foundational work for more robust worker-manager communication and standardized checkpoint reporting through a new API definition.

Highlights

  • Checkpoint Counting: The CheckpointManager now tracks the total number of reported checkpoints via a new _num_reported_checkpoints attribute. This count is incremented each time a checkpoint is registered and is persisted as part of the manager's state.
  • Worker Context Enrichment: The TrainContext on each worker now includes num_reported_checkpoints (the current count from the manager) and, if the manager is an actor, a handle to the controller_actor. This allows workers to be aware of the overall checkpointing progress and potentially communicate back to the manager.
  • New API Definition: A new dataclass, ValidatedCheckpoint, has been introduced in ray/train/v2/api/reported_result.py. This class is intended to standardize the representation of user-reported checkpoints along with their associated metrics, paving the way for future refactors.
  • Test Coverage: New and updated unit tests have been added to test_checkpoint_manager.py to verify the correct tracking, persistence, and propagation of the num_reported_checkpoints and controller_actor information.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism for both the CheckpointManager and the Worker to count reported checkpoints, which helps in synchronizing state, especially on recovery. It also adds a way for workers to get a handle to the controller actor when the CheckpointManager is run as an actor.

The changes are logical and well-tested. I have a couple of suggestions:

  • There's a new, unused file that should probably be removed to avoid dead code.
  • A test assertion can be made more direct and robust.

Overall, these are good changes that improve the checkpointing mechanism.

@TimothySeah TimothySeah added the go add ONLY when ready to merge, run all tests label Jul 12, 2025
@TimothySeah TimothySeah requested a review from xinyuangui2 July 14, 2025 22:15
@xinyuangui2 xinyuangui2 removed their request for review July 15, 2025 00:28
@github-actions
Copy link

This pull request has been automatically marked as stale because it has not had
any activity for 14 days. It will be closed in another 14 days if no further activity occurs.
Thank you for your contributions.

You can always ask for help on our discussion forum or Ray's public slack channel.

If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@github-actions github-actions bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Jul 29, 2025
@github-actions github-actions bot added unstale A PR that has been marked unstale. It will not get marked stale again if this label is on it. and removed stale The issue is stale. It will be closed within 7 days unless there are further conversation labels Aug 2, 2025
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
@TimothySeah TimothySeah changed the title [train][checkpoint] CheckpointManager and Worker both count checkpoints [train][checkpoint] Add ray.train.get_all_reported_checkpoints method Aug 5, 2025
@TimothySeah TimothySeah changed the title [train][checkpoint] Add ray.train.get_all_reported_checkpoints method [train][checkpoint] Add ray.train.get_all_reported_checkpoints method Aug 5, 2025
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

1 similar comment
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@TimothySeah TimothySeah marked this pull request as ready for review August 5, 2025 02:13
@TimothySeah TimothySeah requested a review from a team as a code owner August 5, 2025 02:13
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it makes sense to split out CheckpointManager to be an actor, and pass that handle to the workers rather than the controller actor handle? similar to the DatasetManager that I'm adding. The controller has a few "heads" that has a narrower responsibility, which workers can also interface with.

Then, the controller can avoid adding these checkpoint specific methods.

For passing checkpoints around, the workers could just send the (metrics, checkpoint) to the CheckpointManager actor directly, rather than wait for controller polling. (This could also possibly solve the slow checkpoint queueing problem.) The controller polling loop is then exclusively used for health checks and the failure policy / scaling policy logic.

One unknown is where the callbacks should be triggered (ex: after_report) -- on the controller or on the checkpoint manager?

This might be too large of a refactor, but I think it could be worth it to modularize the checkpointing code so that we can work only with this CheckpointManager as we're starting to make more changes here.

@TimothySeah
Copy link
Contributor Author

Do you think it makes sense to split out CheckpointManager to be an actor, and pass that handle to the workers rather than the controller actor handle? similar to the DatasetManager that I'm adding. The controller has a few "heads" that has a narrower responsibility, which workers can also interface with.

Then, the controller can avoid adding these checkpoint specific methods.

For passing checkpoints around, the workers could just send the (metrics, checkpoint) to the CheckpointManager actor directly, rather than wait for controller polling. (This could also possibly solve the slow checkpoint queueing problem.) The controller polling loop is then exclusively used for health checks and the failure policy / scaling policy logic.

One unknown is where the callbacks should be triggered (ex: after_report) -- on the controller or on the checkpoint manager?

This might be too large of a refactor, but I think it could be worth it to modularize the checkpointing code so that we can work only with this CheckpointManager as we're starting to make more changes here.

Agreed that CheckpointManager and "actor per head" pattern makes sense. Another way to think about it is that each actor has a "lock" on its state so shoehorning everything into the controller effectively "locks" the entire state. Do you think this is a blocker or can I just file a bug for this? I think it's a bit difficult to do now and doing it later shouldn't create that much throwaway work since all the state is localized to the CheckpointManager already.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
…tor, required controller actor

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
from dataclasses import dataclass
from typing import Any, Dict

from ray.train import Checkpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is circular dependency now? how does this pass the tests?

from dataclasses import dataclass
from typing import Any, Dict

from ray.train import Checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to avoid circular dependency we can import directly from ray.train._checkpoint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm maybe not, let's follow up on this in a separate PR.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
… objects in modified files + fix pydoclint

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
This reverts commit 4b79eaf.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
@matthewdeng matthewdeng merged commit ea27046 into ray-project:master Sep 2, 2025
5 checks passed
gangsf pushed a commit to gangsf/ray that referenced this pull request Sep 2, 2025
…ray-project#54555)

# Summary

This PR adds a `ray.train.get_all_reported_checkpoints` method that
allows users to get all the checkpoints they have reported from within
their training function.

This is different from
[Result](https://docs.ray.io/en/latest/train/user-guides/results.html)
in two ways:
* It is called from the training function on the training worker instead
of from the driver
* It can be called while training is still in progress

# Implementation Notes

The main idea is to use a worker-side counter and controller-side
counter as follows:
* Train worker: `ray.train.report` increments a
`num_reported_checkpoints` counter and puts the training result into its
queue
* Train controller: polls the training results from all worker,
registers the checkpoint, increments `num_reported_checkpoints`, and
then creates an asyncio task to notify asyncio Condition. This works
because asyncio Ray actors should always have an event loop.
* Train worker: `get_all_reported_results` uses an asyncio.Condition to
wait until the worker-side `num_reported_checkpoints` counter matches
its controller-side counterpart before returning the checkpoints. This
ensures that we wait for all pending reports to finish. It has access to
the controller actor through `init_train_context`.

`get_checkpoint` should be unaffected because it uses the local
checkpoint; we can consider changing it to use the "centrally committed"
checkpoint in the future.

# Testing

I ran the [ray train pytorch
example](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)
and called `ray.train.get_all_reported_checkpoints` at the end of each
epoch. The results are as expected; here are a few examples

`
epoch 4: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4})]
`

`
epoch 9: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-12.547310),
metrics={'loss': 0.2924661934375763, 'epoch': 5}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-28.538090),
metrics={'loss': 0.18640762567520142, 'epoch': 6}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-44.583228),
metrics={'loss': 0.12567029893398285, 'epoch': 7}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-00.540405),
metrics={'loss': 0.1620682030916214, 'epoch': 8}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-17.129973),
metrics={'loss': 0.07022886723279953, 'epoch': 9})]
`

I also modified all the Ray Train v2 unit tests that call
`ray.train.report`:
* `test_persistence` also verifies that `get_all_reported_checkpoints`
works on resumption
* `test_data_parallel_trainer` verifies that
`get_all_reported_checkpoints` stalls until all workers report.

I also verified that `get_all_reported_checkpoints` produced similar
output when called from Tune + Train.

I tried to test that `get_all_reported_checkpoints` finished even with
graceful abort but was unable to create such a scenario since
`get_all_reported_checkpoints` returns very quickly and each `report`
forms a barrier.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Gang Zhao <gang@gang-JQ62HD2C37.local>
sampan-s-nayak pushed a commit to sampan-s-nayak/ray that referenced this pull request Sep 8, 2025
…ray-project#54555)

# Summary

This PR adds a `ray.train.get_all_reported_checkpoints` method that
allows users to get all the checkpoints they have reported from within
their training function.

This is different from
[Result](https://docs.ray.io/en/latest/train/user-guides/results.html)
in two ways:
* It is called from the training function on the training worker instead
of from the driver
* It can be called while training is still in progress

# Implementation Notes

The main idea is to use a worker-side counter and controller-side
counter as follows:
* Train worker: `ray.train.report` increments a
`num_reported_checkpoints` counter and puts the training result into its
queue
* Train controller: polls the training results from all worker,
registers the checkpoint, increments `num_reported_checkpoints`, and
then creates an asyncio task to notify asyncio Condition. This works
because asyncio Ray actors should always have an event loop.
* Train worker: `get_all_reported_results` uses an asyncio.Condition to
wait until the worker-side `num_reported_checkpoints` counter matches
its controller-side counterpart before returning the checkpoints. This
ensures that we wait for all pending reports to finish. It has access to
the controller actor through `init_train_context`.

`get_checkpoint` should be unaffected because it uses the local
checkpoint; we can consider changing it to use the "centrally committed"
checkpoint in the future.

# Testing

I ran the [ray train pytorch
example](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)
and called `ray.train.get_all_reported_checkpoints` at the end of each
epoch. The results are as expected; here are a few examples

`
epoch 4: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4})]
`

`
epoch 9: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-12.547310),
metrics={'loss': 0.2924661934375763, 'epoch': 5}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-28.538090),
metrics={'loss': 0.18640762567520142, 'epoch': 6}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-44.583228),
metrics={'loss': 0.12567029893398285, 'epoch': 7}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-00.540405),
metrics={'loss': 0.1620682030916214, 'epoch': 8}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-17.129973),
metrics={'loss': 0.07022886723279953, 'epoch': 9})]
`

I also modified all the Ray Train v2 unit tests that call
`ray.train.report`:
* `test_persistence` also verifies that `get_all_reported_checkpoints`
works on resumption
* `test_data_parallel_trainer` verifies that
`get_all_reported_checkpoints` stalls until all workers report.

I also verified that `get_all_reported_checkpoints` produced similar
output when called from Tune + Train.

I tried to test that `get_all_reported_checkpoints` finished even with
graceful abort but was unable to create such a scenario since
`get_all_reported_checkpoints` returns very quickly and each `report`
forms a barrier.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: sampan <sampan@anyscale.com>
jugalshah291 pushed a commit to jugalshah291/ray_fork that referenced this pull request Sep 11, 2025
… to enable this (ray-project#55556)

In the past, we used `RUN_CONTROLLER_AS_ACTOR_ENV_VAR` to toggle whether
to run the controller as a separate actor (we want this in most cases)
or on the current actor (we wanted this in Tune so we can propagate
`ray.train.report` from Train to Tune using the `TuneReportCallback`).

However, in order to implement `get_all_reported_checkpoints`
(ray-project#54555), we need to pass the
Train Controller actor to all the Train Worker actors. This method
wouldn't work when using Train from Tune because the Train Controller
actor handle would be the Tune Trainable actor handle which does not
have the async `get_all_reported_checkpoints` method.

This PR gets rid of `RUN_CONTROLLER_AS_ACTOR_ENV_VAR` once and for all
by making all communication between Train and Tune happen through a
lightweight `ray.util.Queue` actor instead of forcing Train and Tune to
happen on the same process.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Co-authored-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: jugalshah291 <shah.jugal291@gmail.com>
jugalshah291 pushed a commit to jugalshah291/ray_fork that referenced this pull request Sep 11, 2025
…ray-project#54555)

# Summary

This PR adds a `ray.train.get_all_reported_checkpoints` method that
allows users to get all the checkpoints they have reported from within
their training function.

This is different from
[Result](https://docs.ray.io/en/latest/train/user-guides/results.html)
in two ways:
* It is called from the training function on the training worker instead
of from the driver
* It can be called while training is still in progress

# Implementation Notes

The main idea is to use a worker-side counter and controller-side
counter as follows:
* Train worker: `ray.train.report` increments a
`num_reported_checkpoints` counter and puts the training result into its
queue
* Train controller: polls the training results from all worker,
registers the checkpoint, increments `num_reported_checkpoints`, and
then creates an asyncio task to notify asyncio Condition. This works
because asyncio Ray actors should always have an event loop.
* Train worker: `get_all_reported_results` uses an asyncio.Condition to
wait until the worker-side `num_reported_checkpoints` counter matches
its controller-side counterpart before returning the checkpoints. This
ensures that we wait for all pending reports to finish. It has access to
the controller actor through `init_train_context`.

`get_checkpoint` should be unaffected because it uses the local
checkpoint; we can consider changing it to use the "centrally committed"
checkpoint in the future.

# Testing

I ran the [ray train pytorch
example](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)
and called `ray.train.get_all_reported_checkpoints` at the end of each
epoch. The results are as expected; here are a few examples

`
epoch 4: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4})]
`

`
epoch 9: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-12.547310),
metrics={'loss': 0.2924661934375763, 'epoch': 5}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-28.538090),
metrics={'loss': 0.18640762567520142, 'epoch': 6}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-44.583228),
metrics={'loss': 0.12567029893398285, 'epoch': 7}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-00.540405),
metrics={'loss': 0.1620682030916214, 'epoch': 8}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-17.129973),
metrics={'loss': 0.07022886723279953, 'epoch': 9})]
`

I also modified all the Ray Train v2 unit tests that call
`ray.train.report`:
* `test_persistence` also verifies that `get_all_reported_checkpoints`
works on resumption
* `test_data_parallel_trainer` verifies that
`get_all_reported_checkpoints` stalls until all workers report.

I also verified that `get_all_reported_checkpoints` produced similar
output when called from Tune + Train.

I tried to test that `get_all_reported_checkpoints` finished even with
graceful abort but was unable to create such a scenario since
`get_all_reported_checkpoints` returns very quickly and each `report`
forms a barrier.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: jugalshah291 <shah.jugal291@gmail.com>
wyhong3103 pushed a commit to wyhong3103/ray that referenced this pull request Sep 12, 2025
…ray-project#54555)

# Summary

This PR adds a `ray.train.get_all_reported_checkpoints` method that
allows users to get all the checkpoints they have reported from within
their training function.

This is different from
[Result](https://docs.ray.io/en/latest/train/user-guides/results.html)
in two ways:
* It is called from the training function on the training worker instead
of from the driver
* It can be called while training is still in progress

# Implementation Notes

The main idea is to use a worker-side counter and controller-side
counter as follows:
* Train worker: `ray.train.report` increments a
`num_reported_checkpoints` counter and puts the training result into its
queue
* Train controller: polls the training results from all worker,
registers the checkpoint, increments `num_reported_checkpoints`, and
then creates an asyncio task to notify asyncio Condition. This works
because asyncio Ray actors should always have an event loop.
* Train worker: `get_all_reported_results` uses an asyncio.Condition to
wait until the worker-side `num_reported_checkpoints` counter matches
its controller-side counterpart before returning the checkpoints. This
ensures that we wait for all pending reports to finish. It has access to
the controller actor through `init_train_context`.

`get_checkpoint` should be unaffected because it uses the local
checkpoint; we can consider changing it to use the "centrally committed"
checkpoint in the future.

# Testing

I ran the [ray train pytorch
example](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)
and called `ray.train.get_all_reported_checkpoints` at the end of each
epoch. The results are as expected; here are a few examples

`
epoch 4: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4})]
`

`
epoch 9: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-12.547310),
metrics={'loss': 0.2924661934375763, 'epoch': 5}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-28.538090),
metrics={'loss': 0.18640762567520142, 'epoch': 6}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-44.583228),
metrics={'loss': 0.12567029893398285, 'epoch': 7}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-00.540405),
metrics={'loss': 0.1620682030916214, 'epoch': 8}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-17.129973),
metrics={'loss': 0.07022886723279953, 'epoch': 9})]
`

I also modified all the Ray Train v2 unit tests that call
`ray.train.report`:
* `test_persistence` also verifies that `get_all_reported_checkpoints`
works on resumption
* `test_data_parallel_trainer` verifies that
`get_all_reported_checkpoints` stalls until all workers report.

I also verified that `get_all_reported_checkpoints` produced similar
output when called from Tune + Train.

I tried to test that `get_all_reported_checkpoints` finished even with
graceful abort but was unable to create such a scenario since
`get_all_reported_checkpoints` returns very quickly and each `report`
forms a barrier.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: yenhong.wong <yenhong.wong@grabtaxi.com>
dstrodtman pushed a commit to dstrodtman/ray that referenced this pull request Oct 6, 2025
… to enable this (ray-project#55556)

In the past, we used `RUN_CONTROLLER_AS_ACTOR_ENV_VAR` to toggle whether
to run the controller as a separate actor (we want this in most cases)
or on the current actor (we wanted this in Tune so we can propagate
`ray.train.report` from Train to Tune using the `TuneReportCallback`).

However, in order to implement `get_all_reported_checkpoints`
(ray-project#54555), we need to pass the
Train Controller actor to all the Train Worker actors. This method
wouldn't work when using Train from Tune because the Train Controller
actor handle would be the Tune Trainable actor handle which does not
have the async `get_all_reported_checkpoints` method.

This PR gets rid of `RUN_CONTROLLER_AS_ACTOR_ENV_VAR` once and for all
by making all communication between Train and Tune happen through a
lightweight `ray.util.Queue` actor instead of forcing Train and Tune to
happen on the same process.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Co-authored-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
dstrodtman pushed a commit that referenced this pull request Oct 6, 2025
…#54555)

# Summary

This PR adds a `ray.train.get_all_reported_checkpoints` method that
allows users to get all the checkpoints they have reported from within
their training function.

This is different from
[Result](https://docs.ray.io/en/latest/train/user-guides/results.html)
in two ways:
* It is called from the training function on the training worker instead
of from the driver
* It can be called while training is still in progress

# Implementation Notes

The main idea is to use a worker-side counter and controller-side
counter as follows:
* Train worker: `ray.train.report` increments a
`num_reported_checkpoints` counter and puts the training result into its
queue
* Train controller: polls the training results from all worker,
registers the checkpoint, increments `num_reported_checkpoints`, and
then creates an asyncio task to notify asyncio Condition. This works
because asyncio Ray actors should always have an event loop.
* Train worker: `get_all_reported_results` uses an asyncio.Condition to
wait until the worker-side `num_reported_checkpoints` counter matches
its controller-side counterpart before returning the checkpoints. This
ensures that we wait for all pending reports to finish. It has access to
the controller actor through `init_train_context`.

`get_checkpoint` should be unaffected because it uses the local
checkpoint; we can consider changing it to use the "centrally committed"
checkpoint in the future.

# Testing

I ran the [ray train pytorch
example](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)
and called `ray.train.get_all_reported_checkpoints` at the end of each
epoch. The results are as expected; here are a few examples

`
epoch 4: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4})]
`

`
epoch 9: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-12.547310),
metrics={'loss': 0.2924661934375763, 'epoch': 5}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-28.538090),
metrics={'loss': 0.18640762567520142, 'epoch': 6}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-44.583228),
metrics={'loss': 0.12567029893398285, 'epoch': 7}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-00.540405),
metrics={'loss': 0.1620682030916214, 'epoch': 8}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-17.129973),
metrics={'loss': 0.07022886723279953, 'epoch': 9})]
`

I also modified all the Ray Train v2 unit tests that call
`ray.train.report`:
* `test_persistence` also verifies that `get_all_reported_checkpoints`
works on resumption
* `test_data_parallel_trainer` verifies that
`get_all_reported_checkpoints` stalls until all workers report.

I also verified that `get_all_reported_checkpoints` produced similar
output when called from Tune + Train.

I tried to test that `get_all_reported_checkpoints` finished even with
graceful abort but was unable to create such a scenario since
`get_all_reported_checkpoints` returns very quickly and each `report`
forms a barrier.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
snorkelopstesting1-a11y pushed a commit to snorkel-marlin-repos/ray-project_ray_pr_54555_f064cdeb-33ee-47e2-9c09-aeabc03b7f31 that referenced this pull request Oct 11, 2025
snorkelopstesting1-a11y added a commit to snorkel-marlin-repos/ray-project_ray_pr_54555_f064cdeb-33ee-47e2-9c09-aeabc03b7f31 that referenced this pull request Oct 11, 2025
…rted_checkpoints method

Merged from original PR #54555
Original: ray-project/ray#54555
snorkelopstesting1-a11y pushed a commit to snorkel-marlin-repos/ray-project_ray_pr_54555_01bc5849-a3ec-44d6-95e5-5f306bbf4838 that referenced this pull request Oct 11, 2025
snorkelopstesting1-a11y added a commit to snorkel-marlin-repos/ray-project_ray_pr_54555_01bc5849-a3ec-44d6-95e5-5f306bbf4838 that referenced this pull request Oct 11, 2025
…rted_checkpoints method

Merged from original PR #54555
Original: ray-project/ray#54555
snorkelopsstgtesting1-spec pushed a commit to snorkel-marlin-repos/ray-project_ray_pr_54555_caf8b278-b360-4cd5-a457-cf20fc5e2426 that referenced this pull request Oct 22, 2025
snorkelopstesting4-web added a commit to snorkel-marlin-repos/ray-project_ray_pr_54555_caf8b278-b360-4cd5-a457-cf20fc5e2426 that referenced this pull request Oct 22, 2025
…rted_checkpoints method

Merged from original PR #54555
Original: ray-project/ray#54555
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Nov 17, 2025
… to enable this (ray-project#55556)

In the past, we used `RUN_CONTROLLER_AS_ACTOR_ENV_VAR` to toggle whether
to run the controller as a separate actor (we want this in most cases)
or on the current actor (we wanted this in Tune so we can propagate
`ray.train.report` from Train to Tune using the `TuneReportCallback`).

However, in order to implement `get_all_reported_checkpoints`
(ray-project#54555), we need to pass the
Train Controller actor to all the Train Worker actors. This method
wouldn't work when using Train from Tune because the Train Controller
actor handle would be the Tune Trainable actor handle which does not
have the async `get_all_reported_checkpoints` method.

This PR gets rid of `RUN_CONTROLLER_AS_ACTOR_ENV_VAR` once and for all
by making all communication between Train and Tune happen through a
lightweight `ray.util.Queue` actor instead of forcing Train and Tune to
happen on the same process.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Co-authored-by: Timothy Seah <tseah@anyscale.com>
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Nov 17, 2025
…ray-project#54555)

# Summary

This PR adds a `ray.train.get_all_reported_checkpoints` method that
allows users to get all the checkpoints they have reported from within
their training function.

This is different from
[Result](https://docs.ray.io/en/latest/train/user-guides/results.html)
in two ways:
* It is called from the training function on the training worker instead
of from the driver
* It can be called while training is still in progress

# Implementation Notes

The main idea is to use a worker-side counter and controller-side
counter as follows:
* Train worker: `ray.train.report` increments a
`num_reported_checkpoints` counter and puts the training result into its
queue
* Train controller: polls the training results from all worker,
registers the checkpoint, increments `num_reported_checkpoints`, and
then creates an asyncio task to notify asyncio Condition. This works
because asyncio Ray actors should always have an event loop.
* Train worker: `get_all_reported_results` uses an asyncio.Condition to
wait until the worker-side `num_reported_checkpoints` counter matches
its controller-side counterpart before returning the checkpoints. This
ensures that we wait for all pending reports to finish. It has access to
the controller actor through `init_train_context`.

`get_checkpoint` should be unaffected because it uses the local
checkpoint; we can consider changing it to use the "centrally committed"
checkpoint in the future.

# Testing

I ran the [ray train pytorch
example](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)
and called `ray.train.get_all_reported_checkpoints` at the end of each
epoch. The results are as expected; here are a few examples

`
epoch 4: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4})]
`

`
epoch 9: [TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994),
metrics={'loss': 0.24510294198989868, 'epoch': 0}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694),
metrics={'loss': 0.23799467086791992, 'epoch': 1}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974),
metrics={'loss': 0.39628422260284424, 'epoch': 2}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211),
metrics={'loss': 0.15193207561969757, 'epoch': 3}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119),
metrics={'loss': 0.17416314780712128, 'epoch': 4}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-12.547310),
metrics={'loss': 0.2924661934375763, 'epoch': 5}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-28.538090),
metrics={'loss': 0.18640762567520142, 'epoch': 6}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-44.583228),
metrics={'loss': 0.12567029893398285, 'epoch': 7}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-00.540405),
metrics={'loss': 0.1620682030916214, 'epoch': 8}),
TrainingResult(checkpoint=Checkpoint(filesystem=local,
path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-17.129973),
metrics={'loss': 0.07022886723279953, 'epoch': 9})]
`

I also modified all the Ray Train v2 unit tests that call
`ray.train.report`:
* `test_persistence` also verifies that `get_all_reported_checkpoints`
works on resumption
* `test_data_parallel_trainer` verifies that
`get_all_reported_checkpoints` stalls until all workers report.

I also verified that `get_all_reported_checkpoints` produced similar
output when called from Tune + Train.

I tried to test that `get_all_reported_checkpoints` finished even with
graceful abort but was unable to create such a scenario since
`get_all_reported_checkpoints` returns very quickly and each `report`
forms a barrier.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue unstale A PR that has been marked unstale. It will not get marked stale again if this label is on it.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants