Skip to content

Commit

Permalink
[Enhancement]: Support fcn_unet deployment with dynamic shape (#251)
Browse files Browse the repository at this point in the history
* support mmseg fcn+unet dynamic shape

* add test

* fix ci

* fix units

* resolve comments
  • Loading branch information
RunningLeon authored Mar 24, 2022
1 parent d9eeaba commit ed2ec9d
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/en/codebases/mmseg.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl
| DeepLabV3 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
| DeepLabV3+ | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
| Fast-SCNN[*](#static_shape) | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
| UNet[*](#static_shape) | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
| UNet | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
| ANN[*](#static_shape) | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
| APCNet | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
| BiSeNetV1 | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
Expand Down
2 changes: 1 addition & 1 deletion docs/en/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
| Fast-SCNN[*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
| UNet[*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
| ANN[*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmseg/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .decode_heads import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
from .utils import * # noqa: F401,F403
4 changes: 4 additions & 0 deletions mmdeploy/codebase/mmseg/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .up_conv_block import up_conv_block__forward

__all__ = ['up_conv_block__forward']
40 changes: 40 additions & 0 deletions mmdeploy/codebase/mmseg/models/utils/up_conv_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.utils.UpConvBlock.forward')
def up_conv_block__forward(ctx, self, skip, x):
"""Rewrite `forward` for default backend.
To support dynamic shape for UNet backbone,
upsample feature maps with `size` instead of `scale_factor`
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
skip (Tensor): Skip branch feature.
x (Tensor): Input feature to be upsampled.
Returns:
Tensor: Upsampled output feature map.
"""
from mmcv.cnn import ConvModule

# only valid when self.upsample is from build_upsample_layer
if is_dynamic_shape(ctx.cfg) and not isinstance(self.upsample, ConvModule):
# upsample with `size` instead of `scale_factor`
from mmseg.ops import Upsample
for c in self.upsample.interp_upsample:
if isinstance(c, Upsample):
c.size = skip.shape[-2:]
c.scale_factor = None

x = self.upsample(x)
out = torch.cat([skip, x], dim=1)
out = self.conv_block(out)
return out
38 changes: 30 additions & 8 deletions mmdeploy/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,12 @@ def get_flatten_inputs(
if isinstance(value, torch.Tensor):
flatten_inputs[name] = value
elif isinstance(value, (list, tuple)):
for i, tensor in enumerate(value):
name_i = f'{name}_{i}'
flatten_inputs[name_i] = tensor
if len(value) == 1:
flatten_inputs[name] = value[0]
else:
for i, tensor in enumerate(value):
name_i = f'{name}_{i}'
flatten_inputs[name_i] = tensor
return flatten_inputs


Expand All @@ -358,15 +361,29 @@ def get_onnx_model(wrapped_model: nn.Module,
patched_model = patch_model(
wrapped_model, cfg=deploy_cfg, backend=backend.value)
flatten_model_inputs = get_flatten_inputs(model_inputs)
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
input_names = onnx_cfg.get('input_names', None)
if input_names is None:
input_names = [
k for k, v in flatten_model_inputs.items() if k != 'ctx'
]
output_names = onnx_cfg.get('output_names', None)
dynamic_axes = get_dynamic_axes(deploy_cfg, input_names)

class DummyModel(torch.nn.Module):

def __init__(self):
super(DummyModel, self).__init__()
self.model = patched_model

def forward(self, inputs: dict):
return self.model(**inputs)

model = DummyModel().eval()

with RewriterContext(
cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad():
torch.onnx.export(
patched_model,
tuple([v for k, v in model_inputs.items()]),
model, (model_inputs, {}),
onnx_file_path,
export_params=True,
input_names=input_names,
Expand Down Expand Up @@ -421,8 +438,13 @@ def get_backend_outputs(ir_file_path: str,
"""
backend = get_backend(deploy_cfg)
flatten_model_inputs = get_flatten_inputs(model_inputs)
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
output_names = get_ir_config(deploy_cfg).get('output_names', None)
ir_config = get_ir_config(deploy_cfg)
input_names = ir_config.get('input_names', None)
output_names = ir_config.get('output_names', None)
if input_names is None:
input_names = [
k for k, v in flatten_model_inputs.items() if k != 'ctx'
]

# prepare backend model and input features
if backend == Backend.TENSORRT:
Expand Down
55 changes: 54 additions & 1 deletion tests/test_codebase/test_mmseg/test_mmseg_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mmseg.models.decode_heads.decode_head import BaseDecodeHead

from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils import Backend, Codebase, Task
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs)

Expand Down Expand Up @@ -261,3 +261,56 @@ def test_emamodule_forward(backend):
model_outputs.shape)
assert torch.allclose(
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)


@pytest.mark.parametrize('is_dynamic_shape', [True, False])
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_upconvblock_forward(backend, is_dynamic_shape):
check_backend(backend)
from mmseg.models.backbones.unet import BasicConvBlock
from mmseg.models.utils import UpConvBlock

head = UpConvBlock(BasicConvBlock, 16, 8, 8).eval()
dynamic_axes = {
'x': {
0: 'b',
2: 'h',
3: 'w'
},
'skip': {
0: 'b',
2: 'h',
3: 'w'
},
'output': {
0: 'b',
2: 'h',
3: 'w'
},
} if is_dynamic_shape else None
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend.value),
onnx_config=dict(
input_names=['skip', 'x'],
output_names=['output'],
dynamic_axes=dynamic_axes),
codebase_config=dict(
type=Codebase.MMSEG.value, task=Task.SEGMENTATION.value)))
x = torch.randn(1, 16, 16, 16)
skip = torch.randn(1, 8, 32, 32)
model_inputs = {'x': x, 'skip': skip}
with torch.no_grad():
model_outputs = get_model_outputs(head, 'forward', model_inputs)

wrapped_model = WrapModel(head, 'forward')
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=model_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)

0 comments on commit ed2ec9d

Please sign in to comment.