diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 9165a8e5d8c..206a8b65247 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -346,6 +346,7 @@ def __call__(self, data): Exclusive("config_dir", "config_source"): str, Exclusive("config_module", "config_source"): str, "config_name": str, + "plugins_path": str, }, "studio": { "token": str, diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index 9520652bd92..c77dab5072a 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -486,11 +486,15 @@ def _update_params(self, params: Dict[str, List[str]]): else: config_dir = None config_name = hydra_config.get("config_name", "config") + plugins_path = os.path.join( + self.repo.root_dir, hydra_config.get("plugins_path", "") + ) compose_and_dump( path, config_dir, config_module, config_name, + plugins_path, overrides, ) else: diff --git a/dvc/utils/hydra.py b/dvc/utils/hydra.py index a6a4b936bed..721b0bb1746 100644 --- a/dvc/utils/hydra.py +++ b/dvc/utils/hydra.py @@ -13,11 +13,24 @@ logger = logger.getChild(__name__) +def load_hydra_plugins(plugins_path: str): + import sys + + from hydra.core.plugins import Plugins + + sys.path.append(plugins_path) + try: + Plugins.instance() + finally: + sys.path.remove(plugins_path) + + def compose_and_dump( output_file: "StrPath", config_dir: Optional[str], config_module: Optional[str], config_name: str, + plugins_path: str, overrides: List[str], ) -> None: """Compose Hydra config and dumpt it to `output_file`. @@ -30,6 +43,7 @@ def compose_and_dump( Ignored if `config_dir` is not `None`. config_name: Name of the config file containing defaults, without the .yaml extension. + plugins_path: Path to auto discover Hydra plugins. overrides: List of `Hydra Override`_ patterns. .. _Hydra Override: @@ -47,6 +61,7 @@ def compose_and_dump( initialize_config_dir if config_dir else initialize_config_module ) + load_hydra_plugins(plugins_path) with initialize_config( # type: ignore[attr-defined] config_source, version_base=None ): diff --git a/tests/func/utils/test_hydra.py b/tests/func/utils/test_hydra.py index 6a448e2a934..b2c947c25bb 100644 --- a/tests/func/utils/test_hydra.py +++ b/tests/func/utils/test_hydra.py @@ -176,7 +176,9 @@ def test_compose_and_dump_overrides(tmp_dir, suffix, overrides, expected): output_file = tmp_dir / f"params.{suffix}" config_dir = hydra_setup(tmp_dir, "conf", "config") config_module = None - compose_and_dump(output_file, config_dir, config_module, config_name, overrides) + compose_and_dump( + output_file, config_dir, config_module, config_name, str(tmp_dir), overrides + ) assert output_file.parse() == expected @@ -229,7 +231,9 @@ def test_compose_and_dump_dir_module( ) with error_context: - compose_and_dump(output_file, config_dir, config_module, config_name, []) + compose_and_dump( + output_file, config_dir, config_module, config_name, str(tmp_dir), [] + ) assert output_file.parse() == config_content @@ -241,7 +245,7 @@ def test_compose_and_dump_yaml_handles_string(tmp_dir): config.parent.mkdir() config.write_text("foo: 'no'\n") output_file = tmp_dir / "params.yaml" - compose_and_dump(output_file, str(config.parent), None, "config", []) + compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), []) assert output_file.read_text() == "foo: 'no'\n" @@ -253,12 +257,38 @@ def test_compose_and_dump_resolves_interpolation(tmp_dir): config.parent.mkdir() config.dump({"data": {"root": "path/to/root", "raw": "${.root}/raw"}}) output_file = tmp_dir / "params.yaml" - compose_and_dump(output_file, str(config.parent), None, "config", []) + compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), []) assert output_file.parse() == { "data": {"root": "path/to/root", "raw": "path/to/root/raw"} } +def test_compose_and_dump_plugins(tmp_dir): + """Ensure Hydra plugins are loaded.""" + from hydra.core.plugins import Plugins + + from dvc.utils.hydra import compose_and_dump + + # clear cached plugins + Plugins._instances.pop(Plugins, None) + + config = tmp_dir / "conf" / "config.yaml" + config.parent.mkdir() + config.write_text("foo: '${plus_10:1}'\n") + + plugins = tmp_dir / "hydra_plugins" + plugins.mkdir() + (plugins / "resolver.py").write_text( + """\ +from omegaconf import OmegaConf +OmegaConf.register_new_resolver('plus_10', lambda x: x + 10)""" + ) + + output_file = tmp_dir / "params.yaml" + compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), []) + assert output_file.read_text() == "foo: 11\n" + + @pytest.mark.parametrize( "overrides, expected", [