-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[train][checkpoint] Add validate_function and validate_config to ray.train.report #56360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train][checkpoint] Add validate_function and validate_config to ray.train.report #56360
Conversation
5a673c5 to
21656e1
Compare
…train.report Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
42aa2ff to
4d0c000
Compare
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
justinvyu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LFG
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/controller/controller.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/controller/controller.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
…troller shutdown Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
justinvyu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! I'll do a more thorough pass on the tests in the next round.
python/ray/train/v2/_internal/execution/controller/controller.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
justinvyu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be good after this!
python/ray/train/v2/_internal/execution/checkpoint/report_handler.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
…eport accordingly Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
Show resolved
Hide resolved
|
A few comments on 747e99b: I considered the following alternate implementation methods but decided against them for various reasons:
I reworked the
I'm not worrying about restoring
|
justinvyu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚢
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Why is using the same experiment directory a requirement here? If you reuse the same storage_path+name, then you'll end up "restoring" the previous run: https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html#job-driver-fault-tolerance The validation run shouldn't actually need to save any files or record checkpoints, so the storage path will be mostly unused. I think we should recommend something like this: def eval_only_train_func(config_dict):
# ...
# ray.train.report(metrics, checkpoint, NO_UPLOAD)
# !! The previous usage of ray.train.report w/ NO_UPLOAD is a bit confusing,
# since you're essentially checkpointing within a validation loop.
# The only reason why we needed the dummy checkpoint previously was so
# that we could get these `metrics` out of the training function.
metrics = {'score': mean_valid_loss.compute().item()}
return metrics
def validate_with_torch_trainer(checkpoint, config):
trainer = ray.train.torch.TorchTrainer(
eval_only_train_func,
train_loop_config={'checkpoint': checkpoint},
scaling_config=ray.train.ScalingConfig(num_workers=2, use_gpu=True),
# !! Just leave the default auto-generated UUID run name.
run_config=ray.train.RunConfig(storage_path="/mnt/cluster_storage"),
datasets={"test": config['dataset']},
)
result = trainer.fit()
return result.return_values[0]Note that |
Thanks, good catch! I agree that |
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Josh Kodi <joshkodi@gmail.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Aydin Abiar <aydin@anyscale.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Future-Outlier <eric901201@gmail.com>
Summary
The main change here is:
In doing this, we are leveraging Ray Train's single controller architecture to provide a global understanding of training progress. This makes it easy to do stuff like early stopping in the future.
A few other notes:
Result. I added a TODO to retry and time out in the future.Result.failed_validationscontains all the checkpoints that failed validations, which may or may not have been deleted.CheckpointManagerrestoration; right now we simply won't rerun interrupted validations.API Examples
You can define a
validate_functionwithmap_batchesor
TorchTrainer.Note the following about the
map_batchesandTorchTrainermethods:map_batchesmethod's__call__function must move the batch to the device, but theTorchTrainermethod's forward pass does not because the dataloader automatically does this.map_batchesmethod uses Ray Data's metric aggregation methods, whereas theTorchTrainermethod uses Torch's.Either way, you need to report with the validate function as follows:
Note that both methods pass the
test_datasetdirectly as a global variable - let me know if there is a better way to do this.Testing
I tried both API's above in an Anyscale workspace on 2 epochs:
The
map_batchesAPI:result.best_checkpointsat the end looks like[(Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_train_run-2025-09-15_18-31-44/checkpoint_2025-09-15_18-34-05.089176), {'loss': 2.0584681034088135, 'epoch': 0, 'score': 0.2369}), (Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_train_run-2025-09-15_18-31-44/checkpoint_2025-09-15_18-36-07.816422), {'loss': 1.8418619632720947, 'epoch': 1, 'score': 0.3152})]result.fitreturned) - (time train_func exited), was 23.8961102962s, which is equal to (time taken to set up last validation) + (time taken to perform last validation)The
TorchTrainerAPI:result.best_checkpointsat the end looks like[(Checkpoint(filesystem=local, path=/mnt/cluster_storage/06325ef3-fc29-48eb-bc21-f3fca71da238/checkpoint_2025-09-16_13-34-59.088135), {'loss': 1.7960023880004883, 'epoch': 0, 'score': 1.8543205261230469}), (Checkpoint(filesystem=local, path=/mnt/cluster_storage/06325ef3-fc29-48eb-bc21-f3fca71da238/checkpoint_2025-09-16_13-37-05.083466), {'loss': 1.615540862083435, 'epoch': 1, 'score': 1.6425774097442627})]I also tested the
TorchTrainerAPI with autoscaling enabled:The results are basically the same as above, but interestingly:
Note
Introduce async checkpoint validation via validate_fn/validate_config in report(), add ValidationManager, and refactor reporting to a TrainingReport with pending-checkpoint handling.
ray.train.v2.api.train_fn_utils.reportnow acceptsvalidate_fnandvalidate_configto run async checkpoint validation.ValidationManagerruns validation tasks, polls results, and updates checkpoint metrics.ValidationManagerintoTrainControllerandReportCallbackHandler._TrainingReportand_ValidationSpecto carry checkpoint, metrics, and validation spec end-to-end.ReportCallback.after_reportsignature to receivetraining_reportplus per-worker metrics.WorkerStatusnow holdstraining_report(renamed fromtraining_result).CheckpointManagersupports pending checkpoints and updates metrics post-validation; tie-break equal scores by report index; persist state before deletions._insert_into_sorted_listto accept typed items and optional tie-break map.UserCallbackHandler.after_reportnow sources checkpoint fromtraining_report.test_async_checkpointing_validation*,test_validation_manager).Written by Cursor Bugbot for commit 6e1ddbc. This will update automatically on new commits. Configure here.