Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Make rewriter more powerful #150

Merged
merged 13 commits into from
Mar 1, 2022
29 changes: 29 additions & 0 deletions mmdeploy/codebase/mmdet/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker
from mmdeploy.utils import load_config


Expand Down Expand Up @@ -69,6 +71,33 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
return x1, y1, x2, y2


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes',
backend='tensorrt',
extra_checkers=LibVersionChecker('tensorrt', min_version='8'))
def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
max_shape: Union[Tensor, Sequence[int]]):
"""Clip bboxes for onnx. From TensorRT 8 we can do the operators on the
tensors directly.

Args:
ctx (ContextCaller): The context with additional information.
x1 (Tensor): The x1 for bounding boxes.
y1 (Tensor): The y1 for bounding boxes.
x2 (Tensor): The x2 for bounding boxes.
y2 (Tensor): The y2 for bounding boxes.
max_shape (Tensor | Sequence[int]): The (H,W) of original image.
Returns:
tuple(Tensor): The clipped x1, y1, x2, y2.
"""
assert len(max_shape) == 2, '`max_shape` should be [h, w]'
x1 = torch.clamp(x1, 0, max_shape[1])
y1 = torch.clamp(y1, 0, max_shape[0])
x2 = torch.clamp(x2, 0, max_shape[1])
y2 = torch.clamp(y2, 0, max_shape[0])
return x1, y1, x2, y2


def pad_with_value(x: Tensor,
pad_dim: int,
pad_size: int,
Expand Down
38 changes: 20 additions & 18 deletions mmdeploy/core/rewriters/function_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict
from typing import Callable, Dict, List, Optional, Union

from mmdeploy.utils import Backend, get_root_logger
from .rewriter_utils import ContextCaller, RewriterRegistry, import_function
from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
import_function)


def _set_func(origin_func_path: str, rewrite_func: Callable):
Expand Down Expand Up @@ -66,32 +67,33 @@ class FunctionRewriter:
def __init__(self):
self._registry = RewriterRegistry()

def add_backend(self, backend: str):
"""Add a backend by calling the _registry.add_backend."""
self._registry.add_backend(backend)

def register_rewriter(self,
func_name: str,
backend: str = Backend.DEFAULT.value,
**kwargs):
def register_rewriter(
self,
func_name: str,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
extra_checkers: Optional[Union[Checker, List[Checker]]] = None,
**kwargs):
"""The interface of function rewriter decorator.

Args:
func_name (str): The function name/path to rewrite.
backend (str): The inference engine name.
backend (str): The rewriter will be activated on which backend.
SingleZombie marked this conversation as resolved.
Show resolved Hide resolved
ir (IR): The rewriter will be activated on which IR.
extra_checkers (Checker | List[Checker] | None): Other requirements
defined by Checker.

Returns:
Callable: The process of registering function.
"""

return self._registry.register_object(func_name, backend, **kwargs)
return self._registry.register_object(func_name, backend, ir,
extra_checkers, **kwargs)

def enter(self,
cfg: Dict = dict(),
backend: str = Backend.DEFAULT.value,
**kwargs):
def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs):
"""The implementation of function rewrite."""
# Get current records
functions_records = self._registry.get_records(backend)
functions_records = self._registry.get_records(env)

self._origin_functions = list()
self._additional_functions = list()
Expand Down
41 changes: 25 additions & 16 deletions mmdeploy/core/rewriters/module_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, List, Optional, Union

import mmcv
from torch import nn

from mmdeploy.utils.constants import Backend
from .rewriter_utils import RewriterRegistry, eval_with_import
from mmdeploy.utils.constants import IR, Backend
from .rewriter_utils import (Checker, RewriterRegistry, collect_env,
eval_with_import)


class ModuleRewriter:
Expand All @@ -26,29 +28,33 @@ class ModuleRewriter:
def __init__(self):
self._registry = RewriterRegistry()

def add_backend(self, backend: str):
"""Add a backend by calling the _registry.add_backend."""
self._registry.add_backend(backend)

def register_rewrite_module(self,
module_type: str,
backend: str = Backend.DEFAULT.value,
**kwargs):
def register_rewrite_module(
self,
module_type: str,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
extra_checkers: Optional[Union[Checker, List[Checker]]] = None,
**kwargs):
"""The interface of module rewriter decorator.

Args:
module_type (str): The module type name to rewrite.
backend (str): The inference engine name.
backend (str): The rewriter will be activated on which backend.
ir (IR): The rewriter will be activated on which IR.
extra_checkers (Checker | List[Checker] | None): Other requirements
defined by Checker.

Returns:
nn.Module: THe rewritten model.
nn.Module: The rewritten model.
"""
return self._registry.register_object(module_type, backend, **kwargs)
return self._registry.register_object(module_type, backend, ir,
extra_checkers, **kwargs)

def patch_model(self,
model: nn.Module,
cfg: mmcv.Config,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
recursive: bool = True,
**kwargs) -> nn.Module:
"""Replace the models that was registered.
Expand All @@ -57,6 +63,7 @@ def patch_model(self,
model (torch.nn.Module): The model to patch.
cfg (Dict): Config dictionary of deployment.
backend (str): The inference engine name.
ir (IR): The intermeditate representation name.
recursive (bool): The flag to enable recursive patching.

Returns:
Expand All @@ -67,7 +74,9 @@ def patch_model(self,
>>> patched_model = patch_model(model, cfg=deploy_cfg,
>>> backend=backend)
"""
self._collect_record(backend)
# TODO: Make the type of parameter backend to Backend
SingleZombie marked this conversation as resolved.
Show resolved Hide resolved
env = collect_env(Backend.get(backend), ir)
self._collect_record(env)
return self._replace_module(model, cfg, recursive, **kwargs)

def _replace_one_module(self, module, cfg, **kwargs):
Expand Down Expand Up @@ -103,9 +112,9 @@ def _replace_module_impl(model, cfg, **kwargs):

return _replace_module_impl(model, cfg, **kwargs)

def _collect_record(self, backend: str):
def _collect_record(self, env: Dict):
"""Collect models in registry."""
self._records = {}
records = self._registry.get_records(backend)
records = self._registry.get_records(env)
for name, kwargs in records:
self._records[eval_with_import(name)] = kwargs
30 changes: 11 additions & 19 deletions mmdeploy/core/rewriters/rewriter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import mmcv
import torch.nn as nn

from mmdeploy.utils.constants import Backend
from mmdeploy.utils.constants import IR, Backend
from .function_rewriter import FunctionRewriter
from .module_rewriter import ModuleRewriter
from .rewriter_utils import collect_env
from .symbolic_rewriter import SymbolicRewriter


Expand All @@ -18,20 +19,8 @@ def __init__(self):
self.function_rewriter = FunctionRewriter()
self.symbolic_rewriter = SymbolicRewriter()

def add_backend(self, backend: str):
"""Add backend to all rewriters.

Args:
backend (str): The backend to support.
"""
self.module_rewriter.add_backend(backend)
self.function_rewriter.add_backend(backend)
self.symbolic_rewriter.add_backend(backend)


REWRITER_MANAGER = RewriterManager()
for backend in Backend:
REWRITER_MANAGER.add_backend(backend.value)

MODULE_REWRITER = REWRITER_MANAGER.module_rewriter
FUNCTION_REWRITER = REWRITER_MANAGER.function_rewriter
Expand All @@ -41,6 +30,7 @@ def add_backend(self, backend: str):
def patch_model(model: nn.Module,
cfg: mmcv.Config,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
recursive: bool = True,
**kwargs) -> nn.Module:
"""Patch the model, replace the modules that can be rewritten. Note that
Expand All @@ -50,6 +40,7 @@ def patch_model(model: nn.Module,
model (torch.nn.Module): The model to patch.
cfg (Dict): Config dictionary of deployment.
backend (str): The inference engine name.
ir (IR): The intermeditate representation name.
recursive (bool): The flag to enable recursive patching.

Returns:
Expand All @@ -59,7 +50,7 @@ def patch_model(model: nn.Module,
>>> from mmdeploy.core import patch_model
>>> patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
"""
return MODULE_REWRITER.patch_model(model, cfg, backend, recursive,
return MODULE_REWRITER.patch_model(model, cfg, backend, ir, recursive,
**kwargs)


Expand All @@ -71,6 +62,7 @@ class RewriterContext:
Args:
cfg (Dict): Config dictionary of deployment.
backend (str): The inference engine name.
ir (IR): The intermeditate representation name.
rewrite_manager (RewriterManager): An RewriteManager that consists of
several rewriters

Expand All @@ -84,20 +76,20 @@ class RewriterContext:
def __init__(self,
cfg: Dict = dict(),
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
rewriter_manager: RewriterManager = REWRITER_MANAGER,
**kwargs):
self._cfg = cfg
self._backend = backend
self._kwargs = kwargs
self._rewriter_manager = rewriter_manager
# TODO: Make the type of parameter backend to Backend
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
self._env = collect_env(Backend.get(backend), ir)

def enter(self):
"""Call the enter() of rewriters."""
self._rewriter_manager.function_rewriter.enter(self._cfg,
self._backend,
self._rewriter_manager.function_rewriter.enter(self._cfg, self._env,
**self._kwargs)
self._rewriter_manager.symbolic_rewriter.enter(self._cfg,
self._backend,
self._rewriter_manager.symbolic_rewriter.enter(self._cfg, self._env,
**self._kwargs)

def exit(self):
Expand Down
Loading