Skip to content

Commit

Permalink
hydra: Use OmegaConf.to_yaml for dumping .yaml output. (#8587)
Browse files Browse the repository at this point in the history
* hydra: Use `OmegaConf.to_yaml` for dumping `.yaml` output.

Combining `OmegaConf.to_object` with our yaml DUMPER (ruaml.yaml) was not correctly handling strings like 'no'.

`OmegaConf.to_yaml` uses a custom string representer that correctly handles those cases.

Fixes #8583

* Update dvc/utils/hydra.py

Co-authored-by: Dave Berenbaum <dave@iterative.ai>

Co-authored-by: Dave Berenbaum <dave@iterative.ai>
  • Loading branch information
daavoo and Dave Berenbaum authored Nov 19, 2022
1 parent dae500d commit b0a9241
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
8 changes: 6 additions & 2 deletions dvc/utils/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ def compose_and_dump(
with initialize_config_dir(config_dir, version_base=None):
cfg = compose(config_name=config_name, overrides=overrides)

dumper = DUMPERS[Path(output_file).suffix.lower()]
dumper(output_file, OmegaConf.to_object(cfg))
suffix = Path(output_file).suffix.lower()
if suffix not in [".yml", ".yaml"]:
dumper = DUMPERS[suffix]
dumper(output_file, OmegaConf.to_object(cfg))
else:
Path(output_file).write_text(OmegaConf.to_yaml(cfg), encoding="utf-8")


def apply_overrides(path: "StrPath", overrides: List[str]) -> None:
Expand Down
15 changes: 14 additions & 1 deletion tests/func/utils/test_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,25 @@ def test_compose_and_dump(tmp_dir, suffix, overrides, expected):
from dvc.utils.hydra import compose_and_dump

config_name = "config"
config_dir = hydra_setup(tmp_dir, "conf", "config")
output_file = tmp_dir / f"params.{suffix}"
config_dir = hydra_setup(tmp_dir, "conf", "config")
compose_and_dump(output_file, config_dir, config_name, overrides)
assert output_file.parse() == expected


@pytest.mark.skipif(sys.version_info >= (3, 11), reason="unsupported on 3.11")
def test_compose_and_dump_yaml_handles_string(tmp_dir):
"""Regression test for 8583"""
from dvc.utils.hydra import compose_and_dump

config = tmp_dir / "conf" / "config.yaml"
config.parent.mkdir()
config.write_text("foo: 'no'\n")
output_file = tmp_dir / "params.yaml"
compose_and_dump(output_file, str(config.parent), "config", [])
assert output_file.read_text() == "foo: 'no'\n"


@pytest.mark.parametrize(
"overrides, expected",
[
Expand Down

0 comments on commit b0a9241

Please sign in to comment.