Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3482 Build component instance from dictionary config #3518

Merged
merged 14 commits into from
Jan 6, 2022
10 changes: 10 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ Clara MMARs
:annotation:


Model Package
-------------

.. autoclass:: ConfigParser
:members:

.. autoclass:: ModuleScanner
:members:


`Utilities`
-----------

Expand Down
12 changes: 11 additions & 1 deletion monai/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,15 @@
# limitations under the License.

from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset
from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar
from .mmars import (
MODEL_DESC,
ConfigParser,
ModuleScanner,
RemoteMMARKeys,
download_mmar,
get_class,
get_model_spec,
instantiate_class,
load_from_mmar,
)
from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger
2 changes: 2 additions & 0 deletions monai/apps/mmars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .config_parser import ConfigParser, ModuleScanner
from .mmars import download_mmar, get_model_spec, load_from_mmar
from .model_desc import MODEL_DESC, RemoteMMARKeys
from .utils import get_class, instantiate_class
108 changes: 108 additions & 0 deletions monai/apps/mmars/config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import inspect
import pkgutil
from typing import Dict, Sequence

from monai.apps.mmars.utils import instantiate_class


class ModuleScanner:
"""
Scan all the available classes in the specified packages and modules.
Map the all the class names and the module names in a table.

Args:
pkgs: the expected packages to scan.
modules: the expected modules in the packages to scan.

"""

def __init__(self, pkgs: Sequence[str], modules: Sequence[str]):
self.pkgs = pkgs
self.modules = modules
self._class_table = self._create_classes_table()

def _create_classes_table(self):
class_table = {}
for pkg in self.pkgs:
package = importlib.import_module(pkg)

for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."):
if any(name in modname for name in self.modules):
try:
module = importlib.import_module(modname)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and obj.__module__ == modname:
class_table[name] = modname
except ModuleNotFoundError:
pass
return class_table

def get_module_name(self, class_name):
return self._class_table.get(class_name, None)


class ConfigParser:
"""
Parse dictionary format config and build components.

Args:
pkgs: the expected packages to scan.
modules: the expected modules in the packages to scan.

"""

def __init__(self, pkgs: Sequence[str], modules: Sequence[str]):
self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules)

def build_component(self, config: Dict) -> object:
"""
Build component instance based on the provided dictonary config.
Supported keys for the config:
- 'name' - class name in the modules of packages.
- 'path' - directly specify the class path, based on PYTHONPATH, ignore 'name' if specified.
- 'args' - arguments to initialize the component instance.
- 'disabled' - if defined `'disabled': true`, will skip the buiding, useful for development or tuning.

Args:
config: dictionary config to define a component.

Raises:
ValueError: must provide `path` or `name` of class to build component.
ValueError: can not find component class.

"""
if not isinstance(config, dict):
raise ValueError("config of component must be a dictionary.")

if config.get("disabled") is True:
# if marked as `disabled`, skip parsing
return None

class_args = config.get("args", {})
class_path = self._get_class_path(config)
return instantiate_class(class_path, **class_args)

def _get_class_path(self, config):
class_path = config.get("path", None)
if class_path is None:
class_name = config.get("name", None)
if class_name is None:
raise ValueError("must provide `path` or `name` of class to build component.")
module_name = self.module_scanner.get_module_name(class_name)
if module_name is None:
raise ValueError(f"can not find component class '{class_name}'.")
class_path = f"{module_name}.{class_name}"

return class_path
61 changes: 61 additions & 0 deletions monai/apps/mmars/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib


def get_class(class_path: str):
"""
Get the class from specified class path.

Args:
class_path (str): full path of the class.

Raises:
ValueError: invalid class_path, missing the module name.
ValueError: class does not exist.
ValueError: module does not exist.

"""
if len(class_path.split(".")) < 2:
raise ValueError(f"invalid class_path: {class_path}, missing the module name.")
module_name, class_name = class_path.rsplit(".", 1)

try:
module_ = importlib.import_module(module_name)

try:
class_ = getattr(module_, class_name)
except AttributeError as e:
raise ValueError(f"class {class_name} does not exist.") from e

except AttributeError as e:
raise ValueError(f"module {module_name} does not exist.") from e

return class_


def instantiate_class(class_path: str, **kwargs):
"""
Method for creating an instance for the specified class.

Args:
class_path: full path of the class.
kwargs: arguments to initialize the class instance.

Raises:
ValueError: class has paramenters error.
"""

try:
return get_class(class_path)(**kwargs)
except TypeError as e:
raise ValueError(f"class {class_path} has parameters error.") from e
60 changes: 60 additions & 0 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.apps import ConfigParser
from monai.transforms import LoadImaged

TEST_CASE_1 = [
dict(pkgs=["monai"], modules=["transforms"]),
{"name": "LoadImaged", "args": {"keys": ["image"]}},
LoadImaged,
]
# test python `path`
TEST_CASE_2 = [
dict(pkgs=[], modules=[]),
{"path": "monai.transforms.LoadImaged", "args": {"keys": ["image"]}},
LoadImaged,
]
# test `disabled`
TEST_CASE_3 = [
dict(pkgs=["monai"], modules=["transforms"]),
{"name": "LoadImaged", "disabled": True, "args": {"keys": ["image"]}},
None,
]
# test non-monai modules
TEST_CASE_4 = [
dict(pkgs=["torch.optim", "monai"], modules=["adam"]),
{"name": "Adam", "args": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}},
torch.optim.Adam,
]


class TestConfigParser(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_type(self, input_param, test_input, output_type):
configer = ConfigParser(**input_param)
result = configer.build_component(test_input)
if result is not None:
self.assertTrue(isinstance(result, output_type))
if isinstance(result, LoadImaged):
self.assertEqual(result.keys[0], "image")
else:
# test `disabled` works fine
self.assertEqual(result, output_type)


if __name__ == "__main__":
unittest.main()