Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

warn or raise ValueError on duplicated key in json/yaml config #6252

Merged
merged 12 commits into from
Mar 29, 2023
5 changes: 3 additions & 2 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
from monai.config import PathLike
from monai.utils import ensure_tuple, look_up_option, optional_import
from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates

if TYPE_CHECKING:
import yaml
Expand Down Expand Up @@ -400,9 +401,9 @@ def load_config_file(cls, filepath: PathLike, **kwargs: Any) -> dict:
raise ValueError(f'unknown file input: "{filepath}"')
with open(_filepath) as f:
if _filepath.lower().endswith(cls.suffixes[0]):
return json.load(f, **kwargs) # type: ignore[no-any-return]
return json.load(f, object_pairs_hook=check_key_duplicates, **kwargs) # type: ignore[no-any-return]
if _filepath.lower().endswith(cls.suffixes[1:]):
return yaml.safe_load(f, **kwargs) # type: ignore[no-any-return]
return yaml.load(f, CheckKeyDuplicatesYamlLoader, **kwargs) # type: ignore[no-any-return]
raise ValueError(f"only support JSON or YAML config file so far, got name {_filepath}.")

@classmethod
Expand Down
48 changes: 46 additions & 2 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@
from collections.abc import Callable, Iterable, Sequence
from distutils.util import strtobool
from pathlib import Path
from typing import Any, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload

import numpy as np
import torch

from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
from monai.utils.module import version_leq
from monai.utils.module import optional_import, version_leq

if TYPE_CHECKING:
from yaml import SafeLoader
else:
SafeLoader, _ = optional_import("yaml", name="SafeLoader", as_type="base")

__all__ = [
"zip_with",
Expand Down Expand Up @@ -679,3 +684,42 @@ def pprint_edges(val: Any, n_lines: int = 20) -> str:
hidden_n = len(val_str) - n_lines * 2
val_str = val_str[:n_lines] + [f"\n ... omitted {hidden_n} line(s)\n\n"] + val_str[-n_lines:]
return "".join(val_str)


def check_key_duplicates(ordered_pairs: Sequence[tuple[Any, Any]]) -> dict[Any, Any]:
"""
Checks if there is a duplicated key in the sequence of `ordered_pairs`.
If there is - it will log a warning or raise ValueError
(if configured by environmental var `MONAI_FAIL_ON_DUPLICATE_CONFIG==1`)

Otherwise, it returns the dict made from this sequence.

Satisfies a format for an `object_pairs_hook` in `json.load`

Args:
ordered_pairs: sequence of (key, value)
"""
keys = set()
for k, _ in ordered_pairs:
if k in keys:
if os.environ.get("MONAI_FAIL_ON_DUPLICATE_CONFIG", "0") == "1":
raise ValueError(f"Duplicate key: `{k}`")
else:
warnings.warn(f"Duplicate key: `{k}`")
else:
keys.add(k)
return dict(ordered_pairs)


class CheckKeyDuplicatesYamlLoader(SafeLoader):
def construct_mapping(self, node, deep=False):
mapping = set()
for key_node, _ in node.value:
key = self.construct_object(key_node, deep=deep)
if key in mapping:
if os.environ.get("MONAI_FAIL_ON_DUPLICATE_CONFIG", "0") == "1":
raise ValueError(f"Duplicate key: `{key}`")
else:
warnings.warn(f"Duplicate key: `{key}`")
mapping.add(key)
return super().construct_mapping(node, deep)
47 changes: 46 additions & 1 deletion tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import os
import tempfile
import unittest
from unittest import skipUnless
import warnings
from pathlib import Path
from unittest import mock, skipUnless

import numpy as np
from parameterized import parameterized
Expand All @@ -27,6 +29,7 @@
from tests.utils import TimedCall

_, has_tv = optional_import("torchvision", "0.8.0", min_version)
_, has_yaml = optional_import("yaml")


@TimedCall(seconds=100, force_quit=True)
Expand Down Expand Up @@ -109,6 +112,18 @@ def __call__(self, a, b):

TEST_CASE_5 = [{"training": {"A": 1, "A_B": 2}, "total": "$@training#A + @training#A_B + 1"}, 4]

TEST_CASE_DUPLICATED_KEY_JSON = ["""{"key": {"unique": 1, "duplicate": 0, "duplicate": 4 } }""", "json", 1, [0, 4]]

TEST_CASE_DUPLICATED_KEY_YAML = [
"""key:
unique: 1
duplicate: 0
duplicate: 4""",
"yaml",
1,
[0, 4],
]


class TestConfigParser(unittest.TestCase):
def test_config_content(self):
Expand Down Expand Up @@ -303,6 +318,36 @@ def test_substring_reference(self, config, expected):
parser = ConfigParser(config=config)
self.assertEqual(parser.get_parsed_content("total"), expected)

@parameterized.expand([TEST_CASE_DUPLICATED_KEY_JSON, TEST_CASE_DUPLICATED_KEY_YAML])
@mock.patch.dict(os.environ, {"MONAI_FAIL_ON_DUPLICATE_CONFIG": "1"})
@skipUnless(has_yaml, "Requires pyyaml")
def test_parse_json_raise(self, config_string, extension, _, __):
with tempfile.TemporaryDirectory() as tempdir:
config_path = Path(tempdir) / f"config.{extension}"
config_path.write_text(config_string)
parser = ConfigParser()

with self.assertRaises(ValueError) as context:
parser.read_config(config_path)

self.assertTrue("Duplicate key: `duplicate`" in str(context.exception))

@parameterized.expand([TEST_CASE_DUPLICATED_KEY_JSON, TEST_CASE_DUPLICATED_KEY_YAML])
@skipUnless(has_yaml, "Requires pyyaml")
def test_parse_json_warn(self, config_string, extension, expected_unique_val, expected_duplicate_vals):
with tempfile.TemporaryDirectory() as tempdir:
config_path = Path(tempdir) / f"config.{extension}"
config_path.write_text(config_string)
parser = ConfigParser()

with warnings.catch_warnings(record=True) as w:
parser.read_config(config_path)
self.assertEqual(len(w), 1)
self.assertTrue("Duplicate key: `duplicate`" in str(w[-1].message))

self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val)
self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals)


if __name__ == "__main__":
unittest.main()