Skip to content

Commit

Permalink
[Train] Fix off-by-one AIR Trainer checkpoint ID indexing on restore (r…
Browse files Browse the repository at this point in the history
…ay-project#31423)

This PR is a follow-up to ray-project#31231 to save checkpoints to the correctly indexed directory upon restore. The "latest checkpoint ID" that's used to generate the next checkpoint directory (`checkpoint_0000<latest_checkpoint_id>`) is off by one when restoring an AIR trainer.

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: tmynn <hovhannes.tamoyan@gmail.com>
  • Loading branch information
justinvyu authored and tamohannes committed Jan 25, 2023
1 parent d8b0245 commit 7478482
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
7 changes: 6 additions & 1 deletion python/ray/train/_internal/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ def _load_checkpoint(
) -> Optional[Union[Dict, Checkpoint]]:
loaded_checkpoint = super()._load_checkpoint(checkpoint_to_load)
assert not loaded_checkpoint or isinstance(loaded_checkpoint, Checkpoint)
self._latest_checkpoint_id = getattr(loaded_checkpoint, TUNE_CHECKPOINT_ID, 0)
# `latest_checkpoint_id` will be the id assigned to the next checkpoint,
# which should be one more than the loaded checkpoint's id
# If no checkpoint is loaded, initialize this to 0
self._latest_checkpoint_id = (
getattr(loaded_checkpoint, TUNE_CHECKPOINT_ID, -1) + 1
)
return loaded_checkpoint

def add_tune_checkpoint_id(self, checkpoint: Checkpoint):
Expand Down
34 changes: 27 additions & 7 deletions python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@

import ray
from ray import tune
from ray.air import Checkpoint, CheckpointConfig, FailureConfig, RunConfig, session
from ray.air import (
Checkpoint,
CheckpointConfig,
FailureConfig,
RunConfig,
ScalingConfig,
session,
)
from ray.air._internal.remote_storage import delete_at_uri, download_from_uri
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.tune import Callback, Trainable
from ray.tune.execution.trial_runner import _find_newest_experiment_checkpoint
from ray.tune.experiment import Trial
Expand Down Expand Up @@ -803,7 +811,8 @@ def on_trial_result(self, runner, trial, result):
)


def test_checkpoints_saved_after_resume(tmp_path):
@pytest.mark.parametrize("use_air_trainer", [True, False])
def test_checkpoints_saved_after_resume(tmp_path, use_air_trainer):
"""Checkpoints saved after experiment restore should pick up at the correct
iteration and should not overwrite the checkpoints from the original run.
Old checkpoints should still be deleted if the total number of checkpoints
Expand Down Expand Up @@ -831,19 +840,30 @@ def get_checkpoints(experiment_dir):
fail_marker = tmp_path / "fail_marker"
fail_marker.write_text("", encoding="utf-8")

trainable = (
DataParallelTrainer(
_train_fn_sometimes_failing, scaling_config=ScalingConfig(num_workers=1)
)
if use_air_trainer
else _train_fn_sometimes_failing
)
param_space = {
"failing_hanging": (fail_marker, None),
"num_epochs": 2,
}
if use_air_trainer:
param_space = {"train_loop_config": param_space}

num_to_keep = 4
tuner = Tuner(
_train_fn_sometimes_failing,
trainable,
tune_config=TuneConfig(num_samples=1),
run_config=RunConfig(
name="exp_name",
local_dir=str(tmp_path),
checkpoint_config=CheckpointConfig(num_to_keep=num_to_keep),
),
param_space={
"failing_hanging": (fail_marker, None),
"num_epochs": 2,
},
param_space=param_space,
)
results = tuner.fit()
training_iteration = results[0].metrics["training_iteration"]
Expand Down

0 comments on commit 7478482

Please sign in to comment.