Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Fix off-by-one AIR Trainer checkpoint ID indexing on restore #31423

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/ray/train/_internal/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ def _load_checkpoint(
) -> Optional[Union[Dict, Checkpoint]]:
loaded_checkpoint = super()._load_checkpoint(checkpoint_to_load)
assert not loaded_checkpoint or isinstance(loaded_checkpoint, Checkpoint)
self._latest_checkpoint_id = getattr(loaded_checkpoint, TUNE_CHECKPOINT_ID, 0)
# `latest_checkpoint_id` will be the id assigned to the next checkpoint,
# which should be one more than the loaded checkpoint's id
# If no checkpoint is loaded, initialize this to 0
self._latest_checkpoint_id = (
getattr(loaded_checkpoint, TUNE_CHECKPOINT_ID, -1) + 1
)
return loaded_checkpoint

def add_tune_checkpoint_id(self, checkpoint: Checkpoint):
Expand Down
34 changes: 27 additions & 7 deletions python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@

import ray
from ray import tune
from ray.air import Checkpoint, CheckpointConfig, FailureConfig, RunConfig, session
from ray.air import (
Checkpoint,
CheckpointConfig,
FailureConfig,
RunConfig,
ScalingConfig,
session,
)
from ray.air._internal.remote_storage import delete_at_uri, download_from_uri
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.tune import Callback, Trainable
from ray.tune.execution.trial_runner import _find_newest_experiment_checkpoint
from ray.tune.experiment import Trial
Expand Down Expand Up @@ -803,7 +811,8 @@ def on_trial_result(self, runner, trial, result):
)


def test_checkpoints_saved_after_resume(tmp_path):
@pytest.mark.parametrize("use_air_trainer", [True, False])
def test_checkpoints_saved_after_resume(tmp_path, use_air_trainer):
"""Checkpoints saved after experiment restore should pick up at the correct
iteration and should not overwrite the checkpoints from the original run.
Old checkpoints should still be deleted if the total number of checkpoints
Expand Down Expand Up @@ -831,19 +840,30 @@ def get_checkpoints(experiment_dir):
fail_marker = tmp_path / "fail_marker"
fail_marker.write_text("", encoding="utf-8")

trainable = (
DataParallelTrainer(
_train_fn_sometimes_failing, scaling_config=ScalingConfig(num_workers=1)
)
if use_air_trainer
else _train_fn_sometimes_failing
)
param_space = {
"failing_hanging": (fail_marker, None),
"num_epochs": 2,
}
if use_air_trainer:
param_space = {"train_loop_config": param_space}

num_to_keep = 4
tuner = Tuner(
_train_fn_sometimes_failing,
trainable,
tune_config=TuneConfig(num_samples=1),
run_config=RunConfig(
name="exp_name",
local_dir=str(tmp_path),
checkpoint_config=CheckpointConfig(num_to_keep=num_to_keep),
),
param_space={
"failing_hanging": (fail_marker, None),
"num_epochs": 2,
},
param_space=param_space,
)
results = tuner.fit()
training_iteration = results[0].metrics["training_iteration"]
Expand Down