Skip to content

Commit

Permalink
[Enhancement] Make rewriter more powerful (#150)
Browse files Browse the repository at this point in the history
* Finish function tests

* lint

* resolve comments

* Fix tests

* docstring & fix

* Complement informations

* lint

* Add example

* Fix version

* Remove todo

Co-authored-by: RunningLeon <mnsheng@yeah.net>
  • Loading branch information
2 people authored and lvhan028 committed Mar 28, 2022
1 parent 591155c commit 636a97f
Show file tree
Hide file tree
Showing 14 changed files with 560 additions and 249 deletions.
29 changes: 29 additions & 0 deletions mmdeploy/codebase/mmdet/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker
from mmdeploy.utils import load_config


Expand Down Expand Up @@ -69,6 +71,33 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
return x1, y1, x2, y2


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes',
backend='tensorrt',
extra_checkers=LibVersionChecker('tensorrt', min_version='8'))
def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
max_shape: Union[Tensor, Sequence[int]]):
"""Clip bboxes for onnx. From TensorRT 8 we can do the operators on the
tensors directly.
Args:
ctx (ContextCaller): The context with additional information.
x1 (Tensor): The x1 for bounding boxes.
y1 (Tensor): The y1 for bounding boxes.
x2 (Tensor): The x2 for bounding boxes.
y2 (Tensor): The y2 for bounding boxes.
max_shape (Tensor | Sequence[int]): The (H,W) of original image.
Returns:
tuple(Tensor): The clipped x1, y1, x2, y2.
"""
assert len(max_shape) == 2, '`max_shape` should be [h, w]'
x1 = torch.clamp(x1, 0, max_shape[1])
y1 = torch.clamp(y1, 0, max_shape[0])
x2 = torch.clamp(x2, 0, max_shape[1])
y2 = torch.clamp(y2, 0, max_shape[0])
return x1, y1, x2, y2


def pad_with_value(x: Tensor,
pad_dim: int,
pad_size: int,
Expand Down
38 changes: 20 additions & 18 deletions mmdeploy/core/rewriters/function_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict
from typing import Callable, Dict, List, Optional, Union

from mmdeploy.utils import Backend, get_root_logger
from .rewriter_utils import ContextCaller, RewriterRegistry, import_function
from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
import_function)


def _set_func(origin_func_path: str, rewrite_func: Callable):
Expand Down Expand Up @@ -66,32 +67,33 @@ class FunctionRewriter:
def __init__(self):
self._registry = RewriterRegistry()

def add_backend(self, backend: str):
"""Add a backend by calling the _registry.add_backend."""
self._registry.add_backend(backend)

def register_rewriter(self,
func_name: str,
backend: str = Backend.DEFAULT.value,
**kwargs):
def register_rewriter(
self,
func_name: str,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
extra_checkers: Optional[Union[Checker, List[Checker]]] = None,
**kwargs):
"""The interface of function rewriter decorator.
Args:
func_name (str): The function name/path to rewrite.
backend (str): The inference engine name.
backend (str): The rewriter will be activated on which backend.
ir (IR): The rewriter will be activated on which IR.
extra_checkers (Checker | List[Checker] | None): Other requirements
defined by Checker.
Returns:
Callable: The process of registering function.
"""

return self._registry.register_object(func_name, backend, **kwargs)
return self._registry.register_object(func_name, backend, ir,
extra_checkers, **kwargs)

def enter(self,
cfg: Dict = dict(),
backend: str = Backend.DEFAULT.value,
**kwargs):
def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs):
"""The implementation of function rewrite."""
# Get current records
functions_records = self._registry.get_records(backend)
functions_records = self._registry.get_records(env)

self._origin_functions = list()
self._additional_functions = list()
Expand Down
41 changes: 25 additions & 16 deletions mmdeploy/core/rewriters/module_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, List, Optional, Union

import mmcv
from torch import nn

from mmdeploy.utils.constants import Backend
from .rewriter_utils import RewriterRegistry, eval_with_import
from mmdeploy.utils.constants import IR, Backend
from .rewriter_utils import (Checker, RewriterRegistry, collect_env,
eval_with_import)


class ModuleRewriter:
Expand All @@ -26,29 +28,33 @@ class ModuleRewriter:
def __init__(self):
self._registry = RewriterRegistry()

def add_backend(self, backend: str):
"""Add a backend by calling the _registry.add_backend."""
self._registry.add_backend(backend)

def register_rewrite_module(self,
module_type: str,
backend: str = Backend.DEFAULT.value,
**kwargs):
def register_rewrite_module(
self,
module_type: str,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
extra_checkers: Optional[Union[Checker, List[Checker]]] = None,
**kwargs):
"""The interface of module rewriter decorator.
Args:
module_type (str): The module type name to rewrite.
backend (str): The inference engine name.
backend (str): The rewriter will be activated on which backend.
ir (IR): The rewriter will be activated on which IR.
extra_checkers (Checker | List[Checker] | None): Other requirements
defined by Checker.
Returns:
nn.Module: THe rewritten model.
nn.Module: The rewritten model.
"""
return self._registry.register_object(module_type, backend, **kwargs)
return self._registry.register_object(module_type, backend, ir,
extra_checkers, **kwargs)

def patch_model(self,
model: nn.Module,
cfg: mmcv.Config,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
recursive: bool = True,
**kwargs) -> nn.Module:
"""Replace the models that was registered.
Expand All @@ -57,6 +63,7 @@ def patch_model(self,
model (torch.nn.Module): The model to patch.
cfg (Dict): Config dictionary of deployment.
backend (str): The inference engine name.
ir (IR): The intermeditate representation name.
recursive (bool): The flag to enable recursive patching.
Returns:
Expand All @@ -67,7 +74,9 @@ def patch_model(self,
>>> patched_model = patch_model(model, cfg=deploy_cfg,
>>> backend=backend)
"""
self._collect_record(backend)
# TODO: Make the type of parameter backend to Backend
env = collect_env(Backend.get(backend), ir)
self._collect_record(env)
return self._replace_module(model, cfg, recursive, **kwargs)

def _replace_one_module(self, module, cfg, **kwargs):
Expand Down Expand Up @@ -103,9 +112,9 @@ def _replace_module_impl(model, cfg, **kwargs):

return _replace_module_impl(model, cfg, **kwargs)

def _collect_record(self, backend: str):
def _collect_record(self, env: Dict):
"""Collect models in registry."""
self._records = {}
records = self._registry.get_records(backend)
records = self._registry.get_records(env)
for name, kwargs in records:
self._records[eval_with_import(name)] = kwargs
29 changes: 10 additions & 19 deletions mmdeploy/core/rewriters/rewriter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import mmcv
import torch.nn as nn

from mmdeploy.utils.constants import Backend
from mmdeploy.utils.constants import IR, Backend
from .function_rewriter import FunctionRewriter
from .module_rewriter import ModuleRewriter
from .rewriter_utils import collect_env
from .symbolic_rewriter import SymbolicRewriter


Expand All @@ -18,20 +19,8 @@ def __init__(self):
self.function_rewriter = FunctionRewriter()
self.symbolic_rewriter = SymbolicRewriter()

def add_backend(self, backend: str):
"""Add backend to all rewriters.
Args:
backend (str): The backend to support.
"""
self.module_rewriter.add_backend(backend)
self.function_rewriter.add_backend(backend)
self.symbolic_rewriter.add_backend(backend)


REWRITER_MANAGER = RewriterManager()
for backend in Backend:
REWRITER_MANAGER.add_backend(backend.value)

MODULE_REWRITER = REWRITER_MANAGER.module_rewriter
FUNCTION_REWRITER = REWRITER_MANAGER.function_rewriter
Expand All @@ -41,6 +30,7 @@ def add_backend(self, backend: str):
def patch_model(model: nn.Module,
cfg: mmcv.Config,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
recursive: bool = True,
**kwargs) -> nn.Module:
"""Patch the model, replace the modules that can be rewritten. Note that
Expand All @@ -50,6 +40,7 @@ def patch_model(model: nn.Module,
model (torch.nn.Module): The model to patch.
cfg (Dict): Config dictionary of deployment.
backend (str): The inference engine name.
ir (IR): The intermeditate representation name.
recursive (bool): The flag to enable recursive patching.
Returns:
Expand All @@ -59,7 +50,7 @@ def patch_model(model: nn.Module,
>>> from mmdeploy.core import patch_model
>>> patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
"""
return MODULE_REWRITER.patch_model(model, cfg, backend, recursive,
return MODULE_REWRITER.patch_model(model, cfg, backend, ir, recursive,
**kwargs)


Expand All @@ -71,6 +62,7 @@ class RewriterContext:
Args:
cfg (Dict): Config dictionary of deployment.
backend (str): The inference engine name.
ir (IR): The intermeditate representation name.
rewrite_manager (RewriterManager): An RewriteManager that consists of
several rewriters
Expand All @@ -84,20 +76,19 @@ class RewriterContext:
def __init__(self,
cfg: Dict = dict(),
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
rewriter_manager: RewriterManager = REWRITER_MANAGER,
**kwargs):
self._cfg = cfg
self._backend = backend
self._kwargs = kwargs
self._rewriter_manager = rewriter_manager
self._env = collect_env(Backend.get(backend), ir)

def enter(self):
"""Call the enter() of rewriters."""
self._rewriter_manager.function_rewriter.enter(self._cfg,
self._backend,
self._rewriter_manager.function_rewriter.enter(self._cfg, self._env,
**self._kwargs)
self._rewriter_manager.symbolic_rewriter.enter(self._cfg,
self._backend,
self._rewriter_manager.symbolic_rewriter.enter(self._cfg, self._env,
**self._kwargs)

def exit(self):
Expand Down
Loading

0 comments on commit 636a97f

Please sign in to comment.