diff --git a/docs/en/codebases/mmseg.md b/docs/en/codebases/mmseg.md index a2b6859cac..8cb30994b0 100644 --- a/docs/en/codebases/mmseg.md +++ b/docs/en/codebases/mmseg.md @@ -15,7 +15,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl | DeepLabV3 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | | DeepLabV3+ | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) | | Fast-SCNN[*](#static_shape) | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) | -| UNet[*](#static_shape) | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | +| UNet | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | | ANN[*](#static_shape) | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) | | APCNet | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) | | BiSeNetV1 | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) | diff --git a/docs/en/supported_models.md b/docs/en/supported_models.md index edf51c6591..fb6c42b254 100644 --- a/docs/en/supported_models.md +++ b/docs/en/supported_models.md @@ -29,7 +29,7 @@ The table below lists the models that are guaranteed to be exportable to other b | DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | | DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) | | Fast-SCNN[*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) | -| UNet[*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | +| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | | ANN[*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) | | APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) | | BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) | diff --git a/mmdeploy/codebase/mmseg/models/__init__.py b/mmdeploy/codebase/mmseg/models/__init__.py index 77b260b1c3..f8c63589a9 100644 --- a/mmdeploy/codebase/mmseg/models/__init__.py +++ b/mmdeploy/codebase/mmseg/models/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .decode_heads import * # noqa: F401,F403 from .segmentors import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmseg/models/utils/__init__.py b/mmdeploy/codebase/mmseg/models/utils/__init__.py new file mode 100644 index 0000000000..954eaa3487 --- /dev/null +++ b/mmdeploy/codebase/mmseg/models/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .up_conv_block import up_conv_block__forward + +__all__ = ['up_conv_block__forward'] diff --git a/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py b/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py new file mode 100644 index 0000000000..2ca7592851 --- /dev/null +++ b/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmseg.models.utils.UpConvBlock.forward') +def up_conv_block__forward(ctx, self, skip, x): + """Rewrite `forward` for default backend. + + To support dynamic shape for UNet backbone, + upsample feature maps with `size` instead of `scale_factor` + + Args: + ctx (ContextCaller): The context with additional information. + self: The instance of the original class. + skip (Tensor): Skip branch feature. + x (Tensor): Input feature to be upsampled. + + Returns: + Tensor: Upsampled output feature map. + """ + from mmcv.cnn import ConvModule + + # only valid when self.upsample is from build_upsample_layer + if is_dynamic_shape(ctx.cfg) and not isinstance(self.upsample, ConvModule): + # upsample with `size` instead of `scale_factor` + from mmseg.ops import Upsample + for c in self.upsample.interp_upsample: + if isinstance(c, Upsample): + c.size = skip.shape[-2:] + c.scale_factor = None + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + return out diff --git a/mmdeploy/utils/test.py b/mmdeploy/utils/test.py index 5234d3e495..6f1decaa6b 100644 --- a/mmdeploy/utils/test.py +++ b/mmdeploy/utils/test.py @@ -333,9 +333,12 @@ def get_flatten_inputs( if isinstance(value, torch.Tensor): flatten_inputs[name] = value elif isinstance(value, (list, tuple)): - for i, tensor in enumerate(value): - name_i = f'{name}_{i}' - flatten_inputs[name_i] = tensor + if len(value) == 1: + flatten_inputs[name] = value[0] + else: + for i, tensor in enumerate(value): + name_i = f'{name}_{i}' + flatten_inputs[name_i] = tensor return flatten_inputs @@ -358,15 +361,29 @@ def get_onnx_model(wrapped_model: nn.Module, patched_model = patch_model( wrapped_model, cfg=deploy_cfg, backend=backend.value) flatten_model_inputs = get_flatten_inputs(model_inputs) - input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx'] + input_names = onnx_cfg.get('input_names', None) + if input_names is None: + input_names = [ + k for k, v in flatten_model_inputs.items() if k != 'ctx' + ] output_names = onnx_cfg.get('output_names', None) dynamic_axes = get_dynamic_axes(deploy_cfg, input_names) + class DummyModel(torch.nn.Module): + + def __init__(self): + super(DummyModel, self).__init__() + self.model = patched_model + + def forward(self, inputs: dict): + return self.model(**inputs) + + model = DummyModel().eval() + with RewriterContext( cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad(): torch.onnx.export( - patched_model, - tuple([v for k, v in model_inputs.items()]), + model, (model_inputs, {}), onnx_file_path, export_params=True, input_names=input_names, @@ -421,8 +438,13 @@ def get_backend_outputs(ir_file_path: str, """ backend = get_backend(deploy_cfg) flatten_model_inputs = get_flatten_inputs(model_inputs) - input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx'] - output_names = get_ir_config(deploy_cfg).get('output_names', None) + ir_config = get_ir_config(deploy_cfg) + input_names = ir_config.get('input_names', None) + output_names = ir_config.get('output_names', None) + if input_names is None: + input_names = [ + k for k, v in flatten_model_inputs.items() if k != 'ctx' + ] # prepare backend model and input features if backend == Backend.TENSORRT: diff --git a/tests/test_codebase/test_mmseg/test_mmseg_models.py b/tests/test_codebase/test_mmseg/test_mmseg_models.py index a29abb898f..dfcd5b4cdb 100644 --- a/tests/test_codebase/test_mmseg/test_mmseg_models.py +++ b/tests/test_codebase/test_mmseg/test_mmseg_models.py @@ -9,7 +9,7 @@ from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmdeploy.codebase import import_codebase -from mmdeploy.utils import Backend, Codebase +from mmdeploy.utils import Backend, Codebase, Task from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs, get_rewrite_outputs) @@ -261,3 +261,56 @@ def test_emamodule_forward(backend): model_outputs.shape) assert torch.allclose( rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05) + + +@pytest.mark.parametrize('is_dynamic_shape', [True, False]) +@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) +def test_upconvblock_forward(backend, is_dynamic_shape): + check_backend(backend) + from mmseg.models.backbones.unet import BasicConvBlock + from mmseg.models.utils import UpConvBlock + + head = UpConvBlock(BasicConvBlock, 16, 8, 8).eval() + dynamic_axes = { + 'x': { + 0: 'b', + 2: 'h', + 3: 'w' + }, + 'skip': { + 0: 'b', + 2: 'h', + 3: 'w' + }, + 'output': { + 0: 'b', + 2: 'h', + 3: 'w' + }, + } if is_dynamic_shape else None + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend.value), + onnx_config=dict( + input_names=['skip', 'x'], + output_names=['output'], + dynamic_axes=dynamic_axes), + codebase_config=dict( + type=Codebase.MMSEG.value, task=Task.SEGMENTATION.value))) + x = torch.randn(1, 16, 16, 16) + skip = torch.randn(1, 8, 32, 32) + model_inputs = {'x': x, 'skip': skip} + with torch.no_grad(): + model_outputs = get_model_outputs(head, 'forward', model_inputs) + + wrapped_model = WrapModel(head, 'forward') + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=model_inputs, + deploy_cfg=deploy_cfg) + if is_backend_output: + rewrite_outputs = rewrite_outputs[0] + rewrite_outputs = rewrite_outputs.to(model_outputs).reshape( + model_outputs.shape) + assert torch.allclose( + rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)