Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Enable experiment restore from moved cloud uri #31669

Merged
merged 26 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
97181c9
Remove duplicate local_dir + rename to be the same as trial runner
justinvyu Jan 13, 2023
9f38a34
Add Trial.from_json_state alternate constructor
justinvyu Jan 13, 2023
af2b189
Only save around sync config + experiment dir name to construct remot…
justinvyu Jan 13, 2023
d225599
Update upload dir and experiment name upon URI restore
justinvyu Jan 13, 2023
e0267a8
Refactor relative checkpoint dir logic to be in local_dir setter instead
justinvyu Jan 13, 2023
87c5452
Refactor trial loading logic to use Trial.from_json_state
justinvyu Jan 13, 2023
e202dc0
Switch ExperimentAnalysis trial loading to from_json_state as well
justinvyu Jan 13, 2023
dde5d87
Add restore test from moved URI
justinvyu Jan 13, 2023
f6299a3
Remove load trial utilities
justinvyu Jan 13, 2023
7f810bd
Fix sync_config=None case
justinvyu Jan 14, 2023
936174c
Fix upload dir parsing to include query string
justinvyu Jan 14, 2023
92729aa
Merge branch 'master' of https://github.com/ray-project/ray into rest…
justinvyu Jan 18, 2023
20104dc
Sort trials in results by trial id
justinvyu Jan 18, 2023
bce61b4
Strict input type for from_json_state
justinvyu Jan 18, 2023
3ca887d
Clear memory filesys fixture for tests that use it
justinvyu Jan 18, 2023
12068a8
Remove useless check (trial logdir is loaded from the checkpoint)
justinvyu Jan 18, 2023
687ea5e
Fix tests related to from_json_state
justinvyu Jan 18, 2023
c65999f
Update TrialRunner docstring, remove remote_checkpoint_dir kwargs fro…
justinvyu Jan 18, 2023
f16e513
Simplify mocks in test_api
justinvyu Jan 18, 2023
3e5bb5b
Merge branch 'master' of https://github.com/ray-project/ray into rest…
justinvyu Jan 18, 2023
69a7221
Remove unnecessary import
justinvyu Jan 18, 2023
fdeaa19
Remove really old backwards compatibility test
justinvyu Jan 18, 2023
b54c30c
Remove unnecessary import
justinvyu Jan 18, 2023
b814b72
Merge branch 'master' of https://github.com/ray-project/ray into rest…
justinvyu Jan 24, 2023
2132cfb
Remove unused decode fn
justinvyu Jan 24, 2023
3257b9a
Clarify why the checkpoint gets moved
justinvyu Jan 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@
TRAINING_ITERATION,
)
from ray.tune.experiment import Trial
from ray.tune.execution.trial_runner import (
_find_newest_experiment_checkpoint,
_load_trial_from_checkpoint,
)
from ray.tune.execution.trial_runner import _find_newest_experiment_checkpoint
from ray.tune.trainable.util import TrainableUtil
from ray.tune.utils.util import unflattened_lookup

Expand Down Expand Up @@ -143,13 +140,10 @@ def _load_checkpoints_from_latest(self, latest_checkpoint: List[str]) -> None:

if "checkpoints" not in experiment_state:
raise TuneError("Experiment state invalid; no checkpoints found.")

self._checkpoints_and_paths += [
(_decode_checkpoint_from_experiment_state(cp), Path(path).parent)
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
for cp in experiment_state["checkpoints"]
(cp, Path(path).parent) for cp in experiment_state["checkpoints"]
]
self._checkpoints_and_paths = sorted(
self._checkpoints_and_paths, key=lambda tup: tup[0]["trial_id"]
)

def _get_latest_checkpoint(self, experiment_checkpoint_path: Path) -> List[str]:
# Case 1: Dir specified, find latest checkpoint.
Expand Down Expand Up @@ -798,12 +792,10 @@ def _get_trial_paths(self) -> List[str]:
"out of sync, as checkpointing is periodic."
)
self.trials = []
_trial_paths = []
for checkpoint, path in self._checkpoints_and_paths:
for trial_json_state, path in self._checkpoints_and_paths:
try:
trial = _load_trial_from_checkpoint(
checkpoint, stub=True, new_local_dir=str(path)
)
trial = Trial.from_json_state(trial_json_state, stub=True)
trial.local_dir = str(path)
except Exception:
logger.warning(
f"Could not load trials from experiment checkpoint. "
Expand All @@ -814,7 +806,9 @@ def _get_trial_paths(self) -> List[str]:
)
continue
self.trials.append(trial)
_trial_paths.append(str(trial.logdir))

self.trials.sort(key=lambda trial: trial.trial_id)
_trial_paths = [str(trial.logdir) for trial in self.trials]

if not _trial_paths:
raise TuneError("No trials found.")
Expand Down Expand Up @@ -882,7 +876,3 @@ def make_stub_if_needed(trial: Trial) -> Trial:

state["trials"] = [make_stub_if_needed(t) for t in state["trials"]]
return state


def _decode_checkpoint_from_experiment_state(cp: Union[str, dict]) -> dict:
return json.loads(cp, cls=TuneFunctionDecoder) if isinstance(cp, str) else cp
163 changes: 64 additions & 99 deletions python/ray/tune/execution/trial_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, DefaultDict, List, Mapping, Optional, Union, Tuple, Set
from typing import DefaultDict, List, Optional, Union, Tuple, Set

import click
from datetime import datetime
Expand Down Expand Up @@ -67,65 +67,6 @@ def _find_newest_experiment_checkpoint(ckpt_dir) -> Optional[str]:
return max(full_paths)


def _load_trial_from_checkpoint(
trial_cp: dict, stub: bool = False, new_local_dir: Optional[str] = None
) -> Trial:
"""Create a Trial from the state stored in the experiment checkpoint.

Args:
trial_cp: Trial state from the experiment checkpoint, which is loaded
from the trial's `Trial.get_json_state`.
stub: Whether or not to validate the trainable name when creating the Trial.
Used for testing purposes for creating mocks.
new_local_dir: If set, this `local_dir` will overwrite what's saved in the
`trial_cp` state. Used in the case that the trial directory has moved.
The Trial `logdir` and the persistent trial checkpoints will have their
paths updated relative to this new directory.

Returns:
new_trial: New trial with state loaded from experiment checkpoint
"""
new_trial = Trial(
trial_cp["trainable_name"],
stub=stub,
_setup_default_resource=False,
)
if new_local_dir:
trial_cp["local_dir"] = new_local_dir
new_trial.__setstate__(trial_cp)
new_trial.refresh_default_resource_request()
return new_trial


def _load_trials_from_experiment_checkpoint(
experiment_checkpoint: Mapping[str, Any],
stub: bool = False,
new_local_dir: Optional[str] = None,
) -> List[Trial]:
"""Create trial objects from experiment checkpoint.

Given an experiment checkpoint (TrialRunner state dict), return
list of trials. See `_ExperimentCheckpointManager.checkpoint` for
what's saved in the TrialRunner state dict.
"""
checkpoints = [
json.loads(cp, cls=TuneFunctionDecoder) if isinstance(cp, str) else cp
for cp in experiment_checkpoint["checkpoints"]
]

trials = []
for trial_cp in checkpoints:
trials.append(
_load_trial_from_checkpoint(
trial_cp,
stub=stub,
new_local_dir=new_local_dir,
)
)

return trials


@dataclass
class _ResumeConfig:
resume_unfinished: bool = True
Expand Down Expand Up @@ -154,17 +95,16 @@ class _ExperimentCheckpointManager:

def __init__(
self,
checkpoint_dir: str,
local_checkpoint_dir: str,
checkpoint_period: Union[int, float, str],
start_time: float,
session_str: str,
syncer: Syncer,
sync_trial_checkpoints: bool,
local_dir: str,
remote_dir: str,
remote_checkpoint_dir: str,
sync_every_n_trial_checkpoints: Optional[int] = None,
):
self._checkpoint_dir = checkpoint_dir
self._local_checkpoint_dir = local_checkpoint_dir
self._auto_checkpoint_enabled = checkpoint_period == "auto"
if self._auto_checkpoint_enabled:
self._checkpoint_period = 10.0 # Initial value
Expand All @@ -176,8 +116,7 @@ def __init__(

self._syncer = syncer
self._sync_trial_checkpoints = sync_trial_checkpoints
self._local_dir = local_dir
self._remote_dir = remote_dir
self._remote_checkpoint_dir = remote_checkpoint_dir

self._last_checkpoint_time = 0.0
self._last_sync_time = 0.0
Expand Down Expand Up @@ -225,7 +164,7 @@ def checkpoint(
Args:
force: Forces a checkpoint despite checkpoint_period.
"""
if not self._checkpoint_dir:
if not self._local_checkpoint_dir:
return

force = force or self._should_force_cloud_sync
Expand All @@ -243,12 +182,14 @@ def _serialize_and_write():
"timestamp": self._last_checkpoint_time,
},
}
tmp_file_name = os.path.join(self._checkpoint_dir, ".tmp_checkpoint")
tmp_file_name = os.path.join(self._local_checkpoint_dir, ".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)

os.replace(tmp_file_name, checkpoint_file)
search_alg.save_to_dir(self._checkpoint_dir, session_str=self._session_str)
search_alg.save_to_dir(
self._local_checkpoint_dir, session_str=self._session_str
)

checkpoint_time_start = time.monotonic()
with out_of_band_serialize_dataset():
Expand All @@ -274,14 +215,14 @@ def _serialize_and_write():
"`sync_timeout` in `SyncConfig`."
)
synced = self._syncer.sync_up(
local_dir=self._local_dir,
remote_dir=self._remote_dir,
local_dir=self._local_checkpoint_dir,
remote_dir=self._remote_checkpoint_dir,
exclude=exclude,
)
else:
synced = self._syncer.sync_up_if_needed(
local_dir=self._local_dir,
remote_dir=self._remote_dir,
local_dir=self._local_checkpoint_dir,
remote_dir=self._remote_checkpoint_dir,
exclude=exclude,
)

Expand Down Expand Up @@ -320,7 +261,7 @@ def _serialize_and_write():
)

self._last_checkpoint_time = time.time()
return self._checkpoint_dir
return self._local_checkpoint_dir


@DeveloperAPI
Expand Down Expand Up @@ -350,14 +291,15 @@ class TrialRunner:
search_alg: SearchAlgorithm for generating
Trial objects.
scheduler: Defaults to FIFOScheduler.
local_checkpoint_dir: Path where
global checkpoints are stored and restored from.
remote_checkpoint_dir: Remote path where
global checkpoints are stored and restored from. Used
if `resume` == REMOTE.
sync_config: See `tune.py:run`.
stopper: Custom class for stopping whole experiments. See
``Stopper``.
local_checkpoint_dir: Path where global experiment state checkpoints
are saved and restored from.
sync_config: See :class:`~ray.tune.syncer.SyncConfig`.
Within sync config, the `upload_dir` specifies cloud storage, and
experiment state checkpoints will be synced to the `remote_checkpoint_dir`:
`{sync_config.upload_dir}/{experiment_name}`.
experiment_dir_name: Experiment directory name.
See :class:`~ray.tune.experiment.Experiment`.
stopper: Custom class for stopping whole experiments. See ``Stopper``.
resume: see `tune.py:run`.
server_port: Port number for launching TuneServer.
fail_fast: Finishes as soon as a trial fails if True.
Expand Down Expand Up @@ -388,8 +330,8 @@ def __init__(
search_alg: Optional[SearchAlgorithm] = None,
scheduler: Optional[TrialScheduler] = None,
local_checkpoint_dir: Optional[str] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_config: Optional[SyncConfig] = None,
experiment_dir_name: Optional[str] = None,
stopper: Optional[Stopper] = None,
resume: Union[str, bool] = False,
server_port: Optional[int] = None,
Expand Down Expand Up @@ -436,11 +378,11 @@ def __init__(
# Manual override
self._max_pending_trials = int(max_pending_trials)

sync_config = sync_config or SyncConfig()
self._sync_config = sync_config or SyncConfig()

self.trial_executor.setup(
max_pending_trials=self._max_pending_trials,
trainable_kwargs={"sync_timeout": sync_config.sync_timeout},
trainable_kwargs={"sync_timeout": self._sync_config.sync_timeout},
)

self._metric = metric
Expand Down Expand Up @@ -485,9 +427,9 @@ def __init__(
if self._local_checkpoint_dir:
os.makedirs(self._local_checkpoint_dir, exist_ok=True)

self._remote_checkpoint_dir = remote_checkpoint_dir
self._experiment_dir_name = experiment_dir_name

self._syncer = get_node_to_storage_syncer(sync_config)
self._syncer = get_node_to_storage_syncer(self._sync_config)
self._stopper = stopper or NoopStopper()
self._resumed = False

Expand Down Expand Up @@ -562,14 +504,13 @@ def end_experiment_callbacks(self) -> None:

def _create_checkpoint_manager(self, sync_trial_checkpoints: bool = True):
return _ExperimentCheckpointManager(
checkpoint_dir=self._local_checkpoint_dir,
local_checkpoint_dir=self._local_checkpoint_dir,
checkpoint_period=self._checkpoint_period,
start_time=self._start_time,
session_str=self._session_str,
syncer=self._syncer,
sync_trial_checkpoints=sync_trial_checkpoints,
local_dir=self._local_checkpoint_dir,
remote_dir=self._remote_checkpoint_dir,
remote_checkpoint_dir=self._remote_checkpoint_dir,
sync_every_n_trial_checkpoints=self._trial_checkpoint_config.num_to_keep,
)

Expand All @@ -585,6 +526,12 @@ def search_alg(self):
def scheduler_alg(self):
return self._scheduler_alg

@property
def _remote_checkpoint_dir(self):
if self._sync_config.upload_dir and self._experiment_dir_name:
return os.path.join(self._sync_config.upload_dir, self._experiment_dir_name)
return None

def _validate_resume(
self, resume_type: Union[str, bool], driver_sync_trial_checkpoints=True
) -> Tuple[bool, Optional[_ResumeConfig]]:
Expand Down Expand Up @@ -845,19 +792,34 @@ def resume(
)
)

trial_runner_data = runner_state["runner_data"]
# Don't overwrite the current `_local_checkpoint_dir`
# The current directory could be different from the checkpointed
# directory, if the experiment directory has changed.
trial_runner_data.pop("_local_checkpoint_dir", None)
# 1. Restore trial runner state
self.__setstate__(runner_state["runner_data"])

self.__setstate__(trial_runner_data)
# 2. Restore search algorithm state
if self._search_alg.has_checkpoint(self._local_checkpoint_dir):
self._search_alg.restore_from_dir(self._local_checkpoint_dir)

trials = _load_trials_from_experiment_checkpoint(
runner_state, new_local_dir=self._local_checkpoint_dir
)
# 3. Load trial table from experiment checkpoint
trials = []
for trial_json_state in runner_state["checkpoints"]:
trial = Trial.from_json_state(trial_json_state)

# The following properties may be updated on restoration
# Ex: moved local/cloud experiment directory
trial.local_dir = self._local_checkpoint_dir
trial.sync_config = self._sync_config
trial.experiment_dir_name = self._experiment_dir_name

# Avoid creating logdir in client mode for returned trial results,
# since the dir might not be creatable locally.
# TODO(ekl) this is kind of a hack.
if not ray.util.client.ray.is_connected():
trial.init_logdir() # Create logdir if it does not exist
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this out of Trial.__setstate__. Are trials serialized/deserialized in any situation that would require this to stay within __setstate__? My understanding is that trial objects don't get shipped around and stay on the Tune driver. Only serialized/deserialized on experiment checkpoint and restore, which is handled here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for moving this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct @justinvyu and I think your solution is much cleaner. For setstate/getstate, it's ok to just restore exactly the state that was saved. If properties are overwritten, that should happen in the function that issues the restore. No need for magic here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krfricke doesn't this make trials effectively unserializable in the general case unless the trainable is registered? I don't think that's an issue, but perhaps something to consider


trial.refresh_default_resource_request()
trials.append(trial)

# 4. Set trial statuses according to the resume configuration
for trial in sorted(trials, key=lambda t: t.last_update_time, reverse=True):
trial_to_add = trial
if trial.status == Trial.ERROR:
Expand Down Expand Up @@ -1623,6 +1585,9 @@ def __getstate__(self):
"_syncer",
"_callbacks",
"_checkpoint_manager",
"_local_checkpoint_dir",
"_sync_config",
"_experiment_dir_name",
]:
del state[k]
state["launch_web_server"] = bool(self._server)
Expand Down
Loading