From 494456474249911c742e76abd3467d9cba08c564 Mon Sep 17 00:00:00 2001 From: Yifan Zhou Date: Tue, 1 Mar 2022 16:02:12 +0800 Subject: [PATCH] [Enhancement] Make rewriter more powerful (#150) * Finish function tests * lint * resolve comments * Fix tests * docstring & fix * Complement informations * lint * Add example * Fix version * Remove todo Co-authored-by: RunningLeon --- mmdeploy/codebase/mmdet/deploy/utils.py | 29 ++ mmdeploy/core/rewriters/function_rewriter.py | 38 +-- mmdeploy/core/rewriters/module_rewriter.py | 41 +-- mmdeploy/core/rewriters/rewriter_manager.py | 29 +- mmdeploy/core/rewriters/rewriter_utils.py | 288 ++++++++++++++++--- mmdeploy/core/rewriters/symbolic_rewriter.py | 25 +- mmdeploy/utils/__init__.py | 6 +- mmdeploy/utils/constants.py | 7 + mmdeploy/utils/env.py | 49 ++++ tests/test_core/test_function_rewriter.py | 22 +- tests/test_core/test_rewriter_registry.py | 59 ---- tests/test_core/test_rewriter_utils.py | 112 ++++++++ tests/test_utils/test_util.py | 23 ++ tools/check_env.py | 81 ++---- 14 files changed, 560 insertions(+), 249 deletions(-) create mode 100644 mmdeploy/utils/env.py delete mode 100644 tests/test_core/test_rewriter_registry.py create mode 100644 tests/test_core/test_rewriter_utils.py diff --git a/mmdeploy/codebase/mmdet/deploy/utils.py b/mmdeploy/codebase/mmdet/deploy/utils.py index 1ecd451e2f..860cb54239 100644 --- a/mmdeploy/codebase/mmdet/deploy/utils.py +++ b/mmdeploy/codebase/mmdet/deploy/utils.py @@ -5,6 +5,8 @@ import torch from torch import Tensor +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker from mmdeploy.utils import load_config @@ -69,6 +71,33 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, return x1, y1, x2, y2 +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes', + backend='tensorrt', + extra_checkers=LibVersionChecker('tensorrt', min_version='8')) +def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, + max_shape: Union[Tensor, Sequence[int]]): + """Clip bboxes for onnx. From TensorRT 8 we can do the operators on the + tensors directly. + + Args: + ctx (ContextCaller): The context with additional information. + x1 (Tensor): The x1 for bounding boxes. + y1 (Tensor): The y1 for bounding boxes. + x2 (Tensor): The x2 for bounding boxes. + y2 (Tensor): The y2 for bounding boxes. + max_shape (Tensor | Sequence[int]): The (H,W) of original image. + Returns: + tuple(Tensor): The clipped x1, y1, x2, y2. + """ + assert len(max_shape) == 2, '`max_shape` should be [h, w]' + x1 = torch.clamp(x1, 0, max_shape[1]) + y1 = torch.clamp(y1, 0, max_shape[0]) + x2 = torch.clamp(x2, 0, max_shape[1]) + y2 = torch.clamp(y2, 0, max_shape[0]) + return x1, y1, x2, y2 + + def pad_with_value(x: Tensor, pad_dim: int, pad_size: int, diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index 674361f634..e80ed41d06 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict +from typing import Callable, Dict, List, Optional, Union -from mmdeploy.utils import Backend, get_root_logger -from .rewriter_utils import ContextCaller, RewriterRegistry, import_function +from mmdeploy.utils import IR, Backend, get_root_logger +from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, + import_function) def _set_func(origin_func_path: str, rewrite_func: Callable): @@ -66,32 +67,33 @@ class FunctionRewriter: def __init__(self): self._registry = RewriterRegistry() - def add_backend(self, backend: str): - """Add a backend by calling the _registry.add_backend.""" - self._registry.add_backend(backend) - - def register_rewriter(self, - func_name: str, - backend: str = Backend.DEFAULT.value, - **kwargs): + def register_rewriter( + self, + func_name: str, + backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, + extra_checkers: Optional[Union[Checker, List[Checker]]] = None, + **kwargs): """The interface of function rewriter decorator. Args: func_name (str): The function name/path to rewrite. - backend (str): The inference engine name. + backend (str): The rewriter will be activated on which backend. + ir (IR): The rewriter will be activated on which IR. + extra_checkers (Checker | List[Checker] | None): Other requirements + defined by Checker. + Returns: Callable: The process of registering function. """ - return self._registry.register_object(func_name, backend, **kwargs) + return self._registry.register_object(func_name, backend, ir, + extra_checkers, **kwargs) - def enter(self, - cfg: Dict = dict(), - backend: str = Backend.DEFAULT.value, - **kwargs): + def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): """The implementation of function rewrite.""" # Get current records - functions_records = self._registry.get_records(backend) + functions_records = self._registry.get_records(env) self._origin_functions = list() self._additional_functions = list() diff --git a/mmdeploy/core/rewriters/module_rewriter.py b/mmdeploy/core/rewriters/module_rewriter.py index 43720443c6..d0961809a0 100644 --- a/mmdeploy/core/rewriters/module_rewriter.py +++ b/mmdeploy/core/rewriters/module_rewriter.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +from typing import Dict, List, Optional, Union import mmcv from torch import nn -from mmdeploy.utils.constants import Backend -from .rewriter_utils import RewriterRegistry, eval_with_import +from mmdeploy.utils.constants import IR, Backend +from .rewriter_utils import (Checker, RewriterRegistry, collect_env, + eval_with_import) class ModuleRewriter: @@ -26,29 +28,33 @@ class ModuleRewriter: def __init__(self): self._registry = RewriterRegistry() - def add_backend(self, backend: str): - """Add a backend by calling the _registry.add_backend.""" - self._registry.add_backend(backend) - - def register_rewrite_module(self, - module_type: str, - backend: str = Backend.DEFAULT.value, - **kwargs): + def register_rewrite_module( + self, + module_type: str, + backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, + extra_checkers: Optional[Union[Checker, List[Checker]]] = None, + **kwargs): """The interface of module rewriter decorator. Args: module_type (str): The module type name to rewrite. - backend (str): The inference engine name. + backend (str): The rewriter will be activated on which backend. + ir (IR): The rewriter will be activated on which IR. + extra_checkers (Checker | List[Checker] | None): Other requirements + defined by Checker. Returns: - nn.Module: THe rewritten model. + nn.Module: The rewritten model. """ - return self._registry.register_object(module_type, backend, **kwargs) + return self._registry.register_object(module_type, backend, ir, + extra_checkers, **kwargs) def patch_model(self, model: nn.Module, cfg: mmcv.Config, backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, recursive: bool = True, **kwargs) -> nn.Module: """Replace the models that was registered. @@ -57,6 +63,7 @@ def patch_model(self, model (torch.nn.Module): The model to patch. cfg (Dict): Config dictionary of deployment. backend (str): The inference engine name. + ir (IR): The intermeditate representation name. recursive (bool): The flag to enable recursive patching. Returns: @@ -67,7 +74,9 @@ def patch_model(self, >>> patched_model = patch_model(model, cfg=deploy_cfg, >>> backend=backend) """ - self._collect_record(backend) + # TODO: Make the type of parameter backend to Backend + env = collect_env(Backend.get(backend), ir) + self._collect_record(env) return self._replace_module(model, cfg, recursive, **kwargs) def _replace_one_module(self, module, cfg, **kwargs): @@ -103,9 +112,9 @@ def _replace_module_impl(model, cfg, **kwargs): return _replace_module_impl(model, cfg, **kwargs) - def _collect_record(self, backend: str): + def _collect_record(self, env: Dict): """Collect models in registry.""" self._records = {} - records = self._registry.get_records(backend) + records = self._registry.get_records(env) for name, kwargs in records: self._records[eval_with_import(name)] = kwargs diff --git a/mmdeploy/core/rewriters/rewriter_manager.py b/mmdeploy/core/rewriters/rewriter_manager.py index df7e82703d..de3acaffd2 100644 --- a/mmdeploy/core/rewriters/rewriter_manager.py +++ b/mmdeploy/core/rewriters/rewriter_manager.py @@ -4,9 +4,10 @@ import mmcv import torch.nn as nn -from mmdeploy.utils.constants import Backend +from mmdeploy.utils.constants import IR, Backend from .function_rewriter import FunctionRewriter from .module_rewriter import ModuleRewriter +from .rewriter_utils import collect_env from .symbolic_rewriter import SymbolicRewriter @@ -18,20 +19,8 @@ def __init__(self): self.function_rewriter = FunctionRewriter() self.symbolic_rewriter = SymbolicRewriter() - def add_backend(self, backend: str): - """Add backend to all rewriters. - - Args: - backend (str): The backend to support. - """ - self.module_rewriter.add_backend(backend) - self.function_rewriter.add_backend(backend) - self.symbolic_rewriter.add_backend(backend) - REWRITER_MANAGER = RewriterManager() -for backend in Backend: - REWRITER_MANAGER.add_backend(backend.value) MODULE_REWRITER = REWRITER_MANAGER.module_rewriter FUNCTION_REWRITER = REWRITER_MANAGER.function_rewriter @@ -41,6 +30,7 @@ def add_backend(self, backend: str): def patch_model(model: nn.Module, cfg: mmcv.Config, backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, recursive: bool = True, **kwargs) -> nn.Module: """Patch the model, replace the modules that can be rewritten. Note that @@ -50,6 +40,7 @@ def patch_model(model: nn.Module, model (torch.nn.Module): The model to patch. cfg (Dict): Config dictionary of deployment. backend (str): The inference engine name. + ir (IR): The intermeditate representation name. recursive (bool): The flag to enable recursive patching. Returns: @@ -59,7 +50,7 @@ def patch_model(model: nn.Module, >>> from mmdeploy.core import patch_model >>> patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) """ - return MODULE_REWRITER.patch_model(model, cfg, backend, recursive, + return MODULE_REWRITER.patch_model(model, cfg, backend, ir, recursive, **kwargs) @@ -71,6 +62,7 @@ class RewriterContext: Args: cfg (Dict): Config dictionary of deployment. backend (str): The inference engine name. + ir (IR): The intermeditate representation name. rewrite_manager (RewriterManager): An RewriteManager that consists of several rewriters @@ -84,20 +76,19 @@ class RewriterContext: def __init__(self, cfg: Dict = dict(), backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, rewriter_manager: RewriterManager = REWRITER_MANAGER, **kwargs): self._cfg = cfg - self._backend = backend self._kwargs = kwargs self._rewriter_manager = rewriter_manager + self._env = collect_env(Backend.get(backend), ir) def enter(self): """Call the enter() of rewriters.""" - self._rewriter_manager.function_rewriter.enter(self._cfg, - self._backend, + self._rewriter_manager.function_rewriter.enter(self._cfg, self._env, **self._kwargs) - self._rewriter_manager.symbolic_rewriter.enter(self._cfg, - self._backend, + self._rewriter_manager.symbolic_rewriter.enter(self._cfg, self._env, **self._kwargs) def exit(self): diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 701078144a..a80fd84738 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple +import warnings +from abc import ABCMeta, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from mmdeploy.utils.constants import Backend +import mmdeploy +from mmdeploy.utils.constants import IR, Backend def eval_with_import(path: str) -> Any: @@ -56,6 +59,127 @@ def import_function(path: str) -> Tuple[Callable, Optional[type]]: return obj, None +def collect_env(backend: Backend, ir: IR, **kwargs) -> Dict: + """Collect current environment information, including backend, ir, codebase + version, etc. Rewriters will be checked according to env infos. + + Args: + backend (Backend): Current backend. + ir (IR): Current IR. + + Returns: + Dict: Record the value of Backend and IR as well as the versions of + libraries. + """ + from mmdeploy.utils import get_backend_version, get_codebase_version + env = dict(backend=backend, ir=ir) + env['mmdeploy'] = mmdeploy.__version__ + env.update(get_backend_version()) + env.update(get_codebase_version()) + env.update(kwargs) + return env + + +class Checker(metaclass=ABCMeta): + """The interface for checking whether a rewriter is valid.""" + + def __init__(self): + pass + + @abstractmethod + def check(self, env: Dict) -> bool: + """Check the if the rewriter is valid according to environment. + + Args: + env (Dict): The backend, IR info and version info. + """ + pass + + +class BackendChecker(Checker): + """Checker that determines which backend the rewriter must run on. + + Args: + required_backend (Backend): The rewriter will be activated on + which backend. + """ + + def __init__(self, required_backend: Backend): + super().__init__() + self.required_backend = required_backend + + def check(self, env: Dict) -> bool: + """Check the if the rewriter is valid according to backend. + + Args: + env (Dict): The backend, IR info and version info. + """ + return env['backend'] == self.required_backend + + +class IRChecker(Checker): + """Checker that determines which IR the rewriter must run on. + + Args: + required_ir (IR): The rewriter will be activated on which IR. + """ + + def __init__(self, required_ir: IR): + super().__init__() + self.required_ir = required_ir + + def check(self, env: Dict) -> bool: + """Check the if the rewriter is valid according to IR. + + Args: + env (Dict): The backend, IR info and version info. + """ + return env['ir'] == self.required_ir + + +class LibVersionChecker(Checker): + """Checker that determines which IR the rewriter must run on. + + Args: + lib (str): The name of library. + min_version (str | None): The rewriter should no lower than which + version. Default to `None`. + max_version (str | None): The rewriter should no greater than which + version. Default to `None`. + """ + + def __init__(self, + lib: str, + min_version: Optional[str] = None, + max_version: Optional[str] = None): + super().__init__() + self.lib = lib + self.min_version = min_version + self.max_version = max_version + + def check(self, env: Dict) -> bool: + """Check the if the rewriter is valid according to library version. + + Args: + env (Dict): The backend, IR info and version info. + """ + # If the library has not been installed + if env[self.lib] is None: + return False + + from packaging import version + valid = True + # The version should no less than min version and no greater than + # max version. + if self.min_version is not None: + if version.parse(env[self.lib]) < version.parse(self.min_version): + valid = False + if self.max_version is not None: + if version.parse(env[self.lib]) > version.parse(self.max_version): + valid = False + return valid + + class RewriterRegistry: """A registry that recoreds rewrite objects. @@ -75,58 +199,128 @@ class RewriterRegistry: >>> records = FUNCTION_REGISTRY.get_record("default") """ - # TODO: replace backend string with "Backend" constant def __init__(self): self._rewrite_records = dict() - self.add_backend(Backend.DEFAULT.value) - - def _check_backend(self, backend: str): - """Check if a backend has been supported.""" - if backend not in self._rewrite_records: - raise Exception('Backend is not supported by registry.') - - def add_backend(self, backend: str): - """Add a backend dictionary.""" - if backend not in self._rewrite_records: - self._rewrite_records[backend] = dict() - - def get_records(self, backend: str) -> List: - """Get all registered records in record table.""" - self._check_backend(backend) - - if backend != Backend.DEFAULT.value: - # Update dict A with dict B. - # Then convert the result dict to a list, while keeping the order - # of A and B: the elements only belong to B should alwarys come - # after the elements only belong to A. - # The complexity is O(n + m). - dict_a = self._rewrite_records[Backend.DEFAULT.value] - dict_b = self._rewrite_records[backend] - records = [] - for k, v in dict_a.items(): - if k in dict_b: - records.append((k, dict_b[k])) + + def get_records(self, env: Dict) -> List: + """Get all registered records that are valid in the given environment + from record table. + + If the backend and IR of rewriter are set to 'default', then the + rewriter is regarded as default rewriter. The default rewriter will be + activated only when all other rewriters are not valid. If there are + multiple rewriters are valid (except default rewriter), we will + activate the first one (The order is determined by the time when + rewriters are loaded). + + Args: + env (dict): Environment dictionary that includes backend, IR, + codebase version, etc. + + Returns: + List: A list that includes valid records. + """ + default_records = list() + records = list() + + for origin_function, rewriter_records in self._rewrite_records.items(): + default_rewriter = None + final_rewriter = None + for record in rewriter_records: + # Get the checkers of current rewriter + checkers: List[Checker] = record['_checkers'] + + # Check if the rewriter is default rewriter + if len(checkers) == 0: + # Process the default rewriter exceptionally + if default_rewriter is None: + default_rewriter = record + else: + warnings.warn( + 'Detect multiple valid rewriters for' + f'{origin_function}, use the first rewriter.') else: - records.append((k, v)) - for k, v in dict_b.items(): - if k not in dict_a: - records.append((k, v)) - else: - records = list( - self._rewrite_records[Backend.DEFAULT.value].items()) - return records - - def _register(self, name: str, backend: str, **kwargs): + # Check if the checker is valid. + # The checker is valid only if all the checks are passed + valid = True + for checker in checkers: + if not checker.check(env): + valid = False + break + + if valid: + # Check if there are multiple valid rewriters + if final_rewriter is not None: + warnings.warn( + 'Detect multiple valid rewriters for' + f'{origin_function}, use the first rewriter.') + else: + final_rewriter = record + + # Append final rewriter. + # If there is no valid rewriter, try not apply default rewriter + if final_rewriter is not None: + records.append((origin_function, final_rewriter)) + elif default_rewriter is not None: + default_records.append((origin_function, default_rewriter)) + + # Make the default records como to the front of list because we may + # want the non-default records to override them. + return default_records + records + + def _register(self, name: str, backend: Backend, ir: IR, + extra_checkers: List[Checker], **kwargs): """The implementation of register.""" - self._check_backend(backend) - self._rewrite_records[backend][name] = kwargs - def register_object(self, name: str, backend: str, **kwargs) -> Callable: - """The decorator to register an object.""" - self._check_backend(backend) + # Merge checkers to kwargs + record_dict = kwargs + + # Try to create a checker according to 'backend' field + if backend != Backend.DEFAULT: + extra_checkers.append(BackendChecker(backend)) + + # Try to create a checker according to 'ir' field + if ir != IR.DEFAULT: + extra_checkers.append(IRChecker(ir)) + + record_dict['_checkers'] = extra_checkers + + # There may be multiple rewriters of a function/module. We use a list + # to store the rewriters of a function/module. + if name not in self._rewrite_records: + self._rewrite_records[name] = list() + self._rewrite_records[name].append(record_dict) + + def register_object(self, + name: str, + backend: str, + ir: IR, + extra_checkers: Optional[Union[Checker, + List[Checker]]] = None, + **kwargs) -> Callable: + """The decorator to register an object. + + Args: + name (str): The import path to access the function/module. + backend (str): The rewriter will be activated on which backend. + ir (IR): The rewriter will be activated on which ir. + extra_chekcers (None | Checker | List[Checker]): Other requirements + for the rewriters. Default to `None`. + + Returns: + Callable: The decorator. + """ + + if extra_checkers is None: + extra_checkers = [] + elif isinstance(extra_checkers, Checker): + extra_checkers = [extra_checkers] + + backend = Backend.get(backend) def decorator(object): - self._register(name, backend, _object=object, **kwargs) + self._register( + name, backend, ir, extra_checkers, _object=object, **kwargs) return object return decorator diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index c9c16d071d..dd47cd8d58 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -1,13 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict, Optional, Sequence +from typing import Callable, Dict, List, Optional, Sequence, Union from torch.autograd import Function from torch.onnx.symbolic_helper import parse_args from torch.onnx.symbolic_registry import _registry as pytorch_registry from torch.onnx.symbolic_registry import register_op -from mmdeploy.utils import Backend, get_root_logger -from .rewriter_utils import ContextCaller, RewriterRegistry, eval_with_import +from mmdeploy.utils import IR, Backend, get_root_logger +from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, + eval_with_import) class SymbolicRewriter: @@ -35,25 +36,27 @@ class SymbolicRewriter: def __init__(self) -> None: self._registry = RewriterRegistry() - def add_backend(self, backend: str): - """Add a backend by calling the _registry.add_backend.""" - self._registry.add_backend(backend) - def register_symbolic(self, func_name: str, backend: str = Backend.DEFAULT.value, is_pytorch: bool = False, arg_descriptors: Optional[Sequence[str]] = None, + ir: IR = IR.DEFAULT, + extra_checkers: Optional[Union[ + Checker, List[Checker]]] = None, **kwargs) -> Callable: """The decorator of the custom symbolic. Args: func_name (str): The function name/path to override the symbolic. - backend (str): The inference engine name. + backend (str): The rewriter will be activated on which backend. is_pytorch (bool): Enable this flag if func_name is the name of \ a pytorch builtin function. arg_descriptors (Sequence[str]): The argument descriptors of the \ symbol. + ir (IR): The rewriter will be activated on which IR. + extra_checkers (Checker | List[Checker] | None): Other requirements + defined by Checker. Returns: Callable: The process of registered symbolic. @@ -61,18 +64,20 @@ def register_symbolic(self, return self._registry.register_object( func_name, backend, + ir, + extra_checkers, is_pytorch=is_pytorch, arg_descriptors=arg_descriptors, **kwargs) def enter(self, cfg: Dict = dict(), - backend: str = Backend.DEFAULT.value, + env: Dict = dict(), opset: int = 11, **kwargs): """The implementation of symbolic register.""" # Get current records - symbolic_records = self._registry.get_records(backend) + symbolic_records = self._registry.get_records(env) self._pytorch_symbolic = list() self._extra_symbolic = list() diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index b4b05bd070..4847ba7b09 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -6,8 +6,9 @@ get_model_inputs, get_onnx_config, get_partition_config, get_task_type, is_dynamic_batch, is_dynamic_shape, load_config) -from .constants import SDK_TASK_MAP, Backend, Codebase, Task +from .constants import IR, SDK_TASK_MAP, Backend, Codebase, Task from .device import parse_cuda_device_id, parse_device_id +from .env import get_backend_version, get_codebase_version, get_library_version from .utils import get_file_path, get_root_logger, target_wrapper __all__ = [ @@ -18,5 +19,6 @@ 'get_model_inputs', 'cfg_apply_marks', 'get_input_shape', 'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config', 'get_backend_config', 'get_root_logger', 'get_dynamic_axes', - 'target_wrapper', 'SDK_TASK_MAP', 'get_file_path' + 'target_wrapper', 'SDK_TASK_MAP', 'get_library_version', + 'get_codebase_version', 'get_backend_version', 'IR', 'get_file_path' ] diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index da07cb28e7..086666b26e 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -37,6 +37,13 @@ class Codebase(AdvancedEnum): MMPOSE = 'mmpose' +class IR(AdvancedEnum): + """Define intermediate representation enumerations.""" + ONNX = 'onnx' + TORCHSCRIPT = 'torchscript' + DEFAULT = 'default' + + class Backend(AdvancedEnum): """Define backend enumerations.""" PYTORCH = 'pytorch' diff --git a/mmdeploy/utils/env.py b/mmdeploy/utils/env.py new file mode 100644 index 0000000000..8cc2cbd3d5 --- /dev/null +++ b/mmdeploy/utils/env.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib + +from mmdeploy.utils import Codebase + + +def get_library_version(lib): + """Try to get the version of a library if it has been installed. + + Args: + lib (str): The name of library. + + Returns: + None | str: If the library has been installed, return version. + """ + try: + lib = importlib.import_module(lib) + except Exception: + version = None + else: + version = lib.__version__ + + return version + + +def get_codebase_version(): + """Get the version dictionary of all supported codebases. + + Returns: + Dict: The name and the version of supported codebases. + """ + version_dict = dict() + for enum in Codebase: + codebase = enum.value + version_dict[codebase] = get_library_version(codebase) + return version_dict + + +def get_backend_version(): + """Get the version dictionary of some supported backend. + + Returns: + Dict: The name and the version of some supported backend. + """ + backend_library_list = ['tensorrt', 'onnxruntime', 'ncnn'] + version_dict = dict() + for backend in backend_library_list: + version_dict[backend] = get_library_version(backend) + return version_dict diff --git a/tests/test_core/test_function_rewriter.py b/tests/test_core/test_function_rewriter.py index b9b43fb688..97a814e929 100644 --- a/tests/test_core/test_function_rewriter.py +++ b/tests/test_core/test_function_rewriter.py @@ -3,7 +3,8 @@ from mmdeploy.core import FUNCTION_REWRITER, RewriterContext from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter -from mmdeploy.utils.constants import Backend +from mmdeploy.core.rewriters.rewriter_utils import collect_env +from mmdeploy.utils.constants import IR, Backend def test_function_rewriter(): @@ -97,7 +98,6 @@ def test_rewrite_homonymic_functions(self): assert package.module.func() == 1 function_rewriter = FunctionRewriter() - function_rewriter.add_backend(Backend.NCNN.value) @function_rewriter.register_rewriter(func_name=path1) def func_2(ctx): @@ -108,7 +108,7 @@ def func_2(ctx): def func_3(ctx): return 3 - function_rewriter.enter(backend=Backend.NCNN.value) + function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) # This is a feature assert package.func() == 2 assert package.module.func() == 3 @@ -118,7 +118,6 @@ def func_3(ctx): assert package.module.func() == 1 function_rewriter2 = FunctionRewriter() - function_rewriter2.add_backend(Backend.NCNN.value) @function_rewriter2.register_rewriter( func_name=path1, backend=Backend.NCNN.value) @@ -129,7 +128,7 @@ def func_4(ctx): def func_5(ctx): return 5 - function_rewriter2.enter(backend=Backend.NCNN.value) + function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) # This is a feature assert package.func() == 4 assert package.module.func() == 5 @@ -146,7 +145,6 @@ def test_rewrite_homonymic_methods(self): c = package.C() function_rewriter = FunctionRewriter() - function_rewriter.add_backend(Backend.NCNN.value) assert c.method() == 1 @@ -159,14 +157,13 @@ def func_2(ctx, self): def func_3(ctx, self): return 3 - function_rewriter.enter(backend=Backend.NCNN.value) + function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) assert c.method() == 3 function_rewriter.exit() assert c.method() == 1 function_rewriter2 = FunctionRewriter() - function_rewriter2.add_backend(Backend.NCNN.value) @function_rewriter2.register_rewriter( func_name=path1, backend=Backend.NCNN.value) @@ -177,7 +174,7 @@ def func_4(ctx, self): def func_5(ctx, self): return 5 - function_rewriter2.enter(backend=Backend.NCNN.value) + function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) assert c.method() == 4 function_rewriter2.exit() @@ -196,7 +193,6 @@ def test_rewrite_derived_methods(): assert derived_obj.method() == 1 function_rewriter = FunctionRewriter() - function_rewriter.add_backend(Backend.NCNN.value) @function_rewriter.register_rewriter(func_name=path1) def func_2(ctx, self): @@ -207,12 +203,12 @@ def func_2(ctx, self): def func_3(ctx, self): return 3 - function_rewriter.enter() + function_rewriter.enter(env=collect_env(Backend.DEFAULT, ir=IR.DEFAULT)) assert base_obj.method() == 2 assert derived_obj.method() == 2 function_rewriter.exit() - function_rewriter.enter(backend=Backend.NCNN.value) + function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) assert base_obj.method() == 2 assert derived_obj.method() == 3 function_rewriter.exit() @@ -221,7 +217,7 @@ def func_3(ctx, self): assert derived_obj.method() == 1 # Check if the recovery is correct - function_rewriter.enter() + function_rewriter.enter(env=collect_env(Backend.DEFAULT, ir=IR.DEFAULT)) assert base_obj.method() == 2 assert derived_obj.method() == 2 function_rewriter.exit() diff --git a/tests/test_core/test_rewriter_registry.py b/tests/test_core/test_rewriter_registry.py deleted file mode 100644 index b577d02623..0000000000 --- a/tests/test_core/test_rewriter_registry.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest - -from mmdeploy.core.rewriters.rewriter_utils import RewriterRegistry -from mmdeploy.utils.constants import Backend - - -def test_check_backend(): - with pytest.raises(Exception): - registry = RewriterRegistry() - registry._check_backend(Backend.ONNXRUNTIME.value) - - -def test_add_backend(): - registry = RewriterRegistry() - registry.add_backend(Backend.ONNXRUNTIME.value) - assert Backend.ONNXRUNTIME.value in registry._rewrite_records - assert Backend.DEFAULT.value in registry._rewrite_records - assert Backend.TENSORRT.value not in registry._rewrite_records - - -def test_register_object(): - registry = RewriterRegistry() - - @registry.register_object('add', backend=Backend.DEFAULT.value) - def add(a, b): - return a + b - - records = registry._rewrite_records[Backend.DEFAULT.value] - assert records is not None - assert records['add'] is not None - assert records['add']['_object'] is not None - add_func = records['add']['_object'] - assert add_func(123, 456) == 123 + 456 - - -def test_get_records(): - registry = RewriterRegistry() - registry.add_backend(Backend.TENSORRT.value) - - @registry.register_object('add', backend=Backend.DEFAULT.value) - def add(a, b): - return a + b - - @registry.register_object('minus', backend=Backend.DEFAULT.value) - def minus(a, b): - return a - b - - @registry.register_object('add', backend=Backend.TENSORRT.value) - def fake_add(a, b): - return a * b - - default_records = dict(registry.get_records(Backend.DEFAULT.value)) - assert default_records['add']['_object'](1, 1) == 2 - assert default_records['minus']['_object'](1, 1) == 0 - - tensorrt_records = dict(registry.get_records(Backend.TENSORRT.value)) - assert tensorrt_records['add']['_object'](1, 1) == 1 - assert tensorrt_records['minus']['_object'](1, 1) == 0 diff --git a/tests/test_core/test_rewriter_utils.py b/tests/test_core/test_rewriter_utils.py new file mode 100644 index 0000000000..4954a573d8 --- /dev/null +++ b/tests/test_core/test_rewriter_utils.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmdeploy +import mmdeploy.core.rewriters.rewriter_utils as rewriter_utils +from mmdeploy.core.rewriters.rewriter_utils import (BackendChecker, + RewriterRegistry, + collect_env) +from mmdeploy.utils.constants import IR, Backend + + +def test_collect_env(): + env_dict = collect_env(Backend.ONNXRUNTIME, IR.ONNX, version='1.0') + assert env_dict['backend'] == Backend.ONNXRUNTIME + assert env_dict['ir'] == IR.ONNX + assert env_dict['version'] == '1.0' + assert env_dict['mmdeploy'] == mmdeploy.__version__ + + +class TestChecker: + env = collect_env(Backend.ONNXRUNTIME, IR.ONNX) + + def test_backend_checker(self): + true_checker = rewriter_utils.BackendChecker(Backend.ONNXRUNTIME) + assert true_checker.check(self.env) is True + + false_checker = rewriter_utils.BackendChecker(Backend.TENSORRT) + assert false_checker.check(self.env) is False + + def test_ir_checker(self): + true_checker = rewriter_utils.IRChecker(IR.ONNX) + assert true_checker.check(self.env) is True + + false_checker = rewriter_utils.IRChecker(IR.TORCHSCRIPT) + assert false_checker.check(self.env) is False + + def test_lib_version_checker(self): + true_checker = rewriter_utils.LibVersionChecker( + 'mmdeploy', mmdeploy.__version__, mmdeploy.__version__) + assert true_checker.check(self.env) is True + + false_checker = rewriter_utils.LibVersionChecker( + 'mmdeploy', max_version='0.0.0') + assert false_checker.check(self.env) is False + + +def test_register_object(): + registry = RewriterRegistry() + checker = rewriter_utils.BackendChecker(Backend.ONNXRUNTIME) + + @registry.register_object( + 'add', + backend=Backend.DEFAULT.value, + ir=IR.DEFAULT, + extra_checkers=checker) + def add(a, b): + return a + b + + records = registry._rewrite_records + assert records is not None + assert records['add'] is not None + assert isinstance(records['add'][0]['_checkers'], list) + assert isinstance(records['add'][0]['_checkers'][0], BackendChecker) + assert records['add'][0]['_object'] is not None + add_func = records['add'][0]['_object'] + assert add_func(123, 456) == 123 + 456 + + +def test_get_records(): + registry = RewriterRegistry() + + @registry.register_object( + 'get_num', backend=Backend.ONNXRUNTIME.value, ir=IR.ONNX) + def get_num_1(): + return 1 + + @registry.register_object( + 'get_num', backend=Backend.ONNXRUNTIME.value, ir=IR.TORCHSCRIPT) + def get_num_2(): + return 2 + + @registry.register_object( + 'get_num', backend=Backend.TENSORRT.value, ir=IR.ONNX) + def get_num_3(): + return 3 + + @registry.register_object( + 'get_num', backend=Backend.TENSORRT.value, ir=IR.TORCHSCRIPT) + def get_num_4(): + return 4 + + @registry.register_object( + 'get_num', backend=Backend.DEFAULT.value, ir=IR.DEFAULT) + def get_num_5(): + return 5 + + records = dict( + registry.get_records(collect_env(Backend.ONNXRUNTIME, IR.ONNX))) + assert records['get_num']['_object']() == 1 + + records = dict( + registry.get_records(collect_env(Backend.ONNXRUNTIME, IR.TORCHSCRIPT))) + assert records['get_num']['_object']() == 2 + + records = dict( + registry.get_records(collect_env(Backend.TENSORRT, IR.ONNX))) + assert records['get_num']['_object']() == 3 + + records = dict( + registry.get_records(collect_env(Backend.TENSORRT, IR.TORCHSCRIPT))) + assert records['get_num']['_object']() == 4 + + records = dict(registry.get_records(collect_env(Backend.NCNN, IR.ONNX))) + assert records['get_num']['_object']() == 5 diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index e9f5ad33c2..ea36c63a69 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import importlib import logging import os import tempfile @@ -440,3 +441,25 @@ def test_get_root_logger(): from mmdeploy.utils import get_root_logger logger = get_root_logger() logger.info('This is a test message') + + +def test_get_library_version(): + assert util.get_library_version('abcdefg') is None + try: + lib = importlib.import_module('setuptools') + except ImportError: + pass + else: + assert util.get_library_version('setuptools') == lib.__version__ + + +def test_get_codebase_version(): + versions = util.get_codebase_version() + for k, v in versions.items(): + assert v == util.get_library_version(k) + + +def test_get_backend_version(): + versions = util.get_backend_version() + for k, v in versions.items(): + assert v == util.get_library_version(k) diff --git a/tools/check_env.py b/tools/check_env.py index 68aa2799e7..3718db1bd5 100644 --- a/tools/check_env.py +++ b/tools/check_env.py @@ -4,49 +4,36 @@ from mmcv.utils import get_git_hash import mmdeploy -from mmdeploy.utils import get_root_logger +from mmdeploy.utils import (get_backend_version, get_codebase_version, + get_root_logger) def collect_env(): """Collect the information of the running environments.""" env_info = collect_base_env() - env_info['MMDeployment'] = f'{mmdeploy.__version__}+{get_git_hash()[:7]}' + env_info['MMDeploy'] = f'{mmdeploy.__version__}+{get_git_hash()[:7]}' return env_info def check_backend(): - try: - import onnxruntime as ort - except ImportError: - ort_version = None - else: - ort_version = ort.__version__ + backend_versions = get_backend_version() + ort_version = backend_versions['onnxruntime'] + trt_version = backend_versions['tensorrt'] + ncnn_version = backend_versions['ncnn'] + import mmdeploy.apis.onnxruntime as ort_apis logger = get_root_logger() - logger.info(f'onnxruntime: {ort_version} ops_is_avaliable : ' + logger.info(f'onnxruntime: {ort_version}\tops_is_avaliable : ' f'{ort_apis.is_available()}') - try: - import tensorrt as trt - except ImportError: - trt_version = None - else: - trt_version = trt.__version__ import mmdeploy.apis.tensorrt as trt_apis - logger.info( - f'tensorrt: {trt_version} ops_is_avaliable : {trt_apis.is_available()}' - ) - - try: - import ncnn - except ImportError: - ncnn_version = None - else: - ncnn_version = ncnn.__version__ + logger.info(f'tensorrt: {trt_version}\tops_is_avaliable : ' + f'{trt_apis.is_available()}') + import mmdeploy.apis.ncnn as ncnn_apis logger.info( - f'ncnn: {ncnn_version} ops_is_avaliable : {ncnn_apis.is_available()}') + f'ncnn: {ncnn_version}\tops_is_avaliable : {ncnn_apis.is_available()}') import mmdeploy.apis.pplnn as pplnn_apis logger.info(f'pplnn_is_avaliable: {pplnn_apis.is_available()}') @@ -56,45 +43,9 @@ def check_backend(): def check_codebase(): - try: - import mmcls - except ImportError: - mmcls_version = None - else: - mmcls_version = mmcls.__version__ - logger.info(f'mmcls: {mmcls_version}') - - try: - import mmdet - except ImportError: - mmdet_version = None - else: - mmdet_version = mmdet.__version__ - logger.info(f'mmdet: {mmdet_version}') - - try: - import mmedit - except ImportError: - mmedit_version = None - else: - mmedit_version = mmedit.__version__ - logger.info(f'mmedit: {mmedit_version}') - - try: - import mmocr - except ImportError: - mmocr_version = None - else: - mmocr_version = mmocr.__version__ - logger.info(f'mmocr: {mmocr_version}') - - try: - import mmseg - except ImportError: - mmseg_version = None - else: - mmseg_version = mmseg.__version__ - logger.info(f'mmseg: {mmseg_version}') + codebase_versions = get_codebase_version() + for k, v in codebase_versions.items(): + logger.info(f'{k}:\t{v}') if __name__ == '__main__':