Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Make rewriter more powerful #150

Merged
merged 13 commits into from
Mar 1, 2022
35 changes: 17 additions & 18 deletions mmdeploy/core/rewriters/function_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict
from typing import Callable, Dict, List, Optional, Union

from mmdeploy.utils import Backend, get_root_logger
from .rewriter_utils import ContextCaller, RewriterRegistry, import_function
from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
import_function)


def _set_func(origin_func_path: str, rewrite_func: Callable):
Expand Down Expand Up @@ -66,32 +67,30 @@ class FunctionRewriter:
def __init__(self):
self._registry = RewriterRegistry()

def add_backend(self, backend: str):
"""Add a backend by calling the _registry.add_backend."""
self._registry.add_backend(backend)

def register_rewriter(self,
func_name: str,
backend: str = Backend.DEFAULT.value,
**kwargs):
def register_rewriter(
self,
func_name: str,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
extra_checkers: Optional[Union[Checker, List[Checker]]] = None,
**kwargs):
"""The interface of function rewriter decorator.

Args:
func_name (str): The function name/path to rewrite.
backend (str): The inference engine name.
backend (str): The rewriter will be activated on which backend.
SingleZombie marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Callable: The process of registering function.
"""

return self._registry.register_object(func_name, backend, **kwargs)
return self._registry.register_object(func_name, backend, ir,
extra_checkers, **kwargs)

def enter(self,
cfg: Dict = dict(),
backend: str = Backend.DEFAULT.value,
**kwargs):
def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs):
"""The implementation of function rewrite."""
# Get current records
functions_records = self._registry.get_records(backend)
functions_records = self._registry.get_records(env)

self._origin_functions = list()
self._additional_functions = list()
Expand Down
35 changes: 20 additions & 15 deletions mmdeploy/core/rewriters/module_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, List, Optional, Union

import mmcv
from torch import nn

from mmdeploy.utils.constants import Backend
from .rewriter_utils import RewriterRegistry, eval_with_import
from mmdeploy.utils.constants import IR, Backend
from .rewriter_utils import (Checker, RewriterRegistry, collect_env,
eval_with_import)


class ModuleRewriter:
Expand All @@ -26,30 +28,31 @@ class ModuleRewriter:
def __init__(self):
self._registry = RewriterRegistry()

def add_backend(self, backend: str):
"""Add a backend by calling the _registry.add_backend."""
self._registry.add_backend(backend)

def register_rewrite_module(self,
module_type: str,
backend: str = Backend.DEFAULT.value,
**kwargs):
def register_rewrite_module(
self,
module_type: str,
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
extra_checkers: Optional[Union[Checker, List[Checker]]] = None,
**kwargs):
"""The interface of module rewriter decorator.

Args:
module_type (str): The module type name to rewrite.
backend (str): The inference engine name.

Returns:
nn.Module: THe rewritten model.
nn.Module: The rewritten model.
"""
return self._registry.register_object(module_type, backend, **kwargs)
return self._registry.register_object(module_type, backend, ir,
extra_checkers, **kwargs)

def patch_model(self,
model: nn.Module,
cfg: mmcv.Config,
backend: str = Backend.DEFAULT.value,
recursive: bool = True,
ir: IR = IR.DEFAULT,
**kwargs) -> nn.Module:
"""Replace the models that was registered.

Expand All @@ -67,7 +70,9 @@ def patch_model(self,
>>> patched_model = patch_model(model, cfg=deploy_cfg,
>>> backend=backend)
"""
self._collect_record(backend)
# TODO: Make the type of parameter backend to Backend
SingleZombie marked this conversation as resolved.
Show resolved Hide resolved
env = collect_env(Backend.get(backend), ir)
self._collect_record(env)
return self._replace_module(model, cfg, recursive, **kwargs)

def _replace_one_module(self, module, cfg, **kwargs):
Expand Down Expand Up @@ -103,9 +108,9 @@ def _replace_module_impl(model, cfg, **kwargs):

return _replace_module_impl(model, cfg, **kwargs)

def _collect_record(self, backend: str):
def _collect_record(self, env: Dict):
"""Collect models in registry."""
self._records = {}
records = self._registry.get_records(backend)
records = self._registry.get_records(env)
for name, kwargs in records:
self._records[eval_with_import(name)] = kwargs
25 changes: 7 additions & 18 deletions mmdeploy/core/rewriters/rewriter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import mmcv
import torch.nn as nn

from mmdeploy.utils.constants import Backend
from mmdeploy.utils.constants import IR, Backend
from .function_rewriter import FunctionRewriter
from .module_rewriter import ModuleRewriter
from .rewriter_utils import collect_env
from .symbolic_rewriter import SymbolicRewriter


Expand All @@ -18,20 +19,8 @@ def __init__(self):
self.function_rewriter = FunctionRewriter()
self.symbolic_rewriter = SymbolicRewriter()

def add_backend(self, backend: str):
"""Add backend to all rewriters.

Args:
backend (str): The backend to support.
"""
self.module_rewriter.add_backend(backend)
self.function_rewriter.add_backend(backend)
self.symbolic_rewriter.add_backend(backend)


REWRITER_MANAGER = RewriterManager()
for backend in Backend:
REWRITER_MANAGER.add_backend(backend.value)

MODULE_REWRITER = REWRITER_MANAGER.module_rewriter
FUNCTION_REWRITER = REWRITER_MANAGER.function_rewriter
Expand Down Expand Up @@ -84,20 +73,20 @@ class RewriterContext:
def __init__(self,
cfg: Dict = dict(),
backend: str = Backend.DEFAULT.value,
ir: IR = IR.DEFAULT,
rewriter_manager: RewriterManager = REWRITER_MANAGER,
**kwargs):
self._cfg = cfg
self._backend = backend
self._kwargs = kwargs
self._rewriter_manager = rewriter_manager
# TODO: Make the type of parameter backend to Backend
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
self._env = collect_env(Backend.get(backend), ir)

def enter(self):
"""Call the enter() of rewriters."""
self._rewriter_manager.function_rewriter.enter(self._cfg,
self._backend,
self._rewriter_manager.function_rewriter.enter(self._cfg, self._env,
**self._kwargs)
self._rewriter_manager.symbolic_rewriter.enter(self._cfg,
self._backend,
self._rewriter_manager.symbolic_rewriter.enter(self._cfg, self._env,
**self._kwargs)

def exit(self):
Expand Down
Loading