Skip to content

Commit

Permalink
Fix bug (open-mmlab#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
SingleZombie authored Nov 25, 2021
1 parent 3755966 commit 64ee8db
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
24 changes: 13 additions & 11 deletions mmdeploy/core/rewriters/rewriter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class RewriterManager:

def __init__(self):
self.module_rewriter = ModuleRewriter()
self.function_rewrite = FunctionRewriter()
self.function_rewriter = FunctionRewriter()
self.symbolic_rewriter = SymbolicRewriter()

def add_backend(self, backend: str):
Expand All @@ -23,7 +23,7 @@ def add_backend(self, backend: str):
backend (str): The backend to support.
"""
self.module_rewriter.add_backend(backend)
self.function_rewrite.add_backend(backend)
self.function_rewriter.add_backend(backend)
self.symbolic_rewriter.add_backend(backend)


Expand All @@ -32,7 +32,7 @@ def add_backend(self, backend: str):
REWRITER_MANAGER.add_backend(backend.value)

MODULE_REWRITER = REWRITER_MANAGER.module_rewriter
FUNCTION_REWRITER = REWRITER_MANAGER.function_rewrite
FUNCTION_REWRITER = REWRITER_MANAGER.function_rewriter
SYMBOLIC_REWRITER = REWRITER_MANAGER.symbolic_rewriter


Expand Down Expand Up @@ -81,24 +81,26 @@ class RewriterContext:
def __init__(self,
cfg: Dict = dict(),
backend: str = Backend.DEFAULT.value,
rewrite_manager: RewriterManager = REWRITER_MANAGER,
rewriter_manager: RewriterManager = REWRITER_MANAGER,
**kwargs):
self._cfg = cfg
self._backend = backend
self._kwargs = kwargs
self._rewrite_manager = rewrite_manager
self._rewriter_manager = rewriter_manager

def enter(self):
"""Call the enter() of rewriters."""
self._rewrite_manager.function_rewrite.enter(self._cfg, self._backend,
**self._kwargs)
self._rewrite_manager.symbolic_rewriter.enter(self._cfg, self._backend,
**self._kwargs)
self._rewriter_manager.function_rewriter.enter(self._cfg,
self._backend,
**self._kwargs)
self._rewriter_manager.symbolic_rewriter.enter(self._cfg,
self._backend,
**self._kwargs)

def exit(self):
"""Call the exit() of rewriters."""
self._rewrite_manager.function_rewrite.exit()
self._rewrite_manager.symbolic_rewriter.exit()
self._rewriter_manager.function_rewriter.exit()
self._rewriter_manager.symbolic_rewriter.exit()

def __enter__(self):
"""Call enter()"""
Expand Down
2 changes: 2 additions & 0 deletions mmdeploy/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from torch import nn

# Register the rewrite functions
import mmdeploy.codebase # noqa: F401,F403
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import Backend, get_backend, get_onnx_config

Expand Down

0 comments on commit 64ee8db

Please sign in to comment.