Skip to content

Commit

Permalink
[Docs]: Add how_to_support_new_model document (open-mmlab#123)
Browse files Browse the repository at this point in the history
* Add how_to_support_new_model document

* update follow comments

* fix for comment

* rename
  • Loading branch information
grimoire authored Oct 18, 2021
1 parent 26cf2bd commit fef9d9e
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ You can switch between Chinese and English documents in the lower-left corner of
tutorials/how_to_convert_model.md
tutorials/how_to_create_config.md
tutorials/how_to_test_model.md
tutorials/how_to_support_new_model.md
tutorials/how_to_support_new_models.md

.. toctree::
:maxdepth: 1
Expand Down
1 change: 0 additions & 1 deletion docs/tutorials/how_to_support_new_model.md

This file was deleted.

93 changes: 93 additions & 0 deletions docs/tutorials/how_to_support_new_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
## How to support new models

We provide several tools to support model conversion.

### Function Rewriter

The PyTorch neural network is written in python that eases the development of the algorithm. But the use of Python control flow and third-party libraries make it difficult to export the network to an intermediate representation. We provide a 'monkey patch' tool to rewrite the unsupported function to another one that can be exported. Here is an example:

```python
from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(ctx, input, *size):
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
else:
return origin_func(input, *size)
```

It is easy to use the function rewriter. Just add a decorator with arguments:

- `func_name` is the function to override. It can be either a PyTorch function or a custom function. Methods in modules can also be overridden by this tool.
- `backend` is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.

The arguments are the same as the original function, except a context `ctx` as the first argument. The context provides some useful information such as the deployment config `ctx.cfg` and the original function (which has been overridden) `ctx.origin_func`.

### Module Rewriter

If you want to replace a whole module with another one, we have another rewriter as follows:

```python
@MODULE_REWRITER.register_rewrite_module(
'mmedit.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
class SRCNNWrapper(nn.Module):

def __init__(self,
module,
cfg,
channels=(3, 64, 32, 3),
kernel_sizes=(9, 1, 5),
upscale_factor=4):
super(SRCNNWrapper, self).__init__()

self._module = module

module.img_upsampler = nn.Upsample(
scale_factor=module.upscale_factor,
mode='bilinear',
align_corners=False)

def forward(self, *args, **kwargs):
"""Run forward."""
return self._module(*args, **kwargs)

def init_weights(self, *args, **kwargs):
"""Initialize weights."""
return self._module.init_weights(*args, **kwargs)
```

Just like function rewriter, add a decorator with arguments:

- `module_type` the module class to rewrite.
- `backend` is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.

All instances of the module in the network will be replaced with instances of this new class. The original module and the deployment config will be passed as the first two arguments.

### Custom Symbolic

The mappings between PyTorch and ONNX are defined in PyTorch with symbolic functions. The custom symbolic function can help us to bypass some ONNX nodes which are unsupported by inference engine.

```python
@SYMBOLIC_REGISTER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
return g.op('Squeeze', self, axes_i=dims)
```

The decorator arguments:

- `func_name` The function name to add symbolic. Use full path if it is a custom `torch.autograd.Function`. Or just a name if it is a PyTorch built-in function.
- `backend` is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.
- `is_pytorch` True if the function is a PyTorch built-in function.
- `arg_descriptors` the descriptors of the symbolic function arguments. Will be feed to `torch.onnx.symbolic_helper._parse_arg`.

Just like function rewriter, there is a context `ctx` as the first argument. The context provides some useful information such as the deployment config `ctx.cfg` and the original function (which has been overridden) `ctx.origin_func`. Note that the `ctx.origin_func` can be used only when `is_pytorch==False`.
2 changes: 1 addition & 1 deletion docs_zh_CN/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
tutorials/how_to_convert_model.md
tutorials/how_to_create_config.md
tutorials/how_to_test_model.md
tutorials/how_to_support_new_model.md
tutorials/how_to_support_new_models.md

.. toctree::
:maxdepth: 1
Expand Down

0 comments on commit fef9d9e

Please sign in to comment.