Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

warn or raise ValueError on duplicated key in json/yaml config #6252

Merged
merged 12 commits into from
Mar 29, 2023
29 changes: 28 additions & 1 deletion monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from __future__ import annotations

import json
import logging
import os
import re
from collections.abc import Sequence
from copy import deepcopy
Expand All @@ -33,6 +35,8 @@

_default_globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}

logger = logging.getLogger(__name__)


class ConfigParser:
"""
Expand Down Expand Up @@ -383,6 +387,29 @@ def _do_parse(self, config: Any, id: str = "") -> None:
else:
self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id))

@staticmethod
def check_key_duplicates(ordered_pairs: Sequence[tuple[Any, Any]]) -> dict[Any, Any]:
wyli marked this conversation as resolved.
Show resolved Hide resolved
"""
Checks if there is a duplicated key in the sequence of `ordered_pairs`.
If there is - it will log a warning or raise ValueError
(if configured by environmental var `MONAI_FAIL_ON_DUPLICATE_CONFIG==1`)

Otherwise, it returns the dict made from this sequence.

Args:
ordered_pairs: sequence of (key, value)
"""
keys = set()
for k, _ in ordered_pairs:
if k in keys:
if os.environ.get("MONAI_FAIL_ON_DUPLICATE_CONFIG", "0") == "1":
raise ValueError(f"Duplicate key: `{k}`")
else:
logger.warning(f"Duplicate key: `{k}`")
wyli marked this conversation as resolved.
Show resolved Hide resolved
else:
keys.add(k)
return dict(ordered_pairs)

@classmethod
def load_config_file(cls, filepath: PathLike, **kwargs: Any) -> dict:
"""
Expand All @@ -400,7 +427,7 @@ def load_config_file(cls, filepath: PathLike, **kwargs: Any) -> dict:
raise ValueError(f'unknown file input: "{filepath}"')
with open(_filepath) as f:
if _filepath.lower().endswith(cls.suffixes[0]):
return json.load(f, **kwargs) # type: ignore[no-any-return]
return json.load(f, object_pairs_hook=cls.check_key_duplicates, **kwargs) # type: ignore[no-any-return]
if _filepath.lower().endswith(cls.suffixes[1:]):
return yaml.safe_load(f, **kwargs) # type: ignore[no-any-return]
raise ValueError(f"only support JSON or YAML config file so far, got name {_filepath}.")
Expand Down
33 changes: 32 additions & 1 deletion tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

from __future__ import annotations

import logging
import os
import tempfile
import unittest
from unittest import skipUnless
from pathlib import Path
from unittest import mock, skipUnless

import numpy as np
from parameterized import parameterized
Expand Down Expand Up @@ -109,6 +111,8 @@ def __call__(self, a, b):

TEST_CASE_5 = [{"training": {"A": 1, "A_B": 2}, "total": "$@training#A + @training#A_B + 1"}, 4]

TEST_CASE_DUPLICATED_KEY = ["""{"key": {"unique": 1, "duplicate": 0, "duplicate": 4 } }""", 1, [0, 4]]


class TestConfigParser(unittest.TestCase):
def test_config_content(self):
Expand Down Expand Up @@ -303,6 +307,33 @@ def test_substring_reference(self, config, expected):
parser = ConfigParser(config=config)
self.assertEqual(parser.get_parsed_content("total"), expected)

@parameterized.expand([TEST_CASE_DUPLICATED_KEY])
@mock.patch.dict(os.environ, {"MONAI_FAIL_ON_DUPLICATE_CONFIG": "1"})
def test_parse_json_raise(self, config_string, _, __):
wyli marked this conversation as resolved.
Show resolved Hide resolved
with tempfile.TemporaryDirectory() as tempdir:
config_path = Path(tempdir) / "config.json"
config_path.write_text(config_string)
parser = ConfigParser()

with self.assertRaises(ValueError) as context:
parser.read_config(config_path)

self.assertTrue("Duplicate key: `duplicate`" in str(context.exception))

@parameterized.expand([TEST_CASE_DUPLICATED_KEY])
def test_parse_json_warn(self, config_string, expected_unique_val, expected_duplicate_vals):
with tempfile.TemporaryDirectory() as tempdir:
config_path = Path(tempdir) / "config.json"
config_path.write_text(config_string)
parser = ConfigParser()

with self.assertLogs(level=logging.WARNING) as log:
parser.read_config(config_path)
self.assertTrue("Duplicate key: `duplicate`" in " ".join(log.output))

self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val)
self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals)


if __name__ == "__main__":
unittest.main()