Skip to content

Commit 25ad8f9

Browse files
authored
[Tune] Log the choice of RunConfig only when passed into both trainer + tuner (#33454)
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
1 parent 636c510 commit 25ad8f9

File tree

2 files changed

+74
-23
lines changed

2 files changed

+74
-23
lines changed

python/ray/train/tests/test_tune.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ray.tune.impl.tuner_internal import _TUNER_PKL
2424

2525

26-
@pytest.fixture
26+
@pytest.fixture(scope="module")
2727
def ray_start_4_cpus():
2828
address_info = ray.init(num_cpus=4)
2929
yield address_info
@@ -272,22 +272,41 @@ def train_func(config):
272272
assert not results.errors
273273

274274

275+
@pytest.mark.parametrize("in_trainer", [True, False])
276+
@pytest.mark.parametrize("in_tuner", [True, False])
275277
def test_run_config_in_trainer_and_tuner(
276-
ray_start_4_cpus, tmp_path, propagate_logs, caplog
278+
propagate_logs, tmp_path, caplog, in_trainer, in_tuner
277279
):
280+
trainer_run_config = (
281+
RunConfig(name="trainer", local_dir=str(tmp_path)) if in_trainer else None
282+
)
283+
tuner_run_config = (
284+
RunConfig(name="tuner", local_dir=str(tmp_path)) if in_tuner else None
285+
)
278286
trainer = DataParallelTrainer(
279287
lambda config: None,
280288
backend_config=TestConfig(),
281289
scaling_config=ScalingConfig(num_workers=1),
282-
run_config=RunConfig(name="ignored", local_dir="ignored"),
290+
run_config=trainer_run_config,
283291
)
284292
with caplog.at_level(logging.INFO, logger="ray.tune.impl.tuner_internal"):
285-
Tuner(trainer, run_config=RunConfig(name="used", local_dir=str(tmp_path)))
286-
assert list((tmp_path / "used").glob(_TUNER_PKL))
287-
assert (
293+
tuner = Tuner(trainer, run_config=tuner_run_config)
294+
295+
both_msg = (
288296
"`RunConfig` was passed to both the `Tuner` and the `DataParallelTrainer`"
289-
in caplog.text
290297
)
298+
if in_trainer and in_tuner:
299+
assert list((tmp_path / "tuner").glob(_TUNER_PKL))
300+
assert both_msg in caplog.text
301+
elif in_trainer and not in_tuner:
302+
assert list((tmp_path / "trainer").glob(_TUNER_PKL))
303+
assert both_msg not in caplog.text
304+
elif not in_trainer and in_tuner:
305+
assert list((tmp_path / "tuner").glob(_TUNER_PKL))
306+
assert both_msg not in caplog.text
307+
else:
308+
assert tuner._local_tuner.get_run_config() == RunConfig()
309+
assert both_msg not in caplog.text
291310

292311

293312
def test_run_config_in_param_space():

python/ray/tune/impl/tuner_internal.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,11 @@ def __init__(
8383
from ray.train.trainer import BaseTrainer
8484

8585
if isinstance(trainable, BaseTrainer):
86-
# If no run config was passed to the Tuner directly,
87-
# use the one from the Trainer, if available
88-
if not run_config:
89-
run_config = trainable.run_config
90-
if run_config and trainable.run_config != RunConfig():
91-
logger.info(
92-
"A `RunConfig` was passed to both the `Tuner` and the "
93-
f"`{trainable.__class__.__name__}`. The run config passed to "
94-
"the `Tuner` is the one that will be used."
95-
)
96-
if param_space and "run_config" in param_space:
97-
raise ValueError(
98-
"`RunConfig` cannot be tuned as part of the `param_space`! "
99-
"Move the run config to be a parameter of the `Tuner`: "
100-
"Tuner(..., run_config=RunConfig(...))"
101-
)
86+
run_config = self._choose_run_config(
87+
tuner_run_config=run_config,
88+
trainer=trainable,
89+
param_space=param_space,
90+
)
10291

10392
self.trainable = trainable
10493
param_space = param_space or {}
@@ -403,6 +392,49 @@ def _maybe_sync_down_tuner_state(self, restore_path: str) -> Tuple[bool, str]:
403392
download_from_uri(str(restore_uri / _TUNER_PKL), str(tempdir / _TUNER_PKL))
404393
return True, str(tempdir)
405394

395+
def _choose_run_config(
396+
self,
397+
tuner_run_config: Optional[RunConfig],
398+
trainer: "BaseTrainer",
399+
param_space: Optional[Dict[str, Any]],
400+
) -> RunConfig:
401+
"""Chooses which `RunConfig` to use when multiple can be passed in
402+
through a Trainer or the Tuner itself.
403+
404+
Args:
405+
tuner_run_config: The run config passed into the Tuner constructor.
406+
trainer: The AIR Trainer instance to use with Tune, which may have
407+
a RunConfig specified by the user.
408+
param_space: The param space passed to the Tuner.
409+
410+
Raises:
411+
ValueError: if the `run_config` is specified as a hyperparameter.
412+
"""
413+
if param_space and "run_config" in param_space:
414+
raise ValueError(
415+
"`RunConfig` cannot be tuned as part of the `param_space`! "
416+
"Move the run config to be a parameter of the `Tuner`: "
417+
"Tuner(..., run_config=RunConfig(...))"
418+
)
419+
420+
# Both Tuner RunConfig + Trainer RunConfig --> prefer Tuner RunConfig
421+
if tuner_run_config and trainer.run_config != RunConfig():
422+
logger.info(
423+
"A `RunConfig` was passed to both the `Tuner` and the "
424+
f"`{trainer.__class__.__name__}`. The run config passed to "
425+
"the `Tuner` is the one that will be used."
426+
)
427+
return tuner_run_config
428+
429+
# No Tuner RunConfig -> pass the Trainer config through
430+
# This returns either a user-specified config, or the default RunConfig
431+
# if nothing was provided to both the Trainer or Tuner.
432+
if not tuner_run_config:
433+
return trainer.run_config
434+
435+
# Tuner RunConfig + No Trainer RunConfig --> Use the Tuner config
436+
return tuner_run_config
437+
406438
def _process_scaling_config(self) -> None:
407439
"""Converts ``self._param_space["scaling_config"]`` to a dict.
408440

0 commit comments

Comments
 (0)