diff --git a/docs/source/api.rst b/docs/source/api.rst index 0596a25514..c2b19adeb2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -7,6 +7,7 @@ API Reference :maxdepth: 1 apps + bundle transforms losses networks diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 4b1cdc6f43..239ae9eb17 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -19,8 +19,8 @@ Applications :members: -Clara MMARs ------------ +`Clara MMARs` +------------- .. autofunction:: download_mmar .. autofunction:: load_from_mmar @@ -29,25 +29,6 @@ Clara MMARs :annotation: -Model Manifest --------------- - -.. autoclass:: ComponentLocator - :members: - -.. autoclass:: ConfigComponent - :members: - -.. autoclass:: ConfigExpression - :members: - -.. autoclass:: ConfigItem - :members: - -.. autoclass:: ReferenceResolver - :members: - - `Utilities` ----------- diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst new file mode 100644 index 0000000000..03d4e07d17 --- /dev/null +++ b/docs/source/bundle.rst @@ -0,0 +1,34 @@ +:github_url: https://github.com/Project-MONAI/MONAI + +.. _bundle: + +Model Bundle +============ +.. currentmodule:: monai.bundle + +`Config Item` +------------- +.. autoclass:: Instantiable + :members: + +.. autoclass:: ComponentLocator + :members: + +.. autoclass:: ConfigComponent + :members: + +.. autoclass:: ConfigExpression + :members: + +.. autoclass:: ConfigItem + :members: + +`Reference Resolver` +-------------------- +.. autoclass:: ReferenceResolver + :members: + +`Config Parser` +--------------- +.. autoclass:: ConfigParser + :members: diff --git a/docs/source/conf.py b/docs/source/conf.py index 8af3fe8b75..db0ca11be3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -40,6 +40,7 @@ "engines", "data", "apps", + "bundle", "config", "handlers", "losses", diff --git a/monai/README.md b/monai/README.md index a224996f38..2c30531bf3 100644 --- a/monai/README.md +++ b/monai/README.md @@ -2,6 +2,8 @@ * **apps**: high level medical domain specific deep learning applications. +* **bundle**: components to build the portable self-descriptive model bundle. + * **config**: for system configuration and diagnostic output. * **csrc**: for C++/CUDA extensions. diff --git a/monai/__init__.py b/monai/__init__.py index 68a232b46d..a823a3e1e2 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -49,6 +49,7 @@ __all__ = [ "apps", + "bundle", "config", "data", "engines", diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 0f233bc3ef..893f7877d2 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,6 +10,5 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, ReferenceResolver from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/bundle/__init__.py similarity index 89% rename from monai/apps/manifest/__init__.py rename to monai/bundle/__init__.py index 79c4376d5c..68e2d543bb 100644 --- a/monai/apps/manifest/__init__.py +++ b/monai/bundle/__init__.py @@ -9,5 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable +from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver diff --git a/monai/apps/manifest/config_item.py b/monai/bundle/config_item.py similarity index 95% rename from monai/apps/manifest/config_item.py rename to monai/bundle/config_item.py index 075d00b961..44cdd3c634 100644 --- a/monai/apps/manifest/config_item.py +++ b/monai/bundle/config_item.py @@ -10,6 +10,7 @@ # limitations under the License. import inspect +import os import sys import warnings from abc import ABC, abstractmethod @@ -128,7 +129,7 @@ def __init__(self, config: Any, id: str = "") -> None: self.config = config self.id = id - def get_id(self) -> Optional[str]: + def get_id(self) -> str: """ Get the ID name of current config item, useful to identify config items during parsing. @@ -153,6 +154,9 @@ def get_config(self): """ return self.config + def __repr__(self) -> str: + return str(self.config) + class ConfigComponent(ConfigItem, Instantiable): """ @@ -187,7 +191,7 @@ class ConfigComponent(ConfigItem, Instantiable): locator: a ``ComponentLocator`` to convert a module name string into the actual python module. if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. - See also: :py:class:`monai.apps.manifest.ComponentLocator`. + See also: :py:class:`monai.bundle.ComponentLocator`. """ @@ -291,7 +295,7 @@ class ConfigExpression(ConfigItem): .. code-block:: python import monai - from monai.apps.manifest import ConfigExpression + from monai.bundle import ConfigExpression config = "$monai.__version__" expression = ConfigExpression(config, id="test", globals={"monai": monai}) @@ -304,6 +308,9 @@ class ConfigExpression(ConfigItem): """ + prefix = "$" + run_eval = False if os.environ.get("MONAI_EVAL_EXPR", "1") == "0" else True + def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None: super().__init__(config=config, id=id) self.globals = globals @@ -320,10 +327,12 @@ def evaluate(self, locals: Optional[Dict] = None): value = self.get_config() if not ConfigExpression.is_expression(value): return None - return eval(value[1:], self.globals, locals) + if not self.run_eval: + return f"{value[len(self.prefix) :]}" + return eval(value[len(self.prefix) :], self.globals, locals) - @staticmethod - def is_expression(config: Union[Dict, List, str]) -> bool: + @classmethod + def is_expression(cls, config: Union[Dict, List, str]) -> bool: """ Check whether the config is an executable expression string. Currently, a string starts with ``"$"`` character is interpreted as an expression. @@ -332,4 +341,4 @@ def is_expression(config: Union[Dict, List, str]) -> bool: config: input config content to check. """ - return isinstance(config, str) and config.startswith("$") + return isinstance(config, str) and config.startswith(cls.prefix) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py new file mode 100644 index 0000000000..5ebcfd03b4 --- /dev/null +++ b/monai/bundle/config_parser.py @@ -0,0 +1,224 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from copy import deepcopy +from typing import Any, Dict, Optional, Sequence, Union + +from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle.reference_resolver import ReferenceResolver + + +class ConfigParser: + """ + The primary configuration parser. It traverses a structured config (in the form of nested Python dict or list), + creates ``ConfigItem``, and assign unique IDs according to the structures. + + This class provides convenient access to the set of ``ConfigItem`` of the config by ID. + A typical workflow of config parsing is as follows: + + - Initialize ``ConfigParser`` with the ``config`` source. + - Call ``get_parsed_content()`` to get expected component with `id`. + + .. code-block:: python + + from monai.apps import ConfigParser + + config = { + "my_dims": 2, + "dims_1": "$@my_dims + 1", + "my_xform": {"": "LoadImage"}, + "my_net": {"": "BasicUNet", + "": {"spatial_dims": "@dims_1", "in_channels": 1, "out_channels": 4}}, + "trainer": {"": "SupervisedTrainer", + "": {"network": "@my_net", "preprocessing": "@my_xform"}} + } + # in the example $@my_dims + 1 is an expression, which adds 1 to the value of @my_dims + parser = ConfigParser(config) + + # get/set configuration content, the set method should happen before calling parse() + print(parser["my_net"][""]["in_channels"]) # original input channels 1 + parser["my_net"][""]["in_channels"] = 4 # change input channels to 4 + print(parser["my_net"][""]["in_channels"]) + + # instantiate the network component + parser.parse(True) + net = parser.get_parsed_content("my_net", instantiate=True) + print(net) + + # also support to get the configuration content of parsed `ConfigItem` + trainer = parser.get_parsed_content("trainer", instantiate=False) + print(trainer) + + Args: + config: input config source to parse. + excludes: when importing modules to instantiate components, + excluding components from modules specified in ``excludes``. + globals: pre-import packages as global variables to ``ConfigExpression``, + so that expressions, for example, ``"$monai.data.list_data_collate"`` can use ``monai`` modules. + The current supported globals and alias names are + ``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``. + These are MONAI's minimal dependencies. + + See also: + + - :py:class:`monai.apps.ConfigItem` + + """ + + def __init__( + self, + config: Any, + excludes: Optional[Union[Sequence[str], str]] = None, + globals: Optional[Dict[str, Any]] = None, + ): + self.config = None + self.globals: Dict[str, Any] = {} + globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"} if globals is None else globals + if globals is not None: + for k, v in globals.items(): + self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v + + self.locator = ComponentLocator(excludes=excludes) + self.ref_resolver = ReferenceResolver() + self.set(config=config) + + def __repr__(self): + return f"{self.config}" + + def __getitem__(self, id: Union[str, int]): + """ + Get the config by id. + + Args: + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net##channels"``. ``""`` indicates the entire ``self.config``. + + """ + if id == "": + return self.config + config = self.config + for k in str(id).split(self.ref_resolver.sep): + if not isinstance(config, (dict, list)): + raise ValueError(f"config must be dict or list for key `{k}`, but got {type(config)}: {config}.") + indexing = k if isinstance(config, dict) else int(k) + config = config[indexing] + return config + + def __setitem__(self, id: Union[str, int], config: Any): + """ + Set config by ``id``. Note that this method should be used before ``parse()`` or ``get_parsed_content()`` + to ensure the updates are included in the parsed content. + + Args: + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net##channels"``. ``""`` indicates the entire ``self.config``. + config: config to set at location ``id``. + + """ + if id == "": + self.config = config + self.ref_resolver.reset() + return + keys = str(id).split(self.ref_resolver.sep) + # get the last parent level config item and replace it + last_id = self.ref_resolver.sep.join(keys[:-1]) + conf_ = self[last_id] + indexing = keys[-1] if isinstance(conf_, dict) else int(keys[-1]) + conf_[indexing] = config + self.ref_resolver.reset() + return + + def get(self, id: str = "", default: Optional[Any] = None): + """ + Get the config by id. + + Args: + id: id to specify the expected position. See also :py:meth:`__getitem__`. + default: default value to return if the specified ``id`` is invalid. + + """ + try: + return self[id] + except KeyError: + return default + + def set(self, config: Any, id: str = ""): + """ + Set config by ``id``. See also :py:meth:`__setitem__`. + + """ + self[id] = config + + def _do_parse(self, config, id: str = ""): + """ + Recursively parse the nested data in config source, add every item as `ConfigItem` to the resolver. + + Args: + config: config source to parse. + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net##channels"``. ``""`` indicates the entire ``self.config``. + + """ + if isinstance(config, (dict, list)): + subs = enumerate(config) if isinstance(config, list) else config.items() + for k, v in subs: + sub_id = f"{id}{self.ref_resolver.sep}{k}" if id != "" else k + self._do_parse(config=v, id=sub_id) + + # copy every config item to make them independent and add them to the resolver + item_conf = deepcopy(config) + if ConfigComponent.is_instantiable(item_conf): + self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator)) + elif ConfigExpression.is_expression(item_conf): + self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals)) + else: + self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id)) + + def parse(self, reset: bool = True): + """ + Recursively parse the config source, add every item as ``ConfigItem`` to the resolver. + + Args: + reset: whether to reset the ``reference_resolver`` before parsing. Defaults to `True`. + + """ + if reset: + self.ref_resolver.reset() + self._do_parse(config=self.config) + + def get_parsed_content(self, id: str = "", **kwargs): + """ + Get the parsed result of ``ConfigItem`` with the specified ``id``. + + - If the item is ``ConfigComponent`` and ``instantiate=True``, the result is the instance. + - If the item is ``ConfigExpression`` and ``eval_expr=True``, the result is the evaluated output. + - Else, the result is the configuration content of `ConfigItem`. + + Args: + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net##channels"``. ``""`` indicates the entire ``self.config``. + kwargs: additional keyword arguments to be passed to ``_resolve_one_item``. + Currently support ``reset`` (for parse), ``instantiate`` and ``eval_expr``. All defaulting to True. + + """ + if not self.ref_resolver.is_resolved(): + # not parsed the config source yet, parse it + self.parse(kwargs.get("reset", True)) + return self.ref_resolver.get_resolved_content(id=id, **kwargs) diff --git a/monai/apps/manifest/reference_resolver.py b/monai/bundle/reference_resolver.py similarity index 67% rename from monai/apps/manifest/reference_resolver.py rename to monai/bundle/reference_resolver.py index 32d089370a..45d897af05 100644 --- a/monai/apps/manifest/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -10,10 +10,10 @@ # limitations under the License. import re -import warnings from typing import Any, Dict, Optional, Sequence, Set -from monai.apps.manifest.config_item import ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem +from monai.utils import look_up_option class ReferenceResolver: @@ -42,11 +42,28 @@ class ReferenceResolver: """ + _vars = "__local_refs" + sep = "#" # separator for key indexing + ref = "@" # reference prefix + # match a reference string, e.g. "@id#key", "@id#key#0", "@##key" + id_matcher = re.compile(rf"{ref}(?:(?:<\w*>)|(?:\w*))(?:(?:{sep}<\w*>)|(?:{sep}\w*))*") + def __init__(self, items: Optional[Sequence[ConfigItem]] = None): # save the items in a dictionary with the `ConfigItem.id` as key - self.items = {} if items is None else {i.get_id(): i for i in items} + self.items: Dict[str, Any] = {} if items is None else {i.get_id(): i for i in items} self.resolved_content: Dict[str, Any] = {} + def reset(self): + """ + Clear all the added `ConfigItem` and all the resolved content. + + """ + self.items = {} + self.resolved_content = {} + + def is_resolved(self) -> bool: + return bool(self.resolved_content) + def add_item(self, item: ConfigItem): """ Add a ``ConfigItem`` to the resolver. @@ -56,14 +73,11 @@ def add_item(self, item: ConfigItem): """ id = item.get_id() - if id == "": - raise ValueError("id should not be empty when resolving reference.") if id in self.items: - warnings.warn(f"id '{id}' is already added.") return self.items[id] = item - def get_item(self, id: str, resolve: bool = False): + def get_item(self, id: str, resolve: bool = False, **kwargs): """ Get the ``ConfigItem`` by id. @@ -73,13 +87,14 @@ def get_item(self, id: str, resolve: bool = False): Args: id: id of the expected config item. resolve: whether to resolve the item if it is not resolved, default to False. - + kwargs: keyword arguments to pass to ``_resolve_one_item()``. + Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. """ if resolve and id not in self.resolved_content: - self._resolve_one_item(id=id) + self._resolve_one_item(id=id, **kwargs) return self.items.get(id) - def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None): + def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **kwargs): """ Resolve one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``. If it has unresolved references, recursively resolve the referring items first. @@ -89,9 +104,14 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None): waiting_list: set of ids pending to be resolved. It's used to detect circular references such as: `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. + kwargs: keyword arguments to pass to ``_resolve_one_item()``. + Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. """ - item = self.items[id] # if invalid id name, raise KeyError + try: + item = look_up_option(id, self.items, print_all_options=False) + except ValueError as err: + raise KeyError(f"id='{id}' is not found in the config resolver.") from err item_config = item.get_config() if waiting_list is None: @@ -99,46 +119,51 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None): waiting_list.add(id) ref_ids = self.find_refs_in_config(config=item_config, id=id) - - # if current item has reference already in the waiting list, that's circular references for d in ref_ids: + # if current item has reference already in the waiting list, that's circular references if d in waiting_list: - raise ValueError(f"detected circular references for id='{d}' in the config content.") - - # # check whether the component has any unresolved references - for d in ref_ids: + raise ValueError(f"detected circular references '{d}' for id='{id}' in the config content.") + # check whether the component has any unresolved references if d not in self.resolved_content: # this referring item is not resolved - if d not in self.items: - raise ValueError(f"the referring item `{d}` is not defined in config.") + try: + look_up_option(d, self.items, print_all_options=False) + except ValueError as err: + raise ValueError(f"the referring item `@{d}` is not defined in the config content.") from err # recursively resolve the reference first - self._resolve_one_item(id=d, waiting_list=waiting_list) + self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs) + waiting_list.discard(d) # all references are resolved, then try to resolve current config item new_config = self.update_config_with_refs(config=item_config, id=id, refs=self.resolved_content) item.update_config(config=new_config) # save the resolved result into `resolved_content` to recursively resolve others if isinstance(item, ConfigComponent): - self.resolved_content[id] = item.instantiate() + self.resolved_content[id] = item.instantiate() if kwargs.get("instantiate", True) else item elif isinstance(item, ConfigExpression): - self.resolved_content[id] = item.evaluate(locals={"refs": self.resolved_content}) + run_eval = kwargs.get("eval_expr", True) + self.resolved_content[id] = ( + item.evaluate(locals={f"{self._vars}": self.resolved_content}) if run_eval else item + ) else: self.resolved_content[id] = new_config - def get_resolved_content(self, id: str): + def get_resolved_content(self, id: str, **kwargs): """ Get the resolved ``ConfigItem`` by id. If there are unresolved references, try to resolve them first. Args: id: id name of the expected item. + kwargs: additional keyword arguments to be passed to ``_resolve_one_item``. + Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. """ if id not in self.resolved_content: - self._resolve_one_item(id=id) + self._resolve_one_item(id=id, **kwargs) return self.resolved_content[id] - @staticmethod - def match_refs_pattern(value: str) -> Set[str]: + @classmethod + def match_refs_pattern(cls, value: str) -> Set[str]: """ Match regular expression for the input string to find the references. The reference string starts with ``"@"``, like: ``"@XXX#YYY#ZZZ"``. @@ -149,15 +174,16 @@ def match_refs_pattern(value: str) -> Set[str]: """ refs: Set[str] = set() # regular expression pattern to match "@XXX" or "@XXX#YYY" - result = re.compile(r"@\w*[\#\w]*").findall(value) + result = cls.id_matcher.findall(value) + value_is_expr = ConfigExpression.is_expression(value) for item in result: - if ConfigExpression.is_expression(value) or value == item: + if value_is_expr or value == item: # only check when string starts with "$" or the whole content is "@XXX" - refs.add(item[1:]) + refs.add(item[len(cls.ref) :]) return refs - @staticmethod - def update_refs_pattern(value: str, refs: Dict) -> str: + @classmethod + def update_refs_pattern(cls, value: str, refs: Dict) -> str: """ Match regular expression for the input string to update content with the references. The reference part starts with ``"@"``, like: ``"@XXX#YYY#ZZZ"``. @@ -169,21 +195,22 @@ def update_refs_pattern(value: str, refs: Dict) -> str: """ # regular expression pattern to match "@XXX" or "@XXX#YYY" - result = re.compile(r"@\w*[\#\w]*").findall(value) + result = cls.id_matcher.findall(value) + value_is_expr = ConfigExpression.is_expression(value) for item in result: - ref_id = item[1:] + ref_id = item[len(cls.ref) :] # remove the ref prefix "@" if ref_id not in refs: raise KeyError(f"can not find expected ID '{ref_id}' in the references.") - if ConfigExpression.is_expression(value): + if value_is_expr: # replace with local code, will be used in the `evaluate` logic with `locals={"refs": ...}` - value = value.replace(item, f"refs['{ref_id}']") + value = value.replace(item, f"{cls._vars}['{ref_id}']") elif value == item: # the whole content is "@XXX", it will avoid the case that regular string contains "@" value = refs[ref_id] return value - @staticmethod - def find_refs_in_config(config, id: str, refs: Optional[Set[str]] = None) -> Set[str]: + @classmethod + def find_refs_in_config(cls, config, id: str, refs: Optional[Set[str]] = None) -> Set[str]: """ Recursively search all the content of input config item to get the ids of references. References mean: the IDs of other config items (``"@XXX"`` in this config item), or the @@ -198,18 +225,18 @@ def find_refs_in_config(config, id: str, refs: Optional[Set[str]] = None) -> Set """ refs_: Set[str] = refs or set() if isinstance(config, str): - return refs_.union(ReferenceResolver.match_refs_pattern(value=config)) + return refs_.union(cls.match_refs_pattern(value=config)) if not isinstance(config, (list, dict)): return refs_ for k, v in config.items() if isinstance(config, dict) else enumerate(config): - sub_id = f"{id}#{k}" if id != "" else f"{k}" + sub_id = f"{id}{cls.sep}{k}" if id != "" else f"{k}" if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): refs_.add(sub_id) - refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_) + refs_ = cls.find_refs_in_config(v, sub_id, refs_) return refs_ - @staticmethod - def update_config_with_refs(config, id: str, refs: Optional[Dict] = None): + @classmethod + def update_config_with_refs(cls, config, id: str, refs: Optional[Dict] = None): """ With all the references in ``refs``, update the input config content with references and return the new config. @@ -222,15 +249,15 @@ def update_config_with_refs(config, id: str, refs: Optional[Dict] = None): """ refs_: Dict = refs or {} if isinstance(config, str): - return ReferenceResolver.update_refs_pattern(config, refs_) + return cls.update_refs_pattern(config, refs_) if not isinstance(config, (list, dict)): return config ret = type(config)() for idx, v in config.items() if isinstance(config, dict) else enumerate(config): - sub_id = f"{id}#{idx}" if id != "" else f"{idx}" + sub_id = f"{id}{cls.sep}{idx}" if id != "" else f"{idx}" if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): - updated = ReferenceResolver.update_config_with_refs(v, sub_id, refs_) + updated = refs_[sub_id] else: - updated = ReferenceResolver.update_config_with_refs(v, sub_id, refs_) + updated = cls.update_config_with_refs(v, sub_id, refs_) ret.update({idx: updated}) if isinstance(ret, dict) else ret.append(updated) return ret diff --git a/monai/utils/module.py b/monai/utils/module.py index 8b7745c3ee..de2152d182 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -10,13 +10,13 @@ # limitations under the License. import enum -import inspect import os import re import sys import warnings from functools import partial, wraps from importlib import import_module +from inspect import isclass, isfunction, ismethod from pkgutil import walk_packages from pydoc import locate from re import match @@ -47,7 +47,7 @@ ] -def look_up_option(opt_str, supported: Union[Collection, enum.EnumMeta], default="no_default"): +def look_up_option(opt_str, supported: Union[Collection, enum.EnumMeta], default="no_default", print_all_options=True): """ Look up the option in the supported collection and return the matched item. Raise a value error possibly with a guess of the closest match. @@ -58,6 +58,7 @@ def look_up_option(opt_str, supported: Union[Collection, enum.EnumMeta], default default: If it is given, this method will return `default` when `opt_str` is not found, instead of raising a `ValueError`. Otherwise, it defaults to `"no_default"`, so that the method may raise a `ValueError`. + print_all_options: whether to print all available options when `opt_str` is not found. Defaults to True Examples: @@ -113,12 +114,12 @@ class Color(Enum): if edit_dist <= 3: edit_dists[key] = edit_dist - supported_msg = f"Available options are {set_to_check}.\n" + supported_msg = f"Available options are {set_to_check}.\n" if print_all_options else "" if edit_dists: guess_at_spelling = min(edit_dists, key=edit_dists.get) # type: ignore raise ValueError( f"By '{opt_str}', did you mean '{guess_at_spelling}'?\n" - + f"'{opt_str}' is not a valid option.\n" + + f"'{opt_str}' is not a valid value.\n" + supported_msg ) raise ValueError(f"Unsupported option '{opt_str}', " + supported_msg) @@ -212,9 +213,10 @@ def instantiate(path: str, **kwargs): component = locate(path) if component is None: raise ModuleNotFoundError(f"Cannot locate '{path}'.") - if inspect.isclass(component): + if isclass(component): return component(**kwargs) - if inspect.isfunction(component): + # support regular function, static method and class method + if isfunction(component) or (ismethod(component) and isclass(getattr(component, "__self__", None))): return partial(component, **kwargs) warnings.warn(f"Component to instantiate must represent a valid class or function, but got {path}.") diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py index eafb2152d1..ebb2cca7b3 100644 --- a/tests/test_component_locator.py +++ b/tests/test_component_locator.py @@ -12,7 +12,7 @@ import unittest from pydoc import locate -from monai.apps.manifest import ComponentLocator +from monai.bundle import ComponentLocator from monai.utils import optional_import _, has_ignite = optional_import("ignite") diff --git a/tests/test_config_item.py b/tests/test_config_item.py index b2c2fec6c6..1284efab56 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -17,12 +17,12 @@ from parameterized import parameterized import monai -from monai.apps import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import optional_import +from monai.utils import min_version, optional_import -_, has_tv = optional_import("torchvision") +_, has_tv = optional_import("torchvision", "0.8.0", min_version) TEST_CASE_1 = [{"lr": 0.001}, 0.0001] diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py new file mode 100644 index 0000000000..5b5aa2b816 --- /dev/null +++ b/tests/test_config_parser.py @@ -0,0 +1,128 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import skipUnless + +from parameterized import parameterized + +from monai.bundle.config_parser import ConfigParser +from monai.data import DataLoader, Dataset +from monai.transforms import Compose, LoadImaged, RandTorchVisiond +from monai.utils import min_version, optional_import + +_, has_tv = optional_import("torchvision", "0.8.0", min_version) + +# test the resolved and parsed instances +TEST_CASE_1 = [ + { + "transform": { + "": "Compose", + "": { + "transforms": [ + {"": "LoadImaged", "": {"keys": "image"}}, + { + "": "RandTorchVisiond", + "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, + }, + ] + }, + }, + "dataset": {"": "Dataset", "": {"data": [1, 2], "transform": "@transform"}}, + "dataloader": { + "": "DataLoader", + "": {"dataset": "@dataset", "batch_size": 2, "collate_fn": "monai.data.list_data_collate"}, + }, + }, + ["transform", "transform##transforms#0", "transform##transforms#1", "dataset", "dataloader"], + [Compose, LoadImaged, RandTorchVisiond, Dataset, DataLoader], +] + + +class TestClass: + @staticmethod + def compute(a, b, func=lambda x, y: x + y): + return func(a, b) + + @classmethod + def cls_compute(cls, a, b, func=lambda x, y: x + y): + return cls.compute(a, b, func) + + def __call__(self, a, b): + return self.compute(a, b) + + +TEST_CASE_2 = [ + { + "basic_func": "$lambda x, y: x + y", + "static_func": "$TestClass.compute", + "cls_func": "$TestClass.cls_compute", + "lambda_static_func": "$lambda x, y: TestClass.compute(x, y)", + "lambda_cls_func": "$lambda x, y: TestClass.cls_compute(x, y)", + "compute": {"": "tests.test_config_parser.TestClass.compute", "": {"func": "@basic_func"}}, + "cls_compute": {"": "tests.test_config_parser.TestClass.cls_compute", "": {"func": "@basic_func"}}, + "call_compute": {"": "tests.test_config_parser.TestClass"}, + "error_func": "$TestClass.__call__", + "": "$lambda x, y: x + y", + } +] + + +class TestConfigComponent(unittest.TestCase): + def test_config_content(self): + test_config = {"preprocessing": [{"": "LoadImage"}], "dataset": {"": "Dataset"}} + parser = ConfigParser(config=test_config) + # test `get`, `set`, `__getitem__`, `__setitem__` + self.assertEqual(str(parser.get()), str(test_config)) + parser.set(config=test_config) + self.assertListEqual(parser["preprocessing"], test_config["preprocessing"]) + parser["dataset"] = {"": "CacheDataset"} + self.assertEqual(parser["dataset"][""], "CacheDataset") + # test nested ids + parser["dataset#"] = "Dataset" + self.assertEqual(parser["dataset#"], "Dataset") + # test int id + parser.set(["test1", "test2", "test3"]) + parser[1] = "test4" + self.assertEqual(parser[1], "test4") + + @parameterized.expand([TEST_CASE_1]) + @skipUnless(has_tv, "Requires torchvision >= 0.8.0.") + def test_parse(self, config, expected_ids, output_types): + parser = ConfigParser(config=config, globals={"monai": "monai"}) + # test lazy instantiation with original config content + parser["transform"][""]["transforms"][0][""]["keys"] = "label1" + self.assertEqual(parser.get_parsed_content(id="transform##transforms#0").keys[0], "label1") + # test nested id + parser["transform##transforms#0##keys"] = "label2" + self.assertEqual(parser.get_parsed_content(id="transform##transforms#0").keys[0], "label2") + for id, cls in zip(expected_ids, output_types): + self.assertTrue(isinstance(parser.get_parsed_content(id), cls)) + # test root content + root = parser.get_parsed_content(id="") + for v, cls in zip(root.values(), [Compose, Dataset, DataLoader]): + self.assertTrue(isinstance(v, cls)) + + @parameterized.expand([TEST_CASE_2]) + def test_function(self, config): + parser = ConfigParser(config=config, globals={"TestClass": TestClass}) + for id in config: + func = parser.get_parsed_content(id=id) + self.assertTrue(id in parser.ref_resolver.resolved_content) + if id == "error_func": + with self.assertRaises(TypeError): + func(1, 2) + continue + self.assertEqual(func(1, 2), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py index a62d6befd9..e16a795c40 100644 --- a/tests/test_reference_resolver.py +++ b/tests/test_reference_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,13 +15,13 @@ from parameterized import parameterized import monai -from monai.apps import ConfigComponent, ReferenceResolver -from monai.apps.manifest.config_item import ComponentLocator, ConfigExpression, ConfigItem +from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle.reference_resolver import ReferenceResolver from monai.data import DataLoader from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import optional_import +from monai.utils import min_version, optional_import -_, has_tv = optional_import("torchvision") +_, has_tv = optional_import("torchvision", "0.8.0", min_version) # test instance with no dependencies TEST_CASE_1 = [