Skip to content

Conversation

@TimothySeah
Copy link
Contributor

@TimothySeah TimothySeah commented Sep 9, 2025

Summary

The main change here is:

  • Train workers report validation function + validation config
  • Controller kicks off validation Ray task and associates its return value with the relevant checkpoint.
  • Main controller step polls workers and validations, only finishing when both are done.

In doing this, we are leveraging Ray Train's single controller architecture to provide a global understanding of training progress. This makes it easy to do stuff like early stopping in the future.

A few other notes:

  • When validations fail, training moves on, but the controller prints the failure and stores it in the Result. I added a TODO to retry and time out in the future.
  • We never delete pending checkpoints because they could end up being better than existing checkpoints. We may keep failed checkpoints depending on their initially reported metrics + the CheckpointConfig. Note that Result.failed_validations contains all the checkpoints that failed validations, which may or may not have been deleted.
  • I added a TODO to resume validations on CheckpointManager restoration; right now we simply won't rerun interrupted validations.
  • I added a TODO to rate limit validations with a queue; right now we always kick off Ray tasks and rely on Ray Core scheduling to rate limit.

API Examples

You can define a validate_function with map_batches

class Predictor:
    def __init__(self, checkpoint):
        self.model = create_model()
        with checkpoint.as_directory() as checkpoint_dir:
            model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
            self.model.load_state_dict(model_state_dict)
        self.model.cuda().eval()

    def __call__(self, batch):
        image = torch.as_tensor(batch['image'], dtype=torch.float32, device="cuda")
        label = torch.as_tensor(batch['label'], dtype=torch.float32, device="cuda")
        pred = self.model(image)
        return {'res': (pred.argmax(1) == label).cpu().numpy()}


def validate_with_map_batches(checkpoint, config):
    eval_res = config['dataset'].map_batches(
        Predictor,
        batch_size = 128,
        num_gpus=1,
        fn_constructor_kwargs = {'checkpoint': checkpoint},
        # guarantees worker on 2 actors because predictor is class.
        concurrency=2,
    )
    mean = eval_res.mean(['res'])
    return {
        'score': mean,
    }

or TorchTrainer.

def eval_only_train_func(config_dict):
    # Load the checkpoint
    model = create_model()
    with config_dict['checkpoint'].as_directory() as checkpoint_dir:
        model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
        model.load_state_dict(model_state_dict)
    model.cuda().eval()

    # Set up metrics and data loaders
    criterion = CrossEntropyLoss()
    mean_valid_loss = torchmetrics.MeanMetric().cuda()
    test_data_shard = ray.train.get_dataset_shard("test")
    test_dataloader = test_data_shard.iter_torch_batches(batch_size=256)

    # Compute and report metric
    with torch.no_grad():
        for batch in test_dataloader:
            images, labels = batch['image'], batch['label']
            outputs = model(images)
            loss = criterion(outputs, labels)
            mean_valid_loss(loss)
    with tempfile.TemporaryDirectory() as temp_dir:
        ray.train.report(
            metrics={'score': mean_valid_loss.compute().item()}, 
            checkpoint=ray.train.Checkpoint.from_directory(temp_dir),
            checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD,
        )

def validate_with_torch_trainer(checkpoint, config):
    trainer = ray.train.torch.TorchTrainer(
        eval_only_train_func,
        train_loop_config={'checkpoint': checkpoint},
        scaling_config=ray.train.ScalingConfig(num_workers=2, use_gpu=True),
        datasets={"test": config['dataset']},
    )
    result = trainer.fit()
    return result.metrics

Note the following about the map_batches and TorchTrainer methods:

  • During the validation forward pass, the map_batches method's __call__ function must move the batch to the device, but the TorchTrainer method's forward pass does not because the dataloader automatically does this.
  • The map_batches method uses Ray Data's metric aggregation methods, whereas the TorchTrainer method uses Torch's.

Either way, you need to report with the validate function as follows:

def train_func():
    ...
    for epoch in epochs:
            if rank == 0:
                metrics = {"loss": loss.item(), "epoch": epoch}
                iteration_checkpoint_dir = os.path.join(temp_checkpoint_dir, f"epoch_{epoch}_rank_{ray.train.get_context().get_world_rank()}")
                os.makedirs(iteration_checkpoint_dir, exist_ok=True)
                torch.save(
                    model.module.state_dict(),
                    os.path.join(iteration_checkpoint_dir, "model.pt")
                )
                ray.train.report(
                    metrics,
                    checkpoint=ray.train.Checkpoint.from_directory(iteration_checkpoint_dir),
                    checkpoint_upload_mode=CheckpointUploadMode.ASYNC,
                    validate_function=validate_function,
                    validate_config={
                        'dataset': test_dataset,
                    },
                )
            else:
                ray.train.report({}, None)

trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    datasets={"train": train_dataset},
    run_config=ray.train.RunConfig(
        storage_path="/mnt/cluster_storage",
    ),
)
result = trainer.fit()

Note that both methods pass the test_dataset directly as a global variable - let me know if there is a better way to do this.

Testing

I tried both API's above in an Anyscale workspace on 2 epochs:

The map_batches API:

  • Took 4m42s
Screenshot 2025-09-15 at 6 53 34 PM
  • The validations took 13.836704969406128s and 20.883531093597412s
  • Total report blocking time was negligible (0.007100105285644531s)
  • result.best_checkpoints at the end looks like [(Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_train_run-2025-09-15_18-31-44/checkpoint_2025-09-15_18-34-05.089176), {'loss': 2.0584681034088135, 'epoch': 0, 'score': 0.2369}), (Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_train_run-2025-09-15_18-31-44/checkpoint_2025-09-15_18-36-07.816422), {'loss': 1.8418619632720947, 'epoch': 1, 'score': 0.3152})]
  • The additional e2e time incurred by validation, measured as (time result.fit returned) - (time train_func exited), was 23.8961102962s, which is equal to (time taken to set up last validation) + (time taken to perform last validation)

The TorchTrainer API:

  • Took 4m49s
Screenshot 2025-09-16 at 1 38 16 PM
  • Interestingly, according to my print statements, the validations took 22.306870937347412s and 22.350095510482788s, but according to the train runs above, they took 17s and 18s.
  • Total report blocking time was negligible (0.006956338882446289s)
  • result.best_checkpoints at the end looks like [(Checkpoint(filesystem=local, path=/mnt/cluster_storage/06325ef3-fc29-48eb-bc21-f3fca71da238/checkpoint_2025-09-16_13-34-59.088135), {'loss': 1.7960023880004883, 'epoch': 0, 'score': 1.8543205261230469}), (Checkpoint(filesystem=local, path=/mnt/cluster_storage/06325ef3-fc29-48eb-bc21-f3fca71da238/checkpoint_2025-09-16_13-37-05.083466), {'loss': 1.615540862083435, 'epoch': 1, 'score': 1.6425774097442627})]
  • The additional e2e time incurred by validation was 26.8572409153s.

I also tested the TorchTrainer API with autoscaling enabled:

Screenshot 2025-09-16 at 2 14 41 PM

The results are basically the same as above, but interestingly:

  • The first validation took much longer than the second one because it includes autoscaling time.
  • The e2e time is around the same (in fact, it was actually shorter) because training continues during autoscaling and validation time is less than training time so validation can catch up to training.

Note

Introduce async checkpoint validation via validate_fn/validate_config in report(), add ValidationManager, and refactor reporting to a TrainingReport with pending-checkpoint handling.

  • API:
    • ray.train.v2.api.train_fn_utils.report now accepts validate_fn and validate_config to run async checkpoint validation.
  • Execution/Controller:
    • New ValidationManager runs validation tasks, polls results, and updates checkpoint metrics.
    • Wire ValidationManager into TrainController and ReportCallbackHandler.
  • Reporting Pipeline:
    • Introduce _TrainingReport and _ValidationSpec to carry checkpoint, metrics, and validation spec end-to-end.
    • Change ReportCallback.after_report signature to receive training_report plus per-worker metrics.
    • WorkerStatus now holds training_report (renamed from training_result).
  • Checkpointing:
    • CheckpointManager supports pending checkpoints and updates metrics post-validation; tie-break equal scores by report index; persist state before deletions.
    • Enhanced _insert_into_sorted_list to accept typed items and optional tie-break map.
  • User Callbacks:
    • UserCallbackHandler.after_report now sources checkpoint from training_report.
  • Tests/Build:
    • Add tests for validation and pending-checkpoint logic; rename and add Bazel targets (test_async_checkpointing_validation*, test_validation_manager).

Written by Cursor Bugbot for commit 6e1ddbc. This will update automatically on new commits. Configure here.

…train.report

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
@TimothySeah TimothySeah marked this pull request as ready for review September 16, 2025 21:30
@TimothySeah TimothySeah requested review from a team as code owners September 16, 2025 21:30
Signed-off-by: Timothy Seah <tseah@anyscale.com>
@ray-gardener ray-gardener bot added the train Ray Train Related Issue label Sep 17, 2025
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LFG

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
…troller shutdown

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
cursor[bot]

This comment was marked as outdated.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I'll do a more thorough pass on the tests in the next round.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
cursor[bot]

This comment was marked as outdated.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be good after this!

Signed-off-by: Timothy Seah <tseah@anyscale.com>
…eport accordingly

Signed-off-by: Timothy Seah <tseah@anyscale.com>
cursor[bot]

This comment was marked as outdated.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
cursor[bot]

This comment was marked as outdated.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
@TimothySeah
Copy link
Contributor Author

A few comments on 747e99b:

I considered the following alternate implementation methods but decided against them for various reasons:

  • base CheckpointManager also updates checkpoint_to_report_index. I decided against this to minimize train v1 changes.
  • Store report index inside trainingresult: this would be a large refactoring and expose unnecessary information to users

I reworked the test_update_checkpoints_with_metrics. It basically covered the same cases as test_pending_checkpoint_management except:

  • trying to update checkpoint that is not pending: extracted into test_pending_checkpoint_management_finalized_checkpoint
  • trying to update checkpoint that is not in checkpoint_results: already covered by test_update_checkpoints_with_metrics_not_in_checkpoint_results

I'm not worrying about restoring _checkpoint_to_report_index for now because:

  • I already have a TODO to restore canceled validations
  • Because we do not restore canceled validations, new validations will be associated with new reports, which will have report numbers.

Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚢

Signed-off-by: Timothy Seah <tseah@anyscale.com>
@TimothySeah TimothySeah added the go add ONLY when ready to merge, run all tests label Sep 27, 2025
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
@justinvyu
Copy link
Contributor

justinvyu commented Sep 29, 2025

If you use the TorchTrainer way, you also need to pass the experiment_dir_name to comply with the requirement that each TorchTrainer manages checkpoints from the same experiment directory.

Why is using the same experiment directory a requirement here?

If you reuse the same storage_path+name, then you'll end up "restoring" the previous run: https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html#job-driver-fault-tolerance
This is probably why you were running into the "checkpoint directory" matching prefix assertion error. We don't need to "restore" the training run while launching the validation run.

The validation run shouldn't actually need to save any files or record checkpoints, so the storage path will be mostly unused.

I think we should recommend something like this:

def eval_only_train_func(config_dict):
    # ...

    # ray.train.report(metrics, checkpoint, NO_UPLOAD)
    # !! The previous usage of ray.train.report w/ NO_UPLOAD is a bit confusing,
    # since you're essentially checkpointing within a validation loop.
    # The only reason why we needed the dummy checkpoint previously was so
    # that we could get these `metrics` out of the training function.
    metrics = {'score': mean_valid_loss.compute().item()}
    return metrics

def validate_with_torch_trainer(checkpoint, config):
    trainer = ray.train.torch.TorchTrainer(
        eval_only_train_func,
        train_loop_config={'checkpoint': checkpoint},
        scaling_config=ray.train.ScalingConfig(num_workers=2, use_gpu=True),
        # !! Just leave the default auto-generated UUID run name.
        run_config=ray.train.RunConfig(storage_path="/mnt/cluster_storage"),
        datasets={"test": config['dataset']},
    )
    result = trainer.fit()
    return result.return_values[0]

Note that result.return_values is a new API but I think it's generally useful (user have asked for something similar) and avoids the need to attach dummy checkpoints during validation. We can recommend the dummy checkpoint workaround for now though, while deciding on the return values API.

@TimothySeah
Copy link
Contributor Author

TimothySeah commented Sep 29, 2025

If you use the TorchTrainer way, you also need to pass the experiment_dir_name to comply with the requirement that each TorchTrainer manages checkpoints from the same experiment directory.

Why is using the same experiment directory a requirement here?

If you reuse the same storage_path+name, then you'll end up "restoring" the previous run: https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html#job-driver-fault-tolerance This is probably why you were running into the "checkpoint directory" matching prefix assertion error. We don't need to "restore" the training run while launching the validation run.

The validation run shouldn't actually need to save any files or record checkpoints, so the storage path will be mostly unused.

I think we should recommend something like this:

def eval_only_train_func(config_dict):
    # ...

    # ray.train.report(metrics, checkpoint, NO_UPLOAD)
    # !! The previous usage of ray.train.report w/ NO_UPLOAD is a bit confusing,
    # since you're essentially checkpointing within a validation loop.
    # The only reason why we needed the dummy checkpoint previously was so
    # that we could get these `metrics` out of the training function.
    metrics = {'score': mean_valid_loss.compute().item()}
    return metrics

def validate_with_torch_trainer(checkpoint, config):
    trainer = ray.train.torch.TorchTrainer(
        eval_only_train_func,
        train_loop_config={'checkpoint': checkpoint},
        scaling_config=ray.train.ScalingConfig(num_workers=2, use_gpu=True),
        # !! Just leave the default auto-generated UUID run name.
        run_config=ray.train.RunConfig(storage_path="/mnt/cluster_storage"),
        datasets={"test": config['dataset']},
    )
    result = trainer.fit()
    return result.return_values[0]

Note that result.return_values is a new API but I think it's generally useful (user have asked for something similar) and avoids the need to attach dummy checkpoints during validation. We can recommend the dummy checkpoint workaround for now though, while deciding on the return values API.

Thanks, good catch! I agree that return_values is a better API; like you pointed out, I was only calling report in eval_only_train_func because that API doesn't exist yet. I initially required both TorchTrainers to share the experiment directory so that the eval worker can report the checkpoint being validated, but after thinking about it more, I realize that we can just report a dummy checkpoint instead, which removes that requirement. I've also update the example in the PR description accordingly.

@justinvyu justinvyu merged commit 8564041 into ray-project:master Sep 29, 2025
6 checks passed
dstrodtman pushed a commit to dstrodtman/ray that referenced this pull request Oct 6, 2025
…train.report (ray-project#56360)

The main change here is:
* Train workers report validation function + validation config
* Controller kicks off validation Ray task and associates its return
value with the relevant checkpoint.
* Main controller step polls workers and validations, only finishing
when both are done.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
joshkodi pushed a commit to joshkodi/ray that referenced this pull request Oct 13, 2025
…train.report (ray-project#56360)

The main change here is:
* Train workers report validation function + validation config
* Controller kicks off validation Ray task and associates its return
value with the relevant checkpoint.
* Main controller step polls workers and validations, only finishing
when both are done.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Josh Kodi <joshkodi@gmail.com>
justinyeh1995 pushed a commit to justinyeh1995/ray that referenced this pull request Oct 20, 2025
…train.report (ray-project#56360)

The main change here is:
* Train workers report validation function + validation config
* Controller kicks off validation Ray task and associates its return
value with the relevant checkpoint.
* Main controller step polls workers and validations, only finishing
when both are done.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Nov 17, 2025
…train.report (ray-project#56360)

The main change here is:
* Train workers report validation function + validation config
* Controller kicks off validation Ray task and associates its return
value with the relevant checkpoint.
* Main controller step polls workers and validations, only finishing
when both are done.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Aydin-ab pushed a commit to Aydin-ab/ray-aydin that referenced this pull request Nov 19, 2025
…train.report (ray-project#56360)

The main change here is:
* Train workers report validation function + validation config
* Controller kicks off validation Ray task and associates its return
value with the relevant checkpoint.
* Main controller step polls workers and validations, only finishing
when both are done.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Aydin Abiar <aydin@anyscale.com>
Future-Outlier pushed a commit to Future-Outlier/ray that referenced this pull request Dec 7, 2025
…train.report (ray-project#56360)

The main change here is:
* Train workers report validation function + validation config
* Controller kicks off validation Ray task and associates its return
value with the relevant checkpoint.
* Main controller step polls workers and validations, only finishing
when both are done.

---------

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Future-Outlier <eric901201@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants