@@ -83,22 +83,11 @@ def __init__(
8383 from ray .train .trainer import BaseTrainer
8484
8585 if isinstance (trainable , BaseTrainer ):
86- # If no run config was passed to the Tuner directly,
87- # use the one from the Trainer, if available
88- if not run_config :
89- run_config = trainable .run_config
90- if run_config and trainable .run_config != RunConfig ():
91- logger .info (
92- "A `RunConfig` was passed to both the `Tuner` and the "
93- f"`{ trainable .__class__ .__name__ } `. The run config passed to "
94- "the `Tuner` is the one that will be used."
95- )
96- if param_space and "run_config" in param_space :
97- raise ValueError (
98- "`RunConfig` cannot be tuned as part of the `param_space`! "
99- "Move the run config to be a parameter of the `Tuner`: "
100- "Tuner(..., run_config=RunConfig(...))"
101- )
86+ run_config = self ._choose_run_config (
87+ tuner_run_config = run_config ,
88+ trainer = trainable ,
89+ param_space = param_space ,
90+ )
10291
10392 self .trainable = trainable
10493 param_space = param_space or {}
@@ -403,6 +392,49 @@ def _maybe_sync_down_tuner_state(self, restore_path: str) -> Tuple[bool, str]:
403392 download_from_uri (str (restore_uri / _TUNER_PKL ), str (tempdir / _TUNER_PKL ))
404393 return True , str (tempdir )
405394
395+ def _choose_run_config (
396+ self ,
397+ tuner_run_config : Optional [RunConfig ],
398+ trainer : "BaseTrainer" ,
399+ param_space : Optional [Dict [str , Any ]],
400+ ) -> RunConfig :
401+ """Chooses which `RunConfig` to use when multiple can be passed in
402+ through a Trainer or the Tuner itself.
403+
404+ Args:
405+ tuner_run_config: The run config passed into the Tuner constructor.
406+ trainer: The AIR Trainer instance to use with Tune, which may have
407+ a RunConfig specified by the user.
408+ param_space: The param space passed to the Tuner.
409+
410+ Raises:
411+ ValueError: if the `run_config` is specified as a hyperparameter.
412+ """
413+ if param_space and "run_config" in param_space :
414+ raise ValueError (
415+ "`RunConfig` cannot be tuned as part of the `param_space`! "
416+ "Move the run config to be a parameter of the `Tuner`: "
417+ "Tuner(..., run_config=RunConfig(...))"
418+ )
419+
420+ # Both Tuner RunConfig + Trainer RunConfig --> prefer Tuner RunConfig
421+ if tuner_run_config and trainer .run_config != RunConfig ():
422+ logger .info (
423+ "A `RunConfig` was passed to both the `Tuner` and the "
424+ f"`{ trainer .__class__ .__name__ } `. The run config passed to "
425+ "the `Tuner` is the one that will be used."
426+ )
427+ return tuner_run_config
428+
429+ # No Tuner RunConfig -> pass the Trainer config through
430+ # This returns either a user-specified config, or the default RunConfig
431+ # if nothing was provided to both the Trainer or Tuner.
432+ if not tuner_run_config :
433+ return trainer .run_config
434+
435+ # Tuner RunConfig + No Trainer RunConfig --> Use the Tuner config
436+ return tuner_run_config
437+
406438 def _process_scaling_config (self ) -> None :
407439 """Converts ``self._param_space["scaling_config"]`` to a dict.
408440
0 commit comments