Skip to content

Commit

Permalink
[Tune] Update trainable remote_checkpoint_dir upon actor reuse (ray…
Browse files Browse the repository at this point in the history
…-project#32420)

This PR fixes trainable actor reuse to update the remote trial directory that it's writing checkpoints to.

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: elliottower <elliot@elliottower.com>
  • Loading branch information
justinvyu authored and elliottower committed Apr 22, 2023
1 parent fb7cc30 commit 28a3c7e
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 7 deletions.
2 changes: 2 additions & 0 deletions python/ray/air/_internal/remote_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import fnmatch
import os
import urllib.parse
from pathlib import Path
from pkg_resources import packaging
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -281,6 +282,7 @@ def _should_exclude(candidate: str) -> bool:
full_source_path = os.path.normpath(os.path.join(local_path, candidate))
full_target_path = os.path.normpath(os.path.join(bucket_path, candidate))

_ensure_directory(str(Path(full_target_path).parent))
_pyarrow_fs_copy_files(
full_source_path, full_target_path, destination_filesystem=fs
)
Expand Down
6 changes: 5 additions & 1 deletion python/ray/tune/execution/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,11 @@ def reset_trial(
with warn_if_slow("reset"):
try:
reset_val = ray.get(
trainable.reset.remote(extra_config, logger_creator),
trainable.reset.remote(
extra_config,
logger_creator=logger_creator,
remote_checkpoint_dir=trial.remote_checkpoint_dir,
),
timeout=DEFAULT_GET_TIMEOUT,
)
except GetTimeoutError:
Expand Down
81 changes: 76 additions & 5 deletions python/ray/tune/tests/test_actor_reuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray import tune, logger
from ray.tune import Trainable, run_experiments, register_trainable
from ray.tune.error import TuneError
from ray.tune.result_grid import ResultGrid
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.tune import _check_mixin

Expand All @@ -20,6 +21,13 @@ def ray_start_1_cpu():
os.environ.pop("TUNE_STATE_REFRESH_PERIOD", None)


@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
ray.shutdown()


@pytest.fixture
def ray_start_4_cpus_extra():
address_info = ray.init(num_cpus=4, resources={"extra": 4})
Expand Down Expand Up @@ -85,16 +93,17 @@ def default_resource_request(cls, config):
return None


def _run_trials_with_frequent_pauses(trainable, reuse=False):
def _run_trials_with_frequent_pauses(trainable, reuse=False, **kwargs):
analysis = tune.run(
trainable,
num_samples=1,
config={"id": tune.grid_search([0, 1, 2, 3])},
reuse_actors=reuse,
scheduler=FrequentPausesScheduler(),
verbose=0,
**kwargs,
)
return analysis.trials
return analysis


def test_trial_reuse_disabled(ray_start_1_cpu):
Expand All @@ -104,7 +113,8 @@ def test_trial_reuse_disabled(ray_start_1_cpu):
We assert the `num_resets` of each trainable class to be 0 (no reuse).
"""
trials = _run_trials_with_frequent_pauses(MyResettableClass, reuse=False)
analysis = _run_trials_with_frequent_pauses(MyResettableClass, reuse=False)
trials = analysis.trials
assert [t.last_result["id"] for t in trials] == [0, 1, 2, 3]
assert [t.last_result["iter"] for t in trials] == [2, 2, 2, 2]
assert [t.last_result["num_resets"] for t in trials] == [0, 0, 0, 0]
Expand All @@ -117,7 +127,8 @@ def test_trial_reuse_disabled_per_default(ray_start_1_cpu):
We assert the `num_resets` of each trainable class to be 0 (no reuse).
"""
trials = _run_trials_with_frequent_pauses(MyResettableClass, reuse=None)
analysis = _run_trials_with_frequent_pauses(MyResettableClass, reuse=None)
trials = analysis.trials
assert [t.last_result["id"] for t in trials] == [0, 1, 2, 3]
assert [t.last_result["iter"] for t in trials] == [2, 2, 2, 2]
assert [t.last_result["num_resets"] for t in trials] == [0, 0, 0, 0]
Expand All @@ -136,7 +147,8 @@ def test_trial_reuse_enabled(ray_start_1_cpu):
- After each iteration, trials are paused and actors cached for reuse
- Thus, the first trial finishes after 4 resets, the second after 5, etc.
"""
trials = _run_trials_with_frequent_pauses(MyResettableClass, reuse=True)
analysis = _run_trials_with_frequent_pauses(MyResettableClass, reuse=True)
trials = analysis.trials
assert [t.last_result["id"] for t in trials] == [0, 1, 2, 3]
assert [t.last_result["iter"] for t in trials] == [2, 2, 2, 2]
assert [t.last_result["num_resets"] for t in trials] == [4, 5, 6, 7]
Expand Down Expand Up @@ -398,5 +410,64 @@ class MyTrainable(Trainable):
assert _check_mixin(mlflow_mixin(MyTrainable))


def test_remote_trial_dir_with_reuse_actors(ray_start_2_cpus, tmp_path):
"""Check that the trainable has its remote directory set to the right
location, when new trials get swapped in on actor reuse.
Each trial runs for 2 iterations, with checkpoint_freq=1, so each remote
trial dir should have 2 checkpoints.
"""
tmp_target = str(tmp_path / "upload_dir")
exp_name = "remote_trial_dir_update_on_actor_reuse"

def get_remote_trial_dir(trial_id: int):
return os.path.join(tmp_target, exp_name, str(trial_id))

class _MyResettableClass(MyResettableClass):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._should_raise = False

def load_checkpoint(self, *args, **kwargs):
super().load_checkpoint(*args, **kwargs)

# Make sure that `remote_checkpoint_dir` gets updated correctly
trial_id = self.config.get("id")
remote_trial_dir = get_remote_trial_dir(trial_id)
if self.remote_checkpoint_dir != "file://" + remote_trial_dir:
# Delay raising the exception, since raising here would cause
# an unhandled exception that doesn't fail the test.
self._should_raise = True

def step(self):
if self._should_raise:
raise RuntimeError(
f"Failing! {self.remote_checkpoint_dir} not updated properly "
f"for trial {self.config.get('id')}"
)
return super().step()

analysis = _run_trials_with_frequent_pauses(
_MyResettableClass,
reuse=True,
max_concurrent_trials=2,
local_dir=str(tmp_path),
name=exp_name,
sync_config=tune.SyncConfig(upload_dir=f"file://{tmp_target}"),
trial_dirname_creator=lambda t: str(t.config.get("id")),
checkpoint_freq=1,
)
result_grid = ResultGrid(analysis)
assert not result_grid.errors

# Check that each remote trial dir has 2 checkpoints.
for result in result_grid:
trial_id = result.config["id"]
remote_dir = get_remote_trial_dir(trial_id)
num_checkpoints = len(
[file for file in os.listdir(remote_dir) if file.startswith("checkpoint_")]
)
assert num_checkpoints == 2


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
3 changes: 2 additions & 1 deletion python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def export_model(
export_dir = export_dir or self.logdir
return self._export_model(export_formats, export_dir)

def reset(self, new_config, logger_creator=None):
def reset(self, new_config, logger_creator=None, remote_checkpoint_dir=None):
"""Resets trial for use with new config.
Subclasses should override reset_config() to actually
Expand Down Expand Up @@ -914,6 +914,7 @@ def reset(self, new_config, logger_creator=None):
self._time_since_restore = 0.0
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
self.remote_checkpoint_dir = remote_checkpoint_dir
self._restored = False

return True
Expand Down

0 comments on commit 28a3c7e

Please sign in to comment.