Skip to content

Commit

Permalink
--cfg=hydra for read-only top-level config
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed Apr 6, 2021
1 parent 8e078bf commit 836c75c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 11 deletions.
22 changes: 12 additions & 10 deletions hydra/_internal/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import defaultdict
from typing import Any, Callable, DefaultDict, List, Optional, Sequence, Type, Union

from omegaconf import Container, DictConfig, OmegaConf, open_dict
from omegaconf import Container, DictConfig, OmegaConf, open_dict, read_write

from hydra._internal.utils import get_column_widths, run_and_report
from hydra.core.config_loader import ConfigLoader
Expand Down Expand Up @@ -136,13 +136,14 @@ def multirun(
@staticmethod
def get_sanitized_hydra_cfg(src_cfg: DictConfig) -> DictConfig:
cfg = copy.deepcopy(src_cfg)
with open_dict(cfg):
for key in list(cfg.keys()):
if key != "hydra":
del cfg[key]
with open_dict(cfg.hydra):
del cfg.hydra["hydra_help"]
del cfg.hydra["help"]
with read_write(cfg):
with open_dict(cfg):
for key in list(cfg.keys()):
if key != "hydra":
del cfg[key]
with open_dict(cfg.hydra):
del cfg.hydra["hydra_help"]
del cfg.hydra["help"]
return cfg

def _get_cfg(
Expand All @@ -160,8 +161,9 @@ def _get_cfg(
with_log_configuration=with_log_configuration,
)
if cfg_type == "job":
with open_dict(cfg):
del cfg["hydra"]
with read_write(cfg):
with open_dict(cfg):
del cfg["hydra"]
elif cfg_type == "hydra":
cfg = self.get_sanitized_hydra_cfg(cfg)
return cfg
Expand Down
42 changes: 41 additions & 1 deletion tests/test_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,10 +1227,50 @@ def test_hydra_main_without_config_path(tmpdir: Path) -> None:
@hydra.main()
"""
)

assert_regex_match(
from_line=expected,
to_line=err,
from_name="Expected error",
to_name="Actual error",
)


@mark.parametrize(
"overrides,expected",
[
(
["--cfg", "job"],
dedent(
"""\
baud_rate: 19200
data_bits: 8
stop_bits: 1
"""
),
),
(
["--cfg", "hydra", "-p", "hydra.env"],
dedent(
"""\
# @package hydra.env
{}
"""
),
),
],
)
def test_frozen_primary_config(
tmpdir: Path, overrides: List[str], expected: Any
) -> None:
cmd = [
"examples/patterns/write_protect_config_node/frozen.py",
f"hydra.run.dir={tmpdir}",
]
cmd.extend(overrides)
ret, _err = run_python_script(cmd)
assert_regex_match(
from_line=expected,
to_line=ret,
from_name="Expected output",
to_name="Actual output",
)

0 comments on commit 836c75c

Please sign in to comment.