Skip to content

Commit

Permalink
[Tune] Allow re-specifying param space in Tuner.restore (ray-projec…
Browse files Browse the repository at this point in the history
…t#32317)

This PR adds a `Tuner.restore(param_space=...)` argument. This allows object refs to be updated if used in the original run.

This is a follow-up to ray-project#31927

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
justinvyu authored and edoakes committed Mar 22, 2023
1 parent 33d40c1 commit cd4f620
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 18 deletions.
39 changes: 30 additions & 9 deletions python/ray/tune/impl/tuner_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray.air.config import RunConfig, ScalingConfig
from ray.tune import Experiment, TuneError, ExperimentAnalysis
from ray.tune.execution.experiment_state import _ResumeConfig
from ray.tune.tune import _Config
from ray.tune.registry import is_function_trainable
from ray.tune.result_grid import ResultGrid
from ray.tune.trainable import Trainable
Expand Down Expand Up @@ -99,20 +100,29 @@ def __init__(
"Tuner(..., run_config=RunConfig(...))"
)

self.trainable = trainable
param_space = param_space or {}
if isinstance(param_space, _Config):
param_space = param_space.to_dict()
if not isinstance(param_space, dict):
raise ValueError(
"The `param_space` passed to the Tuner` must be a dict. "
f"Got '{type(param_space)}' instead."
)
self.param_space = param_space

self._tune_config = tune_config or TuneConfig()
self._run_config = run_config or RunConfig()

self._missing_params_error_message = None

self._param_space = param_space or {}
self._process_scaling_config()

# Restore from Tuner checkpoint.
if restore_path:
self._restore_from_path_or_uri(
path_or_uri=restore_path,
resume_config=resume_config,
overwrite_trainable=trainable,
overwrite_param_space=param_space,
)
return

Expand All @@ -121,7 +131,6 @@ def __init__(
raise TuneError("You need to provide a trainable to tune.")

self._is_restored = False
self.trainable = trainable
self._resume_config = None

self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
Expand Down Expand Up @@ -301,6 +310,7 @@ def _restore_from_path_or_uri(
path_or_uri: str,
resume_config: Optional[_ResumeConfig],
overwrite_trainable: Optional[TrainableTypeOrTrainer],
overwrite_param_space: Optional[Dict[str, Any]],
):
# Sync down from cloud storage if needed
synced, experiment_checkpoint_dir = self._maybe_sync_down_tuner_state(
Expand Down Expand Up @@ -332,6 +342,8 @@ def _restore_from_path_or_uri(

self._is_restored = True
self.trainable = trainable
if overwrite_param_space:
self.param_space = overwrite_param_space
self._resume_config = resume_config

if not synced:
Expand Down Expand Up @@ -435,6 +447,15 @@ def trainable(self, trainable: TrainableTypeOrTrainer):
self._trainable = trainable
self._converted_trainable = self._convert_trainable(trainable)

@property
def param_space(self) -> Dict[str, Any]:
return self._param_space

@param_space.setter
def param_space(self, param_space: Dict[str, Any]):
self._param_space = param_space
self._process_scaling_config()

def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType:
"""Converts an AIR Trainer to a Tune trainable and saves the converted
trainable. If not using an AIR Trainer, this leaves the trainable as is."""
Expand All @@ -449,7 +470,7 @@ def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType
def fit(self) -> ResultGrid:
trainable = self.converted_trainable
assert self._experiment_checkpoint_dir
param_space = copy.deepcopy(self._param_space)
param_space = copy.deepcopy(self.param_space)
if not self._is_restored:
analysis = self._fit_internal(trainable, param_space)
else:
Expand Down Expand Up @@ -552,14 +573,14 @@ def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
)

def _fit_internal(
self, trainable: TrainableType, param_space: Dict[str, Any]
self, trainable: TrainableType, param_space: Optional[Dict[str, Any]]
) -> ExperimentAnalysis:
"""Fitting for a fresh Tuner."""
args = {
**self._get_tune_run_arguments(trainable),
**dict(
run_or_experiment=trainable,
config={**param_space},
config=param_space,
num_samples=self._tune_config.num_samples,
search_alg=self._tune_config.search_alg,
scheduler=self._tune_config.scheduler,
Expand All @@ -575,7 +596,7 @@ def _fit_internal(
return analysis

def _fit_resume(
self, trainable: TrainableType, param_space: Dict[str, Any]
self, trainable: TrainableType, param_space: Optional[Dict[str, Any]]
) -> ExperimentAnalysis:
"""Fitting for a restored Tuner."""
if self._missing_params_error_message:
Expand All @@ -599,7 +620,7 @@ def _fit_resume(
**self._get_tune_run_arguments(trainable),
**dict(
run_or_experiment=trainable,
config={**param_space},
config=param_space,
resume=resume,
search_alg=self._tune_config.search_alg,
scheduler=self._tune_config.scheduler,
Expand Down
5 changes: 2 additions & 3 deletions python/ray/tune/tests/test_tune_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
from typing import List
import unittest
from unittest import mock

import ray
from ray import tune
Expand Down Expand Up @@ -314,10 +315,8 @@ def testResourceUpdateInResume(self):
)
assert len(analysis.trials) == 27

# Unfinished trials' resources should be updated.
@mock.patch.dict(os.environ, {"TUNE_MAX_PENDING_TRIALS_PG": "1"})
def testConfigUpdateInResume(self):
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"

class FakeDataset:
def __init__(self, name):
self.name = name
Expand Down
27 changes: 27 additions & 0 deletions python/ray/tune/tests/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,33 @@ def train_func(config):
assert artifact_data == f"{result.config['id']}"


def test_invalid_param_space(shutdown_only):
"""Check that Tune raises an error on invalid param_space types."""

def trainable(config):
return {"metric": 1}

with pytest.raises(ValueError):
Tuner(trainable, param_space="not allowed")

from ray.tune.tune import _Config

class CustomConfig(_Config):
def to_dict(self) -> dict:
return {"hparam": 1}

with pytest.raises(ValueError):
Tuner(trainable, param_space="not allowed").fit()

with pytest.raises(ValueError):
tune.run(trainable, config="not allowed")

# Dict and custom _Config subclasses are fine
Tuner(trainable, param_space={}).fit()
Tuner(trainable, param_space=CustomConfig()).fit()
tune.run(trainable, config=CustomConfig())


if __name__ == "__main__":
import sys

Expand Down
90 changes: 85 additions & 5 deletions python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _train_fn_sometimes_failing(config):
class _FailOnStats(Callback):
"""Fail when at least num_trials exist and num_finished have finished."""

def __init__(self, num_trials: int, num_finished: int, delay: int = 1):
def __init__(self, num_trials: int, num_finished: int = 0, delay: int = 1):
self.num_trials = num_trials
self.num_finished = num_finished
self.delay = delay
Expand Down Expand Up @@ -574,7 +574,7 @@ def train_func_1(config):
with pytest.raises(ValueError):
tuner = Tuner.restore(
str(tmpdir / "overwrite_trainable"),
overwrite_trainable="__fake",
trainable="__fake",
resume_errored=True,
)

Expand All @@ -586,7 +586,7 @@ def train_func_2(config):
with pytest.raises(ValueError):
tuner = Tuner.restore(
str(tmpdir / "overwrite_trainable"),
overwrite_trainable=train_func_2,
trainable=train_func_2,
resume_errored=True,
)

Expand All @@ -599,7 +599,7 @@ def train_func_1(config):
with caplog.at_level(logging.WARNING, logger="ray.tune.impl.tuner_internal"):
tuner = Tuner.restore(
str(tmpdir / "overwrite_trainable"),
overwrite_trainable=train_func_1,
trainable=train_func_1,
resume_errored=True,
)
assert "The trainable will be overwritten" in caplog.text
Expand Down Expand Up @@ -680,7 +680,7 @@ def create_trainable_with_params():
tuner = Tuner.restore(
str(tmp_path / exp_name),
resume_errored=True,
overwrite_trainable=create_trainable_with_params(),
trainable=create_trainable_with_params(),
)
results = tuner.fit()
assert not results.errors
Expand Down Expand Up @@ -1011,6 +1011,86 @@ def test_tuner_can_restore(tmp_path, upload_dir):
assert not Tuner.can_restore(tmp_path / "new_exp")


def testParamSpaceOverwrite(tmp_path, monkeypatch):
"""Test that overwriting param space on restore propagates new refs to existing
trials and newly generated trials."""

# Limit the number of generated trial configs -- so restore tests
# newly generated trials.
monkeypatch.setenv("TUNE_MAX_PENDING_TRIALS_PG", "1")

class FakeDataset:
def __init__(self, name):
self.name = name

def __repr__(self):
return f"<FakeDataset {self.name}>"

def train_fn(config):
raise RuntimeError("Failing!")

param_space = {
"test": tune.grid_search(
[FakeDataset("1"), FakeDataset("2"), FakeDataset("3")]
),
"test2": tune.grid_search(
[
FakeDataset("4"),
FakeDataset("5"),
FakeDataset("6"),
FakeDataset("7"),
]
),
}

tuner = Tuner(
train_fn,
param_space=param_space,
tune_config=TuneConfig(num_samples=1),
run_config=RunConfig(
local_dir=str(tmp_path),
name="param_space_overwrite",
callbacks=[_FailOnStats(num_trials=4, num_finished=2)],
),
)
with pytest.raises(RuntimeError):
tuner.fit()

# Just suppress the error this time with a new trainable
def train_fn(config):
pass

param_space = {
"test": tune.grid_search(
[FakeDataset("8"), FakeDataset("9"), FakeDataset("10")]
),
"test2": tune.grid_search(
[
FakeDataset("11"),
FakeDataset("12"),
FakeDataset("13"),
FakeDataset("14"),
]
),
}

tuner = Tuner.restore(
str(tmp_path / "param_space_overwrite"),
trainable=train_fn,
param_space=param_space,
resume_errored=True,
)
tuner._local_tuner._run_config.callbacks = None
result_grid = tuner.fit()
assert not result_grid.errors
assert len(result_grid) == 12

for r in result_grid:
# Make sure that test and test2 are updated.
assert r.config["test"].name in ["8", "9", "10"]
assert r.config["test2"].name in ["11", "12", "13", "14"]


if __name__ == "__main__":
import sys

Expand Down
16 changes: 16 additions & 0 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import copy
import datetime
import logging
Expand Down Expand Up @@ -61,6 +62,7 @@
from ray.util.annotations import PublicAPI
from ray.util.queue import Queue


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -173,6 +175,12 @@ def signal_interrupt_tune_run(sig: int, frame):
return experiment_interrupted_event


class _Config(abc.ABC):
def to_dict(self) -> dict:
"""Converts this configuration to a dict format."""
raise NotImplementedError


@PublicAPI
def run(
run_or_experiment: Union[str, Callable, Type],
Expand Down Expand Up @@ -477,6 +485,14 @@ class and registered trainables.
set_verbosity(verbose)

config = config or {}
if isinstance(config, _Config):
config = config.to_dict()
if not isinstance(config, dict):
raise ValueError(
"The `config` passed to `tune.run()` must be a dict. "
f"Got '{type(config)}' instead."
)

sync_config = sync_config or SyncConfig()
sync_config.validate_upload_dir()

Expand Down
12 changes: 12 additions & 0 deletions python/ray/tune/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def restore(
overwrite_trainable: Optional[
Union[str, Callable, Type[Trainable], "BaseTrainer"]
] = None,
param_space: Optional[Dict[str, Any]] = None,
) -> "Tuner":
"""Restores Tuner after a previously failed run.
Expand Down Expand Up @@ -202,6 +203,15 @@ def restore(
This should be the same trainable that was used to initialize
the original Tuner.
NOTE: Starting in 2.5, this will be a required parameter.
param_space: The same `param_space` that was passed to
the original Tuner. This can be optionally re-specified due
to the `param_space` potentially containing Ray object
references (tuning over Ray Datasets or tuning over
several `ray.put` object references). **Tune expects the
`param_space` to be unmodified**, and the only part that
will be used during restore are the updated object references.
Changing the hyperparameter search space then resuming is NOT
supported by this API.
resume_unfinished: If True, will continue to run unfinished trials.
resume_errored: If True, will re-schedule errored trials and try to
restore from their latest checkpoints.
Expand Down Expand Up @@ -242,6 +252,7 @@ def restore(
restore_path=path,
resume_config=resume_config,
trainable=trainable,
param_space=param_space,
)
return Tuner(_tuner_internal=tuner_internal)
else:
Expand All @@ -251,6 +262,7 @@ def restore(
restore_path=path,
resume_config=resume_config,
trainable=trainable,
param_space=param_space,
)
return Tuner(_tuner_internal=tuner_internal)

Expand Down
Loading

0 comments on commit cd4f620

Please sign in to comment.