Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]: Support fcn_unet deployment with dynamic shape #251

Merged
merged 5 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/codebases/mmseg.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl
| DeepLabV3 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
| DeepLabV3+ | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
| Fast-SCNN[*](#static_shape) | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
| UNet[*](#static_shape) | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
| UNet | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
| ANN[*](#static_shape) | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
| APCNet | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
| BiSeNetV1 | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
Expand Down
2 changes: 1 addition & 1 deletion docs/en/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
| Fast-SCNN[*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
| UNet[*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
| ANN[*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmseg/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .decode_heads import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
from .utils import * # noqa: F401,F403
4 changes: 4 additions & 0 deletions mmdeploy/codebase/mmseg/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .up_conv_block import up_conv_block__forward

__all__ = ['up_conv_block__forward']
37 changes: 37 additions & 0 deletions mmdeploy/codebase/mmseg/models/utils/up_conv_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.utils.UpConvBlock.forward')
def up_conv_block__forward(ctx, self, skip, x):
"""Rewrite `forward` for default backend.

To support dynamic shape for UNet backbone,
upsample feature maps with `size` instead of `scale_factor`

Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
skip (Tensor): Skip branch feature.
x (Tensor): Input feature to be upsampled.

Returns:
Tensor: Upsampled output feature map.
"""
if is_dynamic_shape(ctx.cfg):
# upsample with `size` instead of `scale_factor`
from mmseg.ops import Upsample
for c in self.upsample.interp_upsample:
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(c, Upsample):
c.size = skip.shape[-2:]
c.scale_factor = None

x = self.upsample(x)
out = torch.cat([skip, x], dim=1)
out = self.conv_block(out)
return out
38 changes: 30 additions & 8 deletions mmdeploy/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,12 @@ def get_flatten_inputs(
if isinstance(value, torch.Tensor):
flatten_inputs[name] = value
elif isinstance(value, (list, tuple)):
for i, tensor in enumerate(value):
name_i = f'{name}_{i}'
flatten_inputs[name_i] = tensor
if len(value) == 1:
flatten_inputs[name] = value[0]
else:
for i, tensor in enumerate(value):
name_i = f'{name}_{i}'
flatten_inputs[name_i] = tensor
return flatten_inputs


Expand All @@ -358,15 +361,29 @@ def get_onnx_model(wrapped_model: nn.Module,
patched_model = patch_model(
wrapped_model, cfg=deploy_cfg, backend=backend.value)
flatten_model_inputs = get_flatten_inputs(model_inputs)
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
input_names = onnx_cfg.get('input_names', None)
if input_names is None:
input_names = [
k for k, v in flatten_model_inputs.items() if k != 'ctx'
]
output_names = onnx_cfg.get('output_names', None)
dynamic_axes = get_dynamic_axes(deploy_cfg, input_names)

class DummyModel(torch.nn.Module):

def __init__(self):
super(DummyModel, self).__init__()
self.model = patched_model

def forward(self, inputs: dict):
return self.model(**inputs)

model = DummyModel().eval()

with RewriterContext(
cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad():
torch.onnx.export(
patched_model,
tuple([v for k, v in model_inputs.items()]),
model, (model_inputs, {}),
onnx_file_path,
export_params=True,
input_names=input_names,
Expand Down Expand Up @@ -421,8 +438,13 @@ def get_backend_outputs(ir_file_path: str,
"""
backend = get_backend(deploy_cfg)
flatten_model_inputs = get_flatten_inputs(model_inputs)
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
output_names = get_ir_config(deploy_cfg).get('output_names', None)
ir_config = get_ir_config(deploy_cfg)
input_names = ir_config.get('input_names', None)
output_names = ir_config.get('output_names', None)
if input_names is None:
input_names = [
k for k, v in flatten_model_inputs.items() if k != 'ctx'
]

# prepare backend model and input features
if backend == Backend.TENSORRT:
Expand Down
55 changes: 54 additions & 1 deletion tests/test_codebase/test_mmseg/test_mmseg_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mmseg.models.decode_heads.decode_head import BaseDecodeHead

from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils import Backend, Codebase, Task
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs)

Expand Down Expand Up @@ -261,3 +261,56 @@ def test_emamodule_forward(backend):
model_outputs.shape)
assert torch.allclose(
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)


@pytest.mark.parametrize('is_dynamic_shape', [True, False])
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_upconvblock_forward(backend, is_dynamic_shape):
check_backend(backend)
from mmseg.models.backbones.unet import BasicConvBlock
from mmseg.models.utils import UpConvBlock

head = UpConvBlock(BasicConvBlock, 16, 8, 8).eval()
dynamic_axes = {
'x': {
0: 'b',
2: 'h',
3: 'w'
},
'skip': {
0: 'b',
2: 'h',
3: 'w'
},
'output': {
0: 'b',
2: 'h',
3: 'w'
},
} if is_dynamic_shape else None
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend.value),
onnx_config=dict(
input_names=['skip', 'x'],
output_names=['output'],
dynamic_axes=dynamic_axes),
codebase_config=dict(
type=Codebase.MMSEG.value, task=Task.SEGMENTATION.value)))
x = torch.randn(1, 16, 16, 16)
skip = torch.randn(1, 8, 32, 32)
model_inputs = {'x': x, 'skip': skip}
with torch.no_grad():
model_outputs = get_model_outputs(head, 'forward', model_inputs)

wrapped_model = WrapModel(head, 'forward')
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=model_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)