From bc6b40a2790fc7328cb622c4f393ff247531ddb5 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 14 Feb 2023 14:35:16 -0800 Subject: [PATCH] [Tune] Allow re-specifying param space in `Tuner.restore` (#32317) This PR adds a `Tuner.restore(param_space=...)` argument. This allows object refs to be updated if used in the original run. This is a follow-up to https://github.com/ray-project/ray/pull/31927 Signed-off-by: Justin Yu --- python/ray/tune/impl/tuner_internal.py | 39 ++++++--- python/ray/tune/tests/test_tune_restore.py | 5 +- python/ray/tune/tests/test_tuner.py | 27 +++++++ python/ray/tune/tests/test_tuner_restore.py | 90 +++++++++++++++++++-- python/ray/tune/tune.py | 16 ++++ python/ray/tune/tuner.py | 12 +++ rllib/algorithms/algorithm_config.py | 3 +- 7 files changed, 174 insertions(+), 18 deletions(-) diff --git a/python/ray/tune/impl/tuner_internal.py b/python/ray/tune/impl/tuner_internal.py index 3ee84e3349c2..f3ce25e59b8c 100644 --- a/python/ray/tune/impl/tuner_internal.py +++ b/python/ray/tune/impl/tuner_internal.py @@ -16,6 +16,7 @@ from ray.air.config import RunConfig, ScalingConfig from ray.tune import Experiment, TuneError, ExperimentAnalysis from ray.tune.execution.experiment_state import _ResumeConfig +from ray.tune.tune import _Config from ray.tune.registry import is_function_trainable from ray.tune.result_grid import ResultGrid from ray.tune.trainable import Trainable @@ -99,20 +100,29 @@ def __init__( "Tuner(..., run_config=RunConfig(...))" ) + self.trainable = trainable + param_space = param_space or {} + if isinstance(param_space, _Config): + param_space = param_space.to_dict() + if not isinstance(param_space, dict): + raise ValueError( + "The `param_space` passed to the Tuner` must be a dict. " + f"Got '{type(param_space)}' instead." + ) + self.param_space = param_space + self._tune_config = tune_config or TuneConfig() self._run_config = run_config or RunConfig() self._missing_params_error_message = None - self._param_space = param_space or {} - self._process_scaling_config() - # Restore from Tuner checkpoint. if restore_path: self._restore_from_path_or_uri( path_or_uri=restore_path, resume_config=resume_config, overwrite_trainable=trainable, + overwrite_param_space=param_space, ) return @@ -121,7 +131,6 @@ def __init__( raise TuneError("You need to provide a trainable to tune.") self._is_restored = False - self.trainable = trainable self._resume_config = None self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {} @@ -301,6 +310,7 @@ def _restore_from_path_or_uri( path_or_uri: str, resume_config: Optional[_ResumeConfig], overwrite_trainable: Optional[TrainableTypeOrTrainer], + overwrite_param_space: Optional[Dict[str, Any]], ): # Sync down from cloud storage if needed synced, experiment_checkpoint_dir = self._maybe_sync_down_tuner_state( @@ -332,6 +342,8 @@ def _restore_from_path_or_uri( self._is_restored = True self.trainable = trainable + if overwrite_param_space: + self.param_space = overwrite_param_space self._resume_config = resume_config if not synced: @@ -435,6 +447,15 @@ def trainable(self, trainable: TrainableTypeOrTrainer): self._trainable = trainable self._converted_trainable = self._convert_trainable(trainable) + @property + def param_space(self) -> Dict[str, Any]: + return self._param_space + + @param_space.setter + def param_space(self, param_space: Dict[str, Any]): + self._param_space = param_space + self._process_scaling_config() + def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType: """Converts an AIR Trainer to a Tune trainable and saves the converted trainable. If not using an AIR Trainer, this leaves the trainable as is.""" @@ -449,7 +470,7 @@ def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType def fit(self) -> ResultGrid: trainable = self.converted_trainable assert self._experiment_checkpoint_dir - param_space = copy.deepcopy(self._param_space) + param_space = copy.deepcopy(self.param_space) if not self._is_restored: analysis = self._fit_internal(trainable, param_space) else: @@ -552,14 +573,14 @@ def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]: ) def _fit_internal( - self, trainable: TrainableType, param_space: Dict[str, Any] + self, trainable: TrainableType, param_space: Optional[Dict[str, Any]] ) -> ExperimentAnalysis: """Fitting for a fresh Tuner.""" args = { **self._get_tune_run_arguments(trainable), **dict( run_or_experiment=trainable, - config={**param_space}, + config=param_space, num_samples=self._tune_config.num_samples, search_alg=self._tune_config.search_alg, scheduler=self._tune_config.scheduler, @@ -575,7 +596,7 @@ def _fit_internal( return analysis def _fit_resume( - self, trainable: TrainableType, param_space: Dict[str, Any] + self, trainable: TrainableType, param_space: Optional[Dict[str, Any]] ) -> ExperimentAnalysis: """Fitting for a restored Tuner.""" if self._missing_params_error_message: @@ -599,7 +620,7 @@ def _fit_resume( **self._get_tune_run_arguments(trainable), **dict( run_or_experiment=trainable, - config={**param_space}, + config=param_space, resume=resume, search_alg=self._tune_config.search_alg, scheduler=self._tune_config.scheduler, diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index a582688f1435..82fa8e5347df 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -12,6 +12,7 @@ import time from typing import List import unittest +from unittest import mock import ray from ray import tune @@ -314,10 +315,8 @@ def testResourceUpdateInResume(self): ) assert len(analysis.trials) == 27 - # Unfinished trials' resources should be updated. + @mock.patch.dict(os.environ, {"TUNE_MAX_PENDING_TRIALS_PG": "1"}) def testConfigUpdateInResume(self): - os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" - class FakeDataset: def __init__(self, name): self.name = name diff --git a/python/ray/tune/tests/test_tuner.py b/python/ray/tune/tests/test_tuner.py index 843186693f7a..6f3e51f726c5 100644 --- a/python/ray/tune/tests/test_tuner.py +++ b/python/ray/tune/tests/test_tuner.py @@ -531,6 +531,33 @@ def train_func(config): assert artifact_data == f"{result.config['id']}" +def test_invalid_param_space(shutdown_only): + """Check that Tune raises an error on invalid param_space types.""" + + def trainable(config): + return {"metric": 1} + + with pytest.raises(ValueError): + Tuner(trainable, param_space="not allowed") + + from ray.tune.tune import _Config + + class CustomConfig(_Config): + def to_dict(self) -> dict: + return {"hparam": 1} + + with pytest.raises(ValueError): + Tuner(trainable, param_space="not allowed").fit() + + with pytest.raises(ValueError): + tune.run(trainable, config="not allowed") + + # Dict and custom _Config subclasses are fine + Tuner(trainable, param_space={}).fit() + Tuner(trainable, param_space=CustomConfig()).fit() + tune.run(trainable, config=CustomConfig()) + + if __name__ == "__main__": import sys diff --git a/python/ray/tune/tests/test_tuner_restore.py b/python/ray/tune/tests/test_tuner_restore.py index d46f5eb5668a..c62b03c481bd 100644 --- a/python/ray/tune/tests/test_tuner_restore.py +++ b/python/ray/tune/tests/test_tuner_restore.py @@ -96,7 +96,7 @@ def _train_fn_sometimes_failing(config): class _FailOnStats(Callback): """Fail when at least num_trials exist and num_finished have finished.""" - def __init__(self, num_trials: int, num_finished: int, delay: int = 1): + def __init__(self, num_trials: int, num_finished: int = 0, delay: int = 1): self.num_trials = num_trials self.num_finished = num_finished self.delay = delay @@ -574,7 +574,7 @@ def train_func_1(config): with pytest.raises(ValueError): tuner = Tuner.restore( str(tmpdir / "overwrite_trainable"), - overwrite_trainable="__fake", + trainable="__fake", resume_errored=True, ) @@ -586,7 +586,7 @@ def train_func_2(config): with pytest.raises(ValueError): tuner = Tuner.restore( str(tmpdir / "overwrite_trainable"), - overwrite_trainable=train_func_2, + trainable=train_func_2, resume_errored=True, ) @@ -599,7 +599,7 @@ def train_func_1(config): with caplog.at_level(logging.WARNING, logger="ray.tune.impl.tuner_internal"): tuner = Tuner.restore( str(tmpdir / "overwrite_trainable"), - overwrite_trainable=train_func_1, + trainable=train_func_1, resume_errored=True, ) assert "The trainable will be overwritten" in caplog.text @@ -680,7 +680,7 @@ def create_trainable_with_params(): tuner = Tuner.restore( str(tmp_path / exp_name), resume_errored=True, - overwrite_trainable=create_trainable_with_params(), + trainable=create_trainable_with_params(), ) results = tuner.fit() assert not results.errors @@ -1011,6 +1011,86 @@ def test_tuner_can_restore(tmp_path, upload_dir): assert not Tuner.can_restore(tmp_path / "new_exp") +def testParamSpaceOverwrite(tmp_path, monkeypatch): + """Test that overwriting param space on restore propagates new refs to existing + trials and newly generated trials.""" + + # Limit the number of generated trial configs -- so restore tests + # newly generated trials. + monkeypatch.setenv("TUNE_MAX_PENDING_TRIALS_PG", "1") + + class FakeDataset: + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"" + + def train_fn(config): + raise RuntimeError("Failing!") + + param_space = { + "test": tune.grid_search( + [FakeDataset("1"), FakeDataset("2"), FakeDataset("3")] + ), + "test2": tune.grid_search( + [ + FakeDataset("4"), + FakeDataset("5"), + FakeDataset("6"), + FakeDataset("7"), + ] + ), + } + + tuner = Tuner( + train_fn, + param_space=param_space, + tune_config=TuneConfig(num_samples=1), + run_config=RunConfig( + local_dir=str(tmp_path), + name="param_space_overwrite", + callbacks=[_FailOnStats(num_trials=4, num_finished=2)], + ), + ) + with pytest.raises(RuntimeError): + tuner.fit() + + # Just suppress the error this time with a new trainable + def train_fn(config): + pass + + param_space = { + "test": tune.grid_search( + [FakeDataset("8"), FakeDataset("9"), FakeDataset("10")] + ), + "test2": tune.grid_search( + [ + FakeDataset("11"), + FakeDataset("12"), + FakeDataset("13"), + FakeDataset("14"), + ] + ), + } + + tuner = Tuner.restore( + str(tmp_path / "param_space_overwrite"), + trainable=train_fn, + param_space=param_space, + resume_errored=True, + ) + tuner._local_tuner._run_config.callbacks = None + result_grid = tuner.fit() + assert not result_grid.errors + assert len(result_grid) == 12 + + for r in result_grid: + # Make sure that test and test2 are updated. + assert r.config["test"].name in ["8", "9", "10"] + assert r.config["test2"].name in ["11", "12", "13", "14"] + + if __name__ == "__main__": import sys diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 3bafc7d00b92..5684e1a8e70e 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -1,3 +1,4 @@ +import abc import copy import datetime import logging @@ -61,6 +62,7 @@ from ray.util.annotations import PublicAPI from ray.util.queue import Queue + logger = logging.getLogger(__name__) @@ -173,6 +175,12 @@ def signal_interrupt_tune_run(sig: int, frame): return experiment_interrupted_event +class _Config(abc.ABC): + def to_dict(self) -> dict: + """Converts this configuration to a dict format.""" + raise NotImplementedError + + @PublicAPI def run( run_or_experiment: Union[str, Callable, Type], @@ -477,6 +485,14 @@ class and registered trainables. set_verbosity(verbose) config = config or {} + if isinstance(config, _Config): + config = config.to_dict() + if not isinstance(config, dict): + raise ValueError( + "The `config` passed to `tune.run()` must be a dict. " + f"Got '{type(config)}' instead." + ) + sync_config = sync_config or SyncConfig() sync_config.validate_upload_dir() diff --git a/python/ray/tune/tuner.py b/python/ray/tune/tuner.py index c15ade6cb597..2ca1a253cc3b 100644 --- a/python/ray/tune/tuner.py +++ b/python/ray/tune/tuner.py @@ -172,6 +172,7 @@ def restore( overwrite_trainable: Optional[ Union[str, Callable, Type[Trainable], "BaseTrainer"] ] = None, + param_space: Optional[Dict[str, Any]] = None, ) -> "Tuner": """Restores Tuner after a previously failed run. @@ -202,6 +203,15 @@ def restore( This should be the same trainable that was used to initialize the original Tuner. NOTE: Starting in 2.5, this will be a required parameter. + param_space: The same `param_space` that was passed to + the original Tuner. This can be optionally re-specified due + to the `param_space` potentially containing Ray object + references (tuning over Ray Datasets or tuning over + several `ray.put` object references). **Tune expects the + `param_space` to be unmodified**, and the only part that + will be used during restore are the updated object references. + Changing the hyperparameter search space then resuming is NOT + supported by this API. resume_unfinished: If True, will continue to run unfinished trials. resume_errored: If True, will re-schedule errored trials and try to restore from their latest checkpoints. @@ -242,6 +252,7 @@ def restore( restore_path=path, resume_config=resume_config, trainable=trainable, + param_space=param_space, ) return Tuner(_tuner_internal=tuner_internal) else: @@ -251,6 +262,7 @@ def restore( restore_path=path, resume_config=resume_config, trainable=trainable, + param_space=param_space, ) return Tuner(_tuner_internal=tuner_internal) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 2fc03b08d735..d38af7bbb2fc 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -61,6 +61,7 @@ ResultDict, SampleBatchType, ) +from ray.tune.tune import _Config from ray.tune.logger import Logger from ray.tune.registry import get_trainable_cls from ray.tune.result import TRIAL_INFO @@ -115,7 +116,7 @@ def _resolve_class_path(module) -> Type: return getattr(module, class_name) -class AlgorithmConfig: +class AlgorithmConfig(_Config): """A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration. Example: