From 52019e4b20d66ef683c086fc695d25183f34bdfc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 14 Dec 2021 17:28:25 +0800 Subject: [PATCH 1/7] [DLMED] add ConfigParser Signed-off-by: Nic Ma --- docs/source/apps.rst | 10 ++++ monai/apps/mmars/__init__.py | 2 + monai/apps/mmars/config_parser.py | 97 +++++++++++++++++++++++++++++++ monai/apps/mmars/utils.py | 61 +++++++++++++++++++ 4 files changed, 170 insertions(+) create mode 100644 monai/apps/mmars/config_parser.py create mode 100644 monai/apps/mmars/utils.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f4f7aff2d2..f6c6ecb283 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -29,6 +29,16 @@ Clara MMARs :annotation: +Model Package +------------- + +.. autoclass:: ConfigParser + :members: + +.. autoclass:: ModuleScanner + :members: + + `Utilities` ----------- diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 396be2e87d..dd93c39d00 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .config_parser import ConfigParser, ModuleScanner from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys +from .utils import get_class, instantiate_class diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py new file mode 100644 index 0000000000..6bc20e7e8d --- /dev/null +++ b/monai/apps/mmars/config_parser.py @@ -0,0 +1,97 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import pkgutil +from typing import Dict, Sequence + +from monai.apps.mmars.utils import instantiate_class + + +class ModuleScanner: + """ + Scan all the available classes in the specified packages and modules. + Map the all the class names and the module names in a table. + + Args: + pkgs: the expected packages to scan. + modules: the expected modules in the packages to scan. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.pkgs = pkgs + self.modules = modules + self._class_table = self._create_classes_table() + + def _create_classes_table(self): + class_table = {} + for pkg in self.pkgs: + package = __import__(pkg) + + for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."): + if modname.startswith(pkg): + if any(name in modname for name in self.modules): + try: + module = importlib.import_module(modname) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__module__ == modname: + class_table[name] = modname + except ModuleNotFoundError: + pass + return class_table + + def get_module_name(self, class_name): + return self._class_table.get(class_name, None) + + +class ConfigParser: + """ + Parse dictionary format config and build components. + + Args: + pkgs: the expected packages to scan. + modules: the expected modules in the packages to scan. + + Raises: + ValueError: must provide `path` or `name` of class to build component. + ValueError: can not find component class. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) + + def build_component(self, config: Dict) -> object: + if not isinstance(config, dict): + raise ValueError("config of component must be a dictionary.") + + if config.get("disabled") is True: + # if marked as `disabled`, skip parsing + return None + + class_args = config.get("args", {}) + class_path = self._get_class_path(config) + return instantiate_class(class_path, **class_args) + + def _get_class_path(self, config): + class_path = config.get("path", None) + if class_path is None: + class_name = config.get("name", None) + if class_name is None: + raise ValueError("must provide `path` or `name` of class to build component.") + module_name = self.module_scanner.get_module_name(class_name) + if module_name is None: + raise ValueError(f"can not find component class '{class_name}'.") + class_path = f"{module_name}.{class_name}" + + return class_path diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py new file mode 100644 index 0000000000..f6c156c9bd --- /dev/null +++ b/monai/apps/mmars/utils.py @@ -0,0 +1,61 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + + +def get_class(class_path: str): + """ + Get the class from specified class path. + + Args: + class_path (str): full path of the class. + + Raises: + ValueError: invalid class_path, missing the module name. + ValueError: class does not exist. + ValueError: module does not exist. + + """ + if len(class_path.split(".")) < 2: + raise ValueError(f"invalid class_path: {class_path}, missing the module name.") + module_name, class_name = class_path.rsplit(".", 1) + + try: + module_ = importlib.import_module(module_name) + + try: + class_ = getattr(module_, class_name) + except AttributeError as e: + raise ValueError(f"class {class_name} does not exist.") from e + + except AttributeError as e: + raise ValueError(f"module {module_name} does not exist.") from e + + return class_ + + +def instantiate_class(class_path: str, **kwargs): + """ + Method for creating an instance for the specified class. + + Args: + class_path: full path of the class. + kwargs: arguments to initialize the class instance. + + Raises: + ValueError: class has paramenters error. + """ + + try: + return get_class(class_path)(**kwargs) + except TypeError as e: + raise ValueError(f"class {class_path} has parameters error.") from e From 431caf0a3482b0db308eca225729fce2925d1aac Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 14 Dec 2021 18:07:44 +0800 Subject: [PATCH 2/7] [DLMED] add more doc-string Signed-off-by: Nic Ma --- monai/apps/mmars/config_parser.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index 6bc20e7e8d..8ede397959 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -62,16 +62,28 @@ class ConfigParser: pkgs: the expected packages to scan. modules: the expected modules in the packages to scan. - Raises: - ValueError: must provide `path` or `name` of class to build component. - ValueError: can not find component class. - """ def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) def build_component(self, config: Dict) -> object: + """ + Build component instance based on the provided dictonary config. + Supported keys for the config: + - 'name' - class name in the modules of packages. + - 'path' - directly specify the class path, based on PYTHONPATH, ignore 'name' if specified. + - 'args' - arguments to initialize the component instance. + - 'disabled' - if defined `'disabled': true`, will skip the buiding, useful for development or tuning. + + Args: + config: dictionary config to define a component. + + Raises: + ValueError: must provide `path` or `name` of class to build component. + ValueError: can not find component class. + + """ if not isinstance(config, dict): raise ValueError("config of component must be a dictionary.") From f55f70262b51cd6d5e3126cf9175f49492eeb5e0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 20 Dec 2021 14:32:01 +0800 Subject: [PATCH 3/7] [DLMED] add unit tests Signed-off-by: Nic Ma --- monai/apps/__init__.py | 12 +++++++- tests/test_config_parser.py | 59 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 tests/test_config_parser.py diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 1df6d74f9d..241abac497 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,5 +10,15 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar +from .mmars import ( + MODEL_DESC, + ConfigParser, + ModuleScanner, + RemoteMMARKeys, + download_mmar, + get_class, + get_model_spec, + instantiate_class, + load_from_mmar, +) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py new file mode 100644 index 0000000000..4b170bd436 --- /dev/null +++ b/tests/test_config_parser.py @@ -0,0 +1,59 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps import ConfigParser +from monai.transforms import LoadImaged + +TEST_CASES = [ + # test MONAI components + [ + dict(pkgs=["torch", "monai"], modules=["transforms"]), + {"name": "LoadImaged", "args": {"keys": ["image"]}}, + LoadImaged, + ], + # test non-monai modules + [ + dict(pkgs=["torch", "monai"], modules=["optim"]), + {"name": "Adam", "args": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, + ], + # test python `path` + [dict(pkgs=[], modules=[]), {"path": "monai.transforms.LoadImaged", "args": {"keys": ["image"]}}, LoadImaged], + # test `disabled` + [ + dict(pkgs=["torch", "monai"], modules=["transforms"]), + {"name": "LoadImaged", "disabled": True, "args": {"keys": ["image"]}}, + None, + ], +] + + +class TestConfigParser(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_type(self, input_param, test_input, output_type): + configer = ConfigParser(**input_param) + result = configer.build_component(test_input) + if result is not None: + self.assertTrue(isinstance(result, output_type)) + if isinstance(result, LoadImaged): + self.assertEqual(result.keys[0], "image") + else: + # test `disabled` works fine + self.assertEqual(result, output_type) + + +if __name__ == "__main__": + unittest.main() From dc2602cf2003eb295db308b049293c3ed59e6fd1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 20 Dec 2021 23:22:37 +0800 Subject: [PATCH 4/7] [DLMED] fix CI error Signed-off-by: Nic Ma --- tests/test_config_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 4b170bd436..70e248dae4 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -20,13 +20,13 @@ TEST_CASES = [ # test MONAI components [ - dict(pkgs=["torch", "monai"], modules=["transforms"]), + dict(pkgs=["torch.nn", "monai"], modules=["transforms"]), {"name": "LoadImaged", "args": {"keys": ["image"]}}, LoadImaged, ], # test non-monai modules [ - dict(pkgs=["torch", "monai"], modules=["optim"]), + dict(pkgs=["torch.optim", "monai"], modules=["adam"]), {"name": "Adam", "args": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ], @@ -34,7 +34,7 @@ [dict(pkgs=[], modules=[]), {"path": "monai.transforms.LoadImaged", "args": {"keys": ["image"]}}, LoadImaged], # test `disabled` [ - dict(pkgs=["torch", "monai"], modules=["transforms"]), + dict(pkgs=["torch.utils", "monai"], modules=["transforms"]), {"name": "LoadImaged", "disabled": True, "args": {"keys": ["image"]}}, None, ], From 2d2c4b4593d378b1b69fb934fda2a62e55814f0e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 20 Dec 2021 23:52:59 +0800 Subject: [PATCH 5/7] [DLMED] fix test error Signed-off-by: Nic Ma --- tests/test_config_parser.py | 56 +++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 70e248dae4..0c217ff764 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -16,44 +16,52 @@ from monai.apps import ConfigParser from monai.transforms import LoadImaged +from tests.utils import skip_if_windows -TEST_CASES = [ - # test MONAI components - [ - dict(pkgs=["torch.nn", "monai"], modules=["transforms"]), - {"name": "LoadImaged", "args": {"keys": ["image"]}}, - LoadImaged, - ], - # test non-monai modules - [ - dict(pkgs=["torch.optim", "monai"], modules=["adam"]), - {"name": "Adam", "args": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, - ], - # test python `path` - [dict(pkgs=[], modules=[]), {"path": "monai.transforms.LoadImaged", "args": {"keys": ["image"]}}, LoadImaged], - # test `disabled` - [ - dict(pkgs=["torch.utils", "monai"], modules=["transforms"]), - {"name": "LoadImaged", "disabled": True, "args": {"keys": ["image"]}}, - None, - ], +TEST_CASE_1 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"name": "LoadImaged", "args": {"keys": ["image"]}}, + LoadImaged, +] +# test python `path` +TEST_CASE_2 = [ + dict(pkgs=[], modules=[]), + {"path": "monai.transforms.LoadImaged", "args": {"keys": ["image"]}}, + LoadImaged, +] +# test `disabled` +TEST_CASE_3 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"name": "LoadImaged", "disabled": True, "args": {"keys": ["image"]}}, + None, +] +# test non-monai modules +TEST_CASE_4 = [ + dict(pkgs=["torch.optim", "monai"], modules=["adam"]), + {"name": "Adam", "args": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, ] class TestConfigParser(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_type(self, input_param, test_input, output_type): configer = ConfigParser(**input_param) result = configer.build_component(test_input) if result is not None: self.assertTrue(isinstance(result, output_type)) - if isinstance(result, LoadImaged): - self.assertEqual(result.keys[0], "image") + self.assertEqual(result.keys[0], "image") else: # test `disabled` works fine self.assertEqual(result, output_type) + @skip_if_windows + @parameterized.expand([TEST_CASE_4]) + def test_non_monai(self, input_param, test_input, output_type): + configer = ConfigParser(**input_param) + result = configer.build_component(test_input) + self.assertTrue(isinstance(result, output_type)) + if __name__ == "__main__": unittest.main() From 1a9bb2d6785ef741f9e5361267be24bbd542267a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 21 Dec 2021 00:11:01 +0800 Subject: [PATCH 6/7] [DLMED] skip for windows Signed-off-by: Nic Ma --- tests/test_config_parser.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 0c217ff764..9108876b2d 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -55,9 +55,11 @@ def test_type(self, input_param, test_input, output_type): # test `disabled` works fine self.assertEqual(result, output_type) - @skip_if_windows + +@skip_if_windows +class TestConfigParserExternal(unittest.TestCase): @parameterized.expand([TEST_CASE_4]) - def test_non_monai(self, input_param, test_input, output_type): + def test_type(self, input_param, test_input, output_type): configer = ConfigParser(**input_param) result = configer.build_component(test_input) self.assertTrue(isinstance(result, output_type)) From b925f27c8e6a25d8d086e67105ff2f2b83fb611d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 21 Dec 2021 07:42:49 +0800 Subject: [PATCH 7/7] [DLMED] fix windows test Signed-off-by: Nic Ma --- monai/apps/mmars/config_parser.py | 19 +++++++++---------- tests/test_config_parser.py | 15 +++------------ 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index 8ede397959..ef69dd837d 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -36,18 +36,17 @@ def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): def _create_classes_table(self): class_table = {} for pkg in self.pkgs: - package = __import__(pkg) + package = importlib.import_module(pkg) for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."): - if modname.startswith(pkg): - if any(name in modname for name in self.modules): - try: - module = importlib.import_module(modname) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and obj.__module__ == modname: - class_table[name] = modname - except ModuleNotFoundError: - pass + if any(name in modname for name in self.modules): + try: + module = importlib.import_module(modname) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__module__ == modname: + class_table[name] = modname + except ModuleNotFoundError: + pass return class_table def get_module_name(self, class_name): diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 9108876b2d..c8d041d3b9 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -16,7 +16,6 @@ from monai.apps import ConfigParser from monai.transforms import LoadImaged -from tests.utils import skip_if_windows TEST_CASE_1 = [ dict(pkgs=["monai"], modules=["transforms"]), @@ -44,26 +43,18 @@ class TestConfigParser(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_type(self, input_param, test_input, output_type): configer = ConfigParser(**input_param) result = configer.build_component(test_input) if result is not None: self.assertTrue(isinstance(result, output_type)) - self.assertEqual(result.keys[0], "image") + if isinstance(result, LoadImaged): + self.assertEqual(result.keys[0], "image") else: # test `disabled` works fine self.assertEqual(result, output_type) -@skip_if_windows -class TestConfigParserExternal(unittest.TestCase): - @parameterized.expand([TEST_CASE_4]) - def test_type(self, input_param, test_input, output_type): - configer = ConfigParser(**input_param) - result = configer.build_component(test_input) - self.assertTrue(isinstance(result, output_type)) - - if __name__ == "__main__": unittest.main()