diff --git a/configs/mmseg/segmentation_pplnn_static-512x1024.py b/configs/mmseg/segmentation_pplnn_static-512x1024.py new file mode 100644 index 000000000..f1b006fdb --- /dev/null +++ b/configs/mmseg/segmentation_pplnn_static-512x1024.py @@ -0,0 +1,5 @@ +_base_ = ['./segmentation_static.py', '../_base_/backends/pplnn.py'] + +onnx_config = dict(input_shape=[1024, 512]) + +backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 512, 1024])) diff --git a/docs/en/codebases/mmseg.md b/docs/en/codebases/mmseg.md index 53a0d312c..30bbba208 100644 --- a/docs/en/codebases/mmseg.md +++ b/docs/en/codebases/mmseg.md @@ -15,7 +15,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl | DeepLabV3 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | | DeepLabV3+ | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) | | Fast-SCNN[*](#static_shape) | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) | -| UNet | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | +| UNet[*](#static_shape) | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | ### Reminder diff --git a/docs/en/supported_models.md b/docs/en/supported_models.md index da7368edb..8f13fe93d 100644 --- a/docs/en/supported_models.md +++ b/docs/en/supported_models.md @@ -28,7 +28,7 @@ The table below lists the models that are guaranteed to be exportable to other b | DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | | DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) | | Fast-SCNN[*static](#note) | MMSegmentation | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) | -| UNet | MMSegmentation | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | +| UNet[*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | | SRCNN | MMEditing | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) | | ESRGAN | MMEditing | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) | | SRGAN | MMEditing | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | diff --git a/mmdeploy/codebase/mmseg/deploy/__init__.py b/mmdeploy/codebase/mmseg/deploy/__init__.py index 730191c03..2d50427b2 100644 --- a/mmdeploy/codebase/mmseg/deploy/__init__.py +++ b/mmdeploy/codebase/mmseg/deploy/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .mmsegmentation import MMSegmentation from .segmentation import Segmentation -from .utils import convert_syncbatchnorm -__all__ = ['convert_syncbatchnorm', 'MMSegmentation', 'Segmentation'] +__all__ = ['MMSegmentation', 'Segmentation'] diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index a81244abd..62528b77a 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -89,10 +89,9 @@ def init_pytorch_model(self, codebases. """ from mmseg.apis import init_segmentor - from mmdeploy.codebase.mmseg.deploy import convert_syncbatchnorm + from mmcv.cnn.utils import revert_sync_batchnorm model = init_segmentor(self.model_cfg, model_checkpoint, self.device) - model = convert_syncbatchnorm(model) - + model = revert_sync_batchnorm(model) return model.eval() def create_input(self, diff --git a/mmdeploy/codebase/mmseg/deploy/utils.py b/mmdeploy/codebase/mmseg/deploy/utils.py deleted file mode 100644 index 05c7e9fdc..000000000 --- a/mmdeploy/codebase/mmseg/deploy/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - - -def convert_syncbatchnorm(module: torch.nn.Module): - """Convert sync batch-norm to batch-norm for inference. - - Args: - module (nn.Module): Input PyTorch model. - - Returns: - nn.Module: PyTorch model without sync batch-norm. - """ - module_output = module - if isinstance(module, torch.nn.SyncBatchNorm): - module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, - module.momentum, module.affine, - module.track_running_stats) - if module.affine: - module_output.weight.data = module.weight.data.clone().detach() - module_output.bias.data = module.bias.data.clone().detach() - # keep requires_grad unchanged - module_output.weight.requires_grad = module.weight.requires_grad - module_output.bias.requires_grad = module.bias.requires_grad - module_output.running_mean = module.running_mean - module_output.running_var = module.running_var - module_output.num_batches_tracked = module.num_batches_tracked - for name, child in module.named_children(): - module_output.add_module(name, convert_syncbatchnorm(child)) - del module - return module_output diff --git a/tests/test_codebase/test_mmseg/test_mmseg_utils.py b/tests/test_codebase/test_mmseg/test_mmseg_utils.py deleted file mode 100644 index 11f71814b..000000000 --- a/tests/test_codebase/test_mmseg/test_mmseg_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn - -from mmdeploy.codebase import import_codebase -from mmdeploy.codebase.mmseg.deploy import convert_syncbatchnorm -from mmdeploy.utils import Codebase - -import_codebase(Codebase.MMSEG) - - -def test_convert_syncbatchnorm(): - - class ExampleModel(nn.Module): - - def __init__(self): - super(ExampleModel, self).__init__() - self.model = nn.Sequential( - nn.Linear(2, 4), nn.SyncBatchNorm(4), nn.Sigmoid(), - nn.Linear(4, 6), nn.SyncBatchNorm(6), nn.Sigmoid()) - - def forward(self, x): - return self.model(x) - - model = ExampleModel() - out_model = convert_syncbatchnorm(model) - assert isinstance(out_model.model[1], - torch.nn.modules.batchnorm.BatchNorm2d) and isinstance( - out_model.model[4], - torch.nn.modules.batchnorm.BatchNorm2d)