Skip to content

Commit

Permalink
Test with AIR trainer
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
  • Loading branch information
justinvyu committed Jan 4, 2023
1 parent 41ce606 commit 84decad
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@

import ray
from ray import tune
from ray.air import Checkpoint, CheckpointConfig, FailureConfig, RunConfig, session
from ray.air import (
Checkpoint,
CheckpointConfig,
FailureConfig,
RunConfig,
ScalingConfig,
session,
)
from ray.air._internal.remote_storage import delete_at_uri, download_from_uri
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.tune import Callback, Trainable
from ray.tune.execution.trial_runner import _find_newest_experiment_checkpoint
from ray.tune.experiment import Trial
Expand Down Expand Up @@ -803,7 +811,8 @@ def on_trial_result(self, runner, trial, result):
)


def test_checkpoints_saved_after_resume(tmp_path):
@pytest.mark.parametrize("use_air_trainer", [True, False])
def test_checkpoints_saved_after_resume(tmp_path, use_air_trainer):
"""Checkpoints saved after experiment restore should pick up at the correct
iteration and should not overwrite the checkpoints from the original run.
Old checkpoints should still be deleted if the total number of checkpoints
Expand Down Expand Up @@ -831,19 +840,30 @@ def get_checkpoints(experiment_dir):
fail_marker = tmp_path / "fail_marker"
fail_marker.write_text("", encoding="utf-8")

trainable = (
DataParallelTrainer(
_train_fn_sometimes_failing, scaling_config=ScalingConfig(num_workers=1)
)
if use_air_trainer
else _train_fn_sometimes_failing
)
param_space = {
"failing_hanging": (fail_marker, None),
"num_epochs": 2,
}
if use_air_trainer:
param_space = {"train_loop_config": param_space}

num_to_keep = 4
tuner = Tuner(
_train_fn_sometimes_failing,
trainable,
tune_config=TuneConfig(num_samples=1),
run_config=RunConfig(
name="exp_name",
local_dir=str(tmp_path),
checkpoint_config=CheckpointConfig(num_to_keep=num_to_keep),
),
param_space={
"failing_hanging": (fail_marker, None),
"num_epochs": 2,
},
param_space=param_space,
)
results = tuner.fit()
training_iteration = results[0].metrics["training_iteration"]
Expand Down

0 comments on commit 84decad

Please sign in to comment.