From 9d7fbea897eaa4960bcd04982251f9c7d411ea18 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Wed, 17 Aug 2022 16:35:43 +0800 Subject: [PATCH 1/6] add swin for cls --- docs/en/03-benchmark/supported_models.md | 1 + .../mmcls/models/backbones/__init__.py | 5 + .../models/backbones/swin_transformer.py | 169 ++++++++++++++++++ 3 files changed, 175 insertions(+) create mode 100644 mmdeploy/codebase/mmcls/models/backbones/swin_transformer.py diff --git a/docs/en/03-benchmark/supported_models.md b/docs/en/03-benchmark/supported_models.md index 0be5c21923..70d257dee5 100644 --- a/docs/en/03-benchmark/supported_models.md +++ b/docs/en/03-benchmark/supported_models.md @@ -26,6 +26,7 @@ The table below lists the models that are guaranteed to be exportable to other b | ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | | ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | | VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | +| SwinTransformer | MMClassification | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) | | FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) | | PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) | | DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | diff --git a/mmdeploy/codebase/mmcls/models/backbones/__init__.py b/mmdeploy/codebase/mmcls/models/backbones/__init__.py index 52e4af6bfd..ce248b8837 100644 --- a/mmdeploy/codebase/mmcls/models/backbones/__init__.py +++ b/mmdeploy/codebase/mmcls/models/backbones/__init__.py @@ -1,8 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from .shufflenet_v2 import shufflenetv2_backbone__forward__default +from .swin_transformer import (shift_window_msa__forward__default, + shift_window_msa__get_attn_mask__default, + window_msa__forward__default) from .vision_transformer import visiontransformer__forward__ncnn __all__ = [ 'shufflenetv2_backbone__forward__default', 'visiontransformer__forward__ncnn', + 'shift_window_msa__get_attn_mask__default', + 'shift_window_msa__forward__default', 'window_msa__forward__default' ] diff --git a/mmdeploy/codebase/mmcls/models/backbones/swin_transformer.py b/mmdeploy/codebase/mmcls/models/backbones/swin_transformer.py new file mode 100644 index 0000000000..e68438a6d1 --- /dev/null +++ b/mmdeploy/codebase/mmcls/models/backbones/swin_transformer.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker + + +@FUNCTION_REWRITER.register_rewriter( + func_name= # noqa: E251 + 'mmcls.models.utils.attention.WindowMSA.forward', + extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) +def window_msa__forward__default(ctx, self, x, mask=None): + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + # replace the gather operation with the split + q, k, v = [i.squeeze(0) for i in torch.split(qkv, 1, 0)] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + means = torch.mean(attn, self.softmax.dim, keepdim=True)[0] + attn_exp = torch.exp(attn - means) + attn_exp_sum = torch.sum(attn_exp, self.softmax.dim, keepdim=True) + attn = attn_exp / attn_exp_sum + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +@FUNCTION_REWRITER.register_rewriter( + func_name= # noqa: E251 + 'mmcls.models.utils.ShiftWindowMSA.forward', + extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) +def shift_window_msa__forward__default(ctx, self, query, hw_shape): + # return ctx.origin_func(self, query, hw_shape) + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, f"The query length {L} doesn't match the input "\ + f'shape ({H}, {W}).' + query = query.view(B, H, W, C) + + window_size = self.window_size + shift_size = self.shift_size + + if min(H, W) == window_size: + # If not pad small feature map, avoid shifting when the window size + # is equal to the size of feature map. It's to align with the + # behavior of the original implementation. + shift_size = shift_size if self.pad_small_map else 0 + elif min(H, W) < window_size: + # In the original implementation, the window size will be shrunk + # to the size of feature map. The behavior is different with + # swin-transformer for downstream tasks. To support dynamic input + # shape, we don't allow this feature. + assert self.pad_small_map, \ + f'The input shape ({H}, {W}) is smaller than the window ' \ + f'size ({window_size}). Please set `pad_small_map=True`, or ' \ + 'decrease the `window_size`.' + + # pad feature maps to multiples of window size + query = query.permute(0, 3, 1, 2).contiguous() + # query = torch.nn.ZeroPad2d([0, self.window_size, 0, self.window_size])( + # query) + query = torch.cat([query, query.new_zeros(B, C, H, window_size)], dim=-1) + query = torch.cat( + [query, query.new_zeros(B, C, window_size, query.shape[-1])], dim=-2) + slice_h = (H + window_size - 1) // window_size * window_size + slice_w = (W + window_size - 1) // window_size * window_size + query = query[:, :, :slice_h, :slice_w] + query = query.permute(0, 2, 3, 1).contiguous() + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if shift_size > 0: + query = torch.roll( + query, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + attn_mask = self.get_attn_mask((H_pad, W_pad), + window_size=window_size, + shift_size=shift_size, + device=query.device) + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(query, window_size) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, window_size, window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, window_size) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) + else: + x = shifted_x + + if H != H_pad or W != W_pad: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + + return x + + +@FUNCTION_REWRITER.register_rewriter( + func_name= # noqa: E251 + 'mmcls.models.utils.ShiftWindowMSA.get_attn_mask', + extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) +def shift_window_msa__get_attn_mask__default(ctx, + self, + hw_shape, + window_size, + shift_size, + device=None): + if shift_size > 0: + # calculate attention mask for SW-MSA + w_mask = torch.cat([ + torch.zeros((hw_shape[1] - window_size), + dtype=torch.int64, + device=device), + torch.full((window_size - shift_size, ), 1, device=device), + torch.full((shift_size, ), 2, device=device) + ]) + h_mask = torch.cat([ + torch.zeros((hw_shape[0] - window_size), + dtype=torch.int64, + device=device), + torch.full((window_size - shift_size, ), 3, device=device), + torch.full((shift_size, ), 6, device=device) + ]) + + img_mask = w_mask.unsqueeze(0) + h_mask.unsqueeze(1) + img_mask = img_mask.unsqueeze(0) + img_mask = img_mask.unsqueeze(-1) + # nW, window_size, window_size, 1 + from mmcls.models.utils import ShiftWindowMSA + mask_windows = ShiftWindowMSA.window_partition(img_mask, window_size) + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0) + attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask From 1fc49fb4c42b6f9c0e7876fed18760b19a62d9bf Mon Sep 17 00:00:00 2001 From: AllentDan Date: Thu, 18 Aug 2022 15:14:23 +0800 Subject: [PATCH 2/6] add ut and doc --- docs/en/03-benchmark/benchmark.md | 21 +++ docs/en/03-benchmark/supported_models.md | 150 +++++++++--------- docs/en/04-supported-codebases/mmcls.md | 19 +-- docs/zh_cn/03-benchmark/supported_models.md | 148 ++++++++--------- docs/zh_cn/04-supported-codebases/mmcls.md | 19 +-- .../test_mmcls/test_mmcls_models.py | 81 ++++++++++ 6 files changed, 273 insertions(+), 165 deletions(-) diff --git a/docs/en/03-benchmark/benchmark.md b/docs/en/03-benchmark/benchmark.md index 79024effc4..f76f3b2c16 100644 --- a/docs/en/03-benchmark/benchmark.md +++ b/docs/en/03-benchmark/benchmark.md @@ -580,6 +580,27 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - + + Swin Transformer + top-1 + 81.18 + 81.18 + 81.18 + 81.18 + 81.18 + - + - + + + top-5 + 95.61 + 95.61 + 95.61 + 95.61 + 95.61 + - + - + diff --git a/docs/en/03-benchmark/supported_models.md b/docs/en/03-benchmark/supported_models.md index 70d257dee5..16ce0ee01e 100644 --- a/docs/en/03-benchmark/supported_models.md +++ b/docs/en/03-benchmark/supported_models.md @@ -2,80 +2,80 @@ The table below lists the models that are guaranteed to be exportable to other backends. -| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Model config | -| :-------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: | -| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) | -| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | -| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) | -| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) | -| FCOS | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) | -| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) | -| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) | -| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | -| FoveaBox | MMDetection | Y | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) | -| ATSS | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) | -| GFL | MMDetection | N | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | -| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | -| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | -| Swin Transformer[\*](#note) | MMDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) | -| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) | -| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | -| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | -| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | -| MobileNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) | -| ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | -| ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | -| VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | -| SwinTransformer | MMClassification | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) | -| FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) | -| PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) | -| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | -| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) | -| Fast-SCNN[\*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) | -| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | -| ANN[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) | -| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) | -| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) | -| BiSeNetV2 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv2) | -| CGNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/cgnet) | -| DMNet | MMSegmentation | ? | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dmnet) | -| DNLNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dnlnet) | -| EMANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/emanet) | -| EncNet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/encnet) | -| ERFNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/erfnet) | -| FastFCN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastfcn) | -| GCNet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/gcnet) | -| ICNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/icnet) | -| ISANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/isanet) | -| NonLocal Net | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/nonlocal_net) | -| OCRNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ocrnet) | -| PointRend | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend) | -| Semantic FPN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/sem_fpn) | -| STDC | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/stdc) | -| UPerNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/upernet) | -| DANet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/danet) | -| Segmenter | MMSegmentation | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter) | -| SRCNN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) | -| ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) | -| SRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | -| SRResNet | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | -| Real-ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) | -| EDSR | MMEditing | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) | -| RDN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) | -| DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) | -| PANet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/panet) | -| DBNet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/psenet) | -| CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) | -| SAR | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) | -| SATRN | MMOCR | Y | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/satrn) | -| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) | -| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) | -| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) | -| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | -| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | -| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | -| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | -| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) | +| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Model config | +| :------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: | +| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) | +| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | +| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) | +| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) | +| FCOS | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) | +| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) | +| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) | +| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | +| FoveaBox | MMDetection | Y | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) | +| ATSS | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) | +| GFL | MMDetection | N | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | +| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | +| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | +| Swin Transformer | MMDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) | +| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) | +| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | +| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | +| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | +| MobileNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) | +| ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | +| ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | +| VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | +| SwinTransformer | MMClassification | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) | +| FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) | +| PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) | +| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | +| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) | +| Fast-SCNN[\*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) | +| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | +| ANN[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) | +| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) | +| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) | +| BiSeNetV2 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv2) | +| CGNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/cgnet) | +| DMNet | MMSegmentation | ? | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dmnet) | +| DNLNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dnlnet) | +| EMANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/emanet) | +| EncNet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/encnet) | +| ERFNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/erfnet) | +| FastFCN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastfcn) | +| GCNet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/gcnet) | +| ICNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/icnet) | +| ISANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/isanet) | +| NonLocal Net | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/nonlocal_net) | +| OCRNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ocrnet) | +| PointRend | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend) | +| Semantic FPN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/sem_fpn) | +| STDC | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/stdc) | +| UPerNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/upernet) | +| DANet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/danet) | +| Segmenter | MMSegmentation | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter) | +| SRCNN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) | +| ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) | +| SRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | +| SRResNet | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | +| Real-ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) | +| EDSR | MMEditing | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) | +| RDN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) | +| DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) | +| PANet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/panet) | +| PSENet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/psenet) | +| CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) | +| SAR[\*](#note) | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) | +| SATRN | MMOCR | Y | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/satrn) | +| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) | +| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) | +| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) | +| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | +| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | +| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | +| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | +| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) | ### Note @@ -83,4 +83,4 @@ The table below lists the models that are guaranteed to be exportable to other b - static: This model only support static export. Please use `static` deploy config, just like $MMDEPLOY_DIR/configs/mmseg/segmentation_tensorrt_static-1024x2048.py. - SSD: When you convert SSD model, you need to use min shape deploy config just like 300x300-512x512 rather than 320x320-1344x1344, for example $MMDEPLOY_DIR/configs/mmdet/detection/detection_tensorrt_dynamic-300x300-512x512.py. - YOLOX: YOLOX with ncnn only supports static shape. -- Swin Transformer: For TensorRT, only version 8.4+ is supported. +- SAR: Chinese text recognition model is not supported as the protobuf size of ONNX is limited. diff --git a/docs/en/04-supported-codebases/mmcls.md b/docs/en/04-supported-codebases/mmcls.md index 582a2cae25..d192ecc320 100644 --- a/docs/en/04-supported-codebases/mmcls.md +++ b/docs/en/04-supported-codebases/mmcls.md @@ -8,12 +8,13 @@ Please refer to [install.md](https://github.com/open-mmlab/mmclassification/blob ## List of MMClassification models supported by MMDeploy -| Model | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | Model config | -| :---------------- | :----------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: | -| ResNet | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | -| ResNeXt | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | -| SE-ResNet | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | -| MobileNetV2 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) | -| ShuffleNetV1 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | -| ShuffleNetV2 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | -| VisionTransformer | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | +| Model | TorchScript | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | Model config | +| :---------------- | :---------: | :----------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: | +| ResNet | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | +| ResNeXt | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | +| SE-ResNet | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | +| MobileNetV2 | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) | +| ShuffleNetV1 | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | +| ShuffleNetV2 | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | +| VisionTransformer | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | +| SwinTransformer | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) | diff --git a/docs/zh_cn/03-benchmark/supported_models.md b/docs/zh_cn/03-benchmark/supported_models.md index 256f96f1a2..3afa8a1a28 100644 --- a/docs/zh_cn/03-benchmark/supported_models.md +++ b/docs/zh_cn/03-benchmark/supported_models.md @@ -2,77 +2,81 @@ 自测完成的 model-backend 组合: -| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Model config | -| :-------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: | -| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) | -| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | -| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) | -| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) | -| FCOS | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) | -| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) | -| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) | -| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | -| FoveaBox | MMDetection | Y | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) | -| ATSS | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) | -| GFL | MMDetection | N | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | -| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | -| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | -| Swin Transformer[\*](#note) | MMDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) | -| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) | -| RepPoints | MMDetection | N | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) | -| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | -| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | -| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | -| MobileNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) | -| ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | -| ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | -| VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | -| FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) | -| PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) | -| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | -| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) | -| Fast-SCNN[\*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) | -| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | -| ANN[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) | -| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) | -| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) | -| BiSeNetV2 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv2) | -| CGNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/cgnet) | -| DMNet | MMSegmentation | ? | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dmnet) | -| DNLNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dnlnet) | -| EMANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/emanet) | -| EncNet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/encnet) | -| ERFNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/erfnet) | -| FastFCN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastfcn) | -| GCNet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/gcnet) | -| ICNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/icnet) | -| ISANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/isanet) | -| NonLocal Net | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/nonlocal_net) | -| OCRNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ocrnet) | -| PointRend | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend) | -| Semantic FPN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/sem_fpn) | -| STDC | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/stdc) | -| UPerNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/upernet) | -| DANet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/danet) | -| Segmenter | MMSegmentation | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter) | -| SRCNN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) | -| ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) | -| SRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | -| SRResNet | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | -| Real-ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) | -| EDSR | MMEditing | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) | -| RDN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) | -| DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) | -| CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) | -| SAR | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) | -| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) | -| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) | -| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) | -| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | -| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | -| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | -| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | -| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) | +| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Model config | +| :------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: | +| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) | +| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | +| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) | +| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) | +| FCOS | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) | +| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) | +| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) | +| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | +| FoveaBox | MMDetection | Y | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) | +| ATSS | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) | +| GFL | MMDetection | N | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | +| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | +| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | +| Swin Transformer | MMDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) | +| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) | +| RepPoints | MMDetection | N | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) | +| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | +| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | +| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | +| MobileNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) | +| ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | +| ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | +| VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | +| SwinTransformer | MMClassification | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) | +| FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) | +| PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) | +| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) | +| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) | +| Fast-SCNN[\*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) | +| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) | +| ANN[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) | +| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) | +| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) | +| BiSeNetV2 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv2) | +| CGNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/cgnet) | +| DMNet | MMSegmentation | ? | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dmnet) | +| DNLNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dnlnet) | +| EMANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/emanet) | +| EncNet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/encnet) | +| ERFNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/erfnet) | +| FastFCN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastfcn) | +| GCNet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/gcnet) | +| ICNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/icnet) | +| ISANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/isanet) | +| NonLocal Net | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/nonlocal_net) | +| OCRNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ocrnet) | +| PointRend | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend) | +| Semantic FPN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/sem_fpn) | +| STDC | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/stdc) | +| UPerNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/upernet) | +| DANet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/danet) | +| Segmenter | MMSegmentation | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter) | +| SRCNN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) | +| ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) | +| SRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | +| SRResNet | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) | +| Real-ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) | +| EDSR | MMEditing | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) | +| RDN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) | +| DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) | +| PANet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/panet) | +| PSENet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/psenet) | +| CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) | +| SAR[\*](#note) | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) | +| SATRN | MMOCR | Y | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/satrn) | +| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) | +| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) | +| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) | +| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | +| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | +| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | +| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | +| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) | ## Note @@ -80,4 +84,4 @@ - static: This model only support static export. Please use `static` deploy config, just like $MMDEPLOY_DIR/configs/mmseg/segmentation_tensorrt_static-1024x2048.py. - SSD: When you convert SSD model, you need to use min shape deploy config just like 300x300-512x512 rather than 320x320-1344x1344, for example $MMDEPLOY_DIR/configs/mmdet/detection/detection_tensorrt_dynamic-300x300-512x512.py. - YOLOX: YOLOX with ncnn only supports static shape. -- Swin Transformer: For TensorRT, only version 8.4+ is supported. +- SAR: Chinese text recognition model is not supported as the protobuf size of ONNX is limited. diff --git a/docs/zh_cn/04-supported-codebases/mmcls.md b/docs/zh_cn/04-supported-codebases/mmcls.md index 1bfa37118d..3e7860ad2b 100644 --- a/docs/zh_cn/04-supported-codebases/mmcls.md +++ b/docs/zh_cn/04-supported-codebases/mmcls.md @@ -8,12 +8,13 @@ ## 支持列表 -| Model | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | Model config | -| :---------------- | :----------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: | -| ResNet | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | -| ResNeXt | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | -| SE-ResNet | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | -| MobileNetV2 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) | -| ShuffleNetV1 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | -| ShuffleNetV2 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | -| VisionTransformer | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | +| Model | TorchScript | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | Model config | +| :---------------- | :---------: | :----------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: | +| ResNet | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | +| ResNeXt | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | +| SE-ResNet | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | +| MobileNetV2 | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) | +| ShuffleNetV1 | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) | +| ShuffleNetV2 | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) | +| VisionTransformer | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) | +| SwinTransformer | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) | diff --git a/tests/test_codebase/test_mmcls/test_mmcls_models.py b/tests/test_codebase/test_mmcls/test_mmcls_models.py index 4637e62649..1670600ff1 100644 --- a/tests/test_codebase/test_mmcls/test_mmcls_models.py +++ b/tests/test_codebase/test_mmcls/test_mmcls_models.py @@ -248,3 +248,84 @@ def test_gap__forward(backend_type: Backend, inputs: list): rewrite_output = rewrite_output.cpu().numpy() assert np.allclose( model_output, rewrite_output, rtol=1e-03, atol=1e-05) + + +@pytest.mark.skipif( + reason='Only support GPU test', condition=not torch.cuda.is_available()) +@pytest.mark.parametrize('backend_type', [(Backend.TENSORRT)]) +def test_windows_msa_cls(backend_type: Backend): + check_backend(backend_type) + from mmcls.models.utils.attention import WindowMSA + model = WindowMSA(96, (7, 7), 3) + model.cuda().eval() + output_names = ['output'] + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict( + type=backend_type.value, + common_config=dict(fp16_mode=True, max_workspace_size=1 << 20), + model_inputs=[ + dict( + input_shapes=dict( + x=dict( + min_shape=[12, 49, 96], + opt_shape=[12, 49, 96], + max_shape=[12, 49, 96]), + mask=dict( + min_shape=[12, 49, 49], + opt_shape=[12, 49, 49], + max_shape=[12, 49, 49]))) + ]), + onnx_config=dict( + input_shape=None, + input_names=['x', 'mask'], + output_names=output_names))) + + x = torch.randn([12, 49, 96]).cuda() + mask = torch.randn([12, 49, 49]).cuda() + wrapped_model = WrapModel(model, 'forward') + rewrite_inputs = {'x': x, 'mask': mask} + _ = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + +@pytest.mark.skipif( + reason='Only support GPU test', condition=not torch.cuda.is_available()) +@pytest.mark.parametrize('backend_type', [(Backend.TENSORRT)]) +def test_shift_windows_msa_cls(backend_type: Backend): + check_backend(backend_type) + from mmcls.models.utils import ShiftWindowMSA + model = ShiftWindowMSA(96, 3, 7) + model.cuda().eval() + output_names = ['output'] + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict( + type=backend_type.value, + model_inputs=[ + dict( + input_shapes=dict( + query=dict( + min_shape=[1, 60800, 96], + opt_shape=[1, 60800, 96], + max_shape=[1, 60800, 96]))) + ]), + onnx_config=dict( + input_shape=None, + input_names=['query'], + output_names=output_names))) + + query = torch.randn([1, 60800, 96]).cuda() + hw_shape = (torch.tensor(200), torch.tensor(304)) + + wrapped_model = WrapModel(model, 'forward') + rewrite_inputs = {'query': query, 'hw_shape': hw_shape} + _ = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg, + run_with_backend=False) From 200954c02305c977a6fe9649e97ed35effab64eb Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 19 Aug 2022 11:04:00 +0800 Subject: [PATCH 3/6] reduce trt batch size --- .../classification_tensorrt-fp16_dynamic-224x224-224x224.py | 2 +- .../classification_tensorrt-int8_dynamic-224x224-224x224.py | 2 +- .../mmcls/classification_tensorrt_dynamic-224x224-224x224.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/mmcls/classification_tensorrt-fp16_dynamic-224x224-224x224.py b/configs/mmcls/classification_tensorrt-fp16_dynamic-224x224-224x224.py index d71ceef990..3bf484bb3d 100644 --- a/configs/mmcls/classification_tensorrt-fp16_dynamic-224x224-224x224.py +++ b/configs/mmcls/classification_tensorrt-fp16_dynamic-224x224-224x224.py @@ -9,5 +9,5 @@ input=dict( min_shape=[1, 3, 224, 224], opt_shape=[4, 3, 224, 224], - max_shape=[64, 3, 224, 224]))) + max_shape=[32, 3, 224, 224]))) ]) diff --git a/configs/mmcls/classification_tensorrt-int8_dynamic-224x224-224x224.py b/configs/mmcls/classification_tensorrt-int8_dynamic-224x224-224x224.py index a0a5871b5d..e9e43d365b 100644 --- a/configs/mmcls/classification_tensorrt-int8_dynamic-224x224-224x224.py +++ b/configs/mmcls/classification_tensorrt-int8_dynamic-224x224-224x224.py @@ -9,5 +9,5 @@ input=dict( min_shape=[1, 3, 224, 224], opt_shape=[4, 3, 224, 224], - max_shape=[64, 3, 224, 224]))) + max_shape=[32, 3, 224, 224]))) ]) diff --git a/configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py b/configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py index 1e091713ad..66a5c16ea4 100644 --- a/configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py +++ b/configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py @@ -9,5 +9,5 @@ input=dict( min_shape=[1, 3, 224, 224], opt_shape=[4, 3, 224, 224], - max_shape=[64, 3, 224, 224]))) + max_shape=[32, 3, 224, 224]))) ]) From e03ff8e0526dbf6c9301fd871315b12ec28b7ecd Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 19 Aug 2022 11:47:03 +0800 Subject: [PATCH 4/6] add regression test --- tests/regression/mmcls.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/regression/mmcls.yml b/tests/regression/mmcls.yml index 78943b63d5..a6f817f484 100644 --- a/tests/regression/mmcls.yml +++ b/tests/regression/mmcls.yml @@ -213,3 +213,12 @@ models: - *pipeline_ort_dynamic_fp32 - *pipeline_trt_static_fp16_384x384 - *pipeline_ncnn_static_fp32 + + - name: SwinTransformer + metafile: configs/swin_transformer/metafile.yml + model_configs: + - configs/swin_transformer/swin-tiny_16xb64_in1k.py + pipelines: + - *pipeline_ts_fp32 + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_static_fp16 From bb4a6d34da40a42c00ce44fdf073a358ad7fbb31 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 19 Aug 2022 17:48:55 +0800 Subject: [PATCH 5/6] resolve comments --- ...n_tensorrt-fp16_dynamic-224x224-224x224.py | 2 +- ...n_tensorrt-int8_dynamic-224x224-224x224.py | 2 +- ...cation_tensorrt_dynamic-224x224-224x224.py | 2 +- .../mmcls/models/backbones/__init__.py | 7 +- .../models/backbones/swin_transformer.py | 169 ----------------- .../codebase/mmcls/models/utils/__init__.py | 11 +- .../codebase/mmcls/models/utils/attention.py | 171 +++++++++++++++++- 7 files changed, 183 insertions(+), 181 deletions(-) delete mode 100644 mmdeploy/codebase/mmcls/models/backbones/swin_transformer.py diff --git a/configs/mmcls/classification_tensorrt-fp16_dynamic-224x224-224x224.py b/configs/mmcls/classification_tensorrt-fp16_dynamic-224x224-224x224.py index 3bf484bb3d..72f5764123 100644 --- a/configs/mmcls/classification_tensorrt-fp16_dynamic-224x224-224x224.py +++ b/configs/mmcls/classification_tensorrt-fp16_dynamic-224x224-224x224.py @@ -9,5 +9,5 @@ input=dict( min_shape=[1, 3, 224, 224], opt_shape=[4, 3, 224, 224], - max_shape=[32, 3, 224, 224]))) + max_shape=[8, 3, 224, 224]))) ]) diff --git a/configs/mmcls/classification_tensorrt-int8_dynamic-224x224-224x224.py b/configs/mmcls/classification_tensorrt-int8_dynamic-224x224-224x224.py index e9e43d365b..30b4d71dd1 100644 --- a/configs/mmcls/classification_tensorrt-int8_dynamic-224x224-224x224.py +++ b/configs/mmcls/classification_tensorrt-int8_dynamic-224x224-224x224.py @@ -9,5 +9,5 @@ input=dict( min_shape=[1, 3, 224, 224], opt_shape=[4, 3, 224, 224], - max_shape=[32, 3, 224, 224]))) + max_shape=[8, 3, 224, 224]))) ]) diff --git a/configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py b/configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py index 66a5c16ea4..d77c853ce6 100644 --- a/configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py +++ b/configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py @@ -9,5 +9,5 @@ input=dict( min_shape=[1, 3, 224, 224], opt_shape=[4, 3, 224, 224], - max_shape=[32, 3, 224, 224]))) + max_shape=[8, 3, 224, 224]))) ]) diff --git a/mmdeploy/codebase/mmcls/models/backbones/__init__.py b/mmdeploy/codebase/mmcls/models/backbones/__init__.py index ce248b8837..fd9d7d3d41 100644 --- a/mmdeploy/codebase/mmcls/models/backbones/__init__.py +++ b/mmdeploy/codebase/mmcls/models/backbones/__init__.py @@ -1,13 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .shufflenet_v2 import shufflenetv2_backbone__forward__default -from .swin_transformer import (shift_window_msa__forward__default, - shift_window_msa__get_attn_mask__default, - window_msa__forward__default) from .vision_transformer import visiontransformer__forward__ncnn __all__ = [ 'shufflenetv2_backbone__forward__default', - 'visiontransformer__forward__ncnn', - 'shift_window_msa__get_attn_mask__default', - 'shift_window_msa__forward__default', 'window_msa__forward__default' + 'visiontransformer__forward__ncnn' ] diff --git a/mmdeploy/codebase/mmcls/models/backbones/swin_transformer.py b/mmdeploy/codebase/mmcls/models/backbones/swin_transformer.py deleted file mode 100644 index e68438a6d1..0000000000 --- a/mmdeploy/codebase/mmcls/models/backbones/swin_transformer.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from mmdeploy.core import FUNCTION_REWRITER -from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker - - -@FUNCTION_REWRITER.register_rewriter( - func_name= # noqa: E251 - 'mmcls.models.utils.attention.WindowMSA.forward', - extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) -def window_msa__forward__default(ctx, self, x, mask=None): - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - # replace the gather operation with the split - q, k, v = [i.squeeze(0) for i in torch.split(qkv, 1, 0)] - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute( - 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, - N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - - means = torch.mean(attn, self.softmax.dim, keepdim=True)[0] - attn_exp = torch.exp(attn - means) - attn_exp_sum = torch.sum(attn_exp, self.softmax.dim, keepdim=True) - attn = attn_exp / attn_exp_sum - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -@FUNCTION_REWRITER.register_rewriter( - func_name= # noqa: E251 - 'mmcls.models.utils.ShiftWindowMSA.forward', - extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) -def shift_window_msa__forward__default(ctx, self, query, hw_shape): - # return ctx.origin_func(self, query, hw_shape) - B, L, C = query.shape - H, W = hw_shape - assert L == H * W, f"The query length {L} doesn't match the input "\ - f'shape ({H}, {W}).' - query = query.view(B, H, W, C) - - window_size = self.window_size - shift_size = self.shift_size - - if min(H, W) == window_size: - # If not pad small feature map, avoid shifting when the window size - # is equal to the size of feature map. It's to align with the - # behavior of the original implementation. - shift_size = shift_size if self.pad_small_map else 0 - elif min(H, W) < window_size: - # In the original implementation, the window size will be shrunk - # to the size of feature map. The behavior is different with - # swin-transformer for downstream tasks. To support dynamic input - # shape, we don't allow this feature. - assert self.pad_small_map, \ - f'The input shape ({H}, {W}) is smaller than the window ' \ - f'size ({window_size}). Please set `pad_small_map=True`, or ' \ - 'decrease the `window_size`.' - - # pad feature maps to multiples of window size - query = query.permute(0, 3, 1, 2).contiguous() - # query = torch.nn.ZeroPad2d([0, self.window_size, 0, self.window_size])( - # query) - query = torch.cat([query, query.new_zeros(B, C, H, window_size)], dim=-1) - query = torch.cat( - [query, query.new_zeros(B, C, window_size, query.shape[-1])], dim=-2) - slice_h = (H + window_size - 1) // window_size * window_size - slice_w = (W + window_size - 1) // window_size * window_size - query = query[:, :, :slice_h, :slice_w] - query = query.permute(0, 2, 3, 1).contiguous() - H_pad, W_pad = query.shape[1], query.shape[2] - - # cyclic shift - if shift_size > 0: - query = torch.roll( - query, shifts=(-shift_size, -shift_size), dims=(1, 2)) - - attn_mask = self.get_attn_mask((H_pad, W_pad), - window_size=window_size, - shift_size=shift_size, - device=query.device) - - # nW*B, window_size, window_size, C - query_windows = self.window_partition(query, window_size) - # nW*B, window_size*window_size, C - query_windows = query_windows.view(-1, window_size**2, C) - - # W-MSA/SW-MSA (nW*B, window_size*window_size, C) - attn_windows = self.w_msa(query_windows, mask=attn_mask) - - # merge windows - attn_windows = attn_windows.view(-1, window_size, window_size, C) - - # B H' W' C - shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, window_size) - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) - else: - x = shifted_x - - if H != H_pad or W != W_pad: - x = x[:, :H, :W, :].contiguous() - - x = x.view(B, H * W, C) - - x = self.drop(x) - - return x - - -@FUNCTION_REWRITER.register_rewriter( - func_name= # noqa: E251 - 'mmcls.models.utils.ShiftWindowMSA.get_attn_mask', - extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) -def shift_window_msa__get_attn_mask__default(ctx, - self, - hw_shape, - window_size, - shift_size, - device=None): - if shift_size > 0: - # calculate attention mask for SW-MSA - w_mask = torch.cat([ - torch.zeros((hw_shape[1] - window_size), - dtype=torch.int64, - device=device), - torch.full((window_size - shift_size, ), 1, device=device), - torch.full((shift_size, ), 2, device=device) - ]) - h_mask = torch.cat([ - torch.zeros((hw_shape[0] - window_size), - dtype=torch.int64, - device=device), - torch.full((window_size - shift_size, ), 3, device=device), - torch.full((shift_size, ), 6, device=device) - ]) - - img_mask = w_mask.unsqueeze(0) + h_mask.unsqueeze(1) - img_mask = img_mask.unsqueeze(0) - img_mask = img_mask.unsqueeze(-1) - # nW, window_size, window_size, 1 - from mmcls.models.utils import ShiftWindowMSA - mask_windows = ShiftWindowMSA.window_partition(img_mask, window_size) - mask_windows = mask_windows.view(-1, window_size * window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0) - attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0) - else: - attn_mask = None - return attn_mask diff --git a/mmdeploy/codebase/mmcls/models/utils/__init__.py b/mmdeploy/codebase/mmcls/models/utils/__init__.py index a3b76e8d72..2de6b6f67f 100644 --- a/mmdeploy/codebase/mmcls/models/utils/__init__.py +++ b/mmdeploy/codebase/mmcls/models/utils/__init__.py @@ -1,4 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .attention import multiheadattention__forward__ncnn +from .attention import (multiheadattention__forward__ncnn, + shift_window_msa__forward__default, + shift_window_msa__get_attn_mask__default, + window_msa__forward__default) -__all__ = ['multiheadattention__forward__ncnn'] +__all__ = [ + 'multiheadattention__forward__ncnn', + 'shift_window_msa__get_attn_mask__default', + 'shift_window_msa__forward__default', 'window_msa__forward__default' +] diff --git a/mmdeploy/codebase/mmcls/models/utils/attention.py b/mmdeploy/codebase/mmcls/models/utils/attention.py index 2088af1f56..f5d65f6a10 100644 --- a/mmdeploy/codebase/mmcls/models/utils/attention.py +++ b/mmdeploy/codebase/mmcls/models/utils/attention.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker from mmdeploy.mmcv.cnn import MultiHeadAttentionop -from mmdeploy.utils import Backend +from mmdeploy.utils import Backend, get_dynamic_axes @FUNCTION_REWRITER.register_rewriter( @@ -44,3 +47,169 @@ def multiheadattention__forward__ncnn(ctx, self, qkv_input): v_bias, o_weight, o_bias, self.embed_dims, self.num_heads) return out + + +@FUNCTION_REWRITER.register_rewriter( + func_name= # noqa: E251 + 'mmcls.models.utils.attention.WindowMSA.forward', + extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) +def window_msa__forward__default(ctx, self, x, mask=None): + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + # replace the gather operation with the split + q, k, v = [i.squeeze(0) for i in torch.split(qkv, 1, 0)] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + means = torch.mean(attn, self.softmax.dim, keepdim=True)[0] + attn_exp = torch.exp(attn - means) + attn_exp_sum = torch.sum(attn_exp, self.softmax.dim, keepdim=True) + attn = attn_exp / attn_exp_sum + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +@FUNCTION_REWRITER.register_rewriter( + func_name= # noqa: E251 + 'mmcls.models.utils.ShiftWindowMSA.forward', + extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) +def shift_window_msa__forward__default(ctx, self, query, hw_shape): + if get_dynamic_axes(ctx.cfg) is None: + # avoid the weird bug of torch to onnx + return ctx.origin_func(self, query, hw_shape) + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, f"The query length {L} doesn't match the input "\ + f'shape ({H}, {W}).' + query = query.view(B, H, W, C) + + window_size = self.window_size + shift_size = self.shift_size + + if min(H, W) == window_size: + # If not pad small feature map, avoid shifting when the window size + # is equal to the size of feature map. It's to align with the + # behavior of the original implementation. + shift_size = shift_size if self.pad_small_map else 0 + elif min(H, W) < window_size: + # In the original implementation, the window size will be shrunk + # to the size of feature map. The behavior is different with + # swin-transformer for downstream tasks. To support dynamic input + # shape, we don't allow this feature. + assert self.pad_small_map, \ + f'The input shape ({H}, {W}) is smaller than the window ' \ + f'size ({window_size}). Please set `pad_small_map=True`, or ' \ + 'decrease the `window_size`.' + + # pad feature maps to multiples of window size + query = query.permute(0, 3, 1, 2).contiguous() + # query = torch.nn.ZeroPad2d([0, self.window_size, 0, self.window_size])( + # query) + query = torch.cat([query, query.new_zeros(B, C, H, window_size)], dim=-1) + query = torch.cat( + [query, query.new_zeros(B, C, window_size, query.shape[-1])], dim=-2) + slice_h = (H + window_size - 1) // window_size * window_size + slice_w = (W + window_size - 1) // window_size * window_size + query = query[:, :, :slice_h, :slice_w] + query = query.permute(0, 2, 3, 1).contiguous() + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if shift_size > 0: + query = torch.roll( + query, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + attn_mask = self.get_attn_mask((H_pad, W_pad), + window_size=window_size, + shift_size=shift_size, + device=query.device) + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(query, window_size) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, window_size, window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, window_size) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) + else: + x = shifted_x + + if H != H_pad or W != W_pad: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + + return x + + +@FUNCTION_REWRITER.register_rewriter( + func_name= # noqa: E251 + 'mmcls.models.utils.ShiftWindowMSA.get_attn_mask', + extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) +def shift_window_msa__get_attn_mask__default(ctx, + self, + hw_shape, + window_size, + shift_size, + device=None): + if shift_size > 0: + # calculate attention mask for SW-MSA + w_mask = torch.cat([ + torch.zeros((hw_shape[1] - window_size), + dtype=torch.int64, + device=device), + torch.full((window_size - shift_size, ), 1, device=device), + torch.full((shift_size, ), 2, device=device) + ]) + h_mask = torch.cat([ + torch.zeros((hw_shape[0] - window_size), + dtype=torch.int64, + device=device), + torch.full((window_size - shift_size, ), 3, device=device), + torch.full((shift_size, ), 6, device=device) + ]) + + img_mask = w_mask.unsqueeze(0) + h_mask.unsqueeze(1) + img_mask = img_mask.unsqueeze(0) + img_mask = img_mask.unsqueeze(-1) + # nW, window_size, window_size, 1 + from mmcls.models.utils import ShiftWindowMSA + mask_windows = ShiftWindowMSA.window_partition(img_mask, window_size) + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0) + attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask From b0c3fe917d77131d996de3e1679c0fc49e48fd0a Mon Sep 17 00:00:00 2001 From: AllentDan Date: Tue, 23 Aug 2022 14:53:44 +0800 Subject: [PATCH 6/6] remove useless rewriting logic --- .../codebase/mmcls/models/utils/__init__.py | 5 +- .../codebase/mmcls/models/utils/attention.py | 49 ++++--------------- .../test_mmcls/test_mmcls_models.py | 42 ---------------- 3 files changed, 11 insertions(+), 85 deletions(-) diff --git a/mmdeploy/codebase/mmcls/models/utils/__init__.py b/mmdeploy/codebase/mmcls/models/utils/__init__.py index 2de6b6f67f..3d0a179949 100644 --- a/mmdeploy/codebase/mmcls/models/utils/__init__.py +++ b/mmdeploy/codebase/mmcls/models/utils/__init__.py @@ -1,11 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .attention import (multiheadattention__forward__ncnn, shift_window_msa__forward__default, - shift_window_msa__get_attn_mask__default, - window_msa__forward__default) + shift_window_msa__get_attn_mask__default) __all__ = [ 'multiheadattention__forward__ncnn', 'shift_window_msa__get_attn_mask__default', - 'shift_window_msa__forward__default', 'window_msa__forward__default' + 'shift_window_msa__forward__default' ] diff --git a/mmdeploy/codebase/mmcls/models/utils/attention.py b/mmdeploy/codebase/mmcls/models/utils/attention.py index f5d65f6a10..0d8a9234a5 100644 --- a/mmdeploy/codebase/mmcls/models/utils/attention.py +++ b/mmdeploy/codebase/mmcls/models/utils/attention.py @@ -49,51 +49,16 @@ def multiheadattention__forward__ncnn(ctx, self, qkv_input): return out -@FUNCTION_REWRITER.register_rewriter( - func_name= # noqa: E251 - 'mmcls.models.utils.attention.WindowMSA.forward', - extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) -def window_msa__forward__default(ctx, self, x, mask=None): - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - # replace the gather operation with the split - q, k, v = [i.squeeze(0) for i in torch.split(qkv, 1, 0)] - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute( - 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, - N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - - means = torch.mean(attn, self.softmax.dim, keepdim=True)[0] - attn_exp = torch.exp(attn - means) - attn_exp_sum = torch.sum(attn_exp, self.softmax.dim, keepdim=True) - attn = attn_exp / attn_exp_sum - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - @FUNCTION_REWRITER.register_rewriter( func_name= # noqa: E251 'mmcls.models.utils.ShiftWindowMSA.forward', extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) def shift_window_msa__forward__default(ctx, self, query, hw_shape): + """Rewrite forward function of ShiftWindowMSA class for TensorRT. + + 1. replace dynamic padding with static padding and dynamic slice. + 2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability. + """ if get_dynamic_axes(ctx.cfg) is None: # avoid the weird bug of torch to onnx return ctx.origin_func(self, query, hw_shape) @@ -183,6 +148,10 @@ def shift_window_msa__get_attn_mask__default(ctx, window_size, shift_size, device=None): + """Rewrite get_attn_mask function of ShiftWindowMSA class. + + Replace the loop of setitem with a simpler logic. + """ if shift_size > 0: # calculate attention mask for SW-MSA w_mask = torch.cat([ diff --git a/tests/test_codebase/test_mmcls/test_mmcls_models.py b/tests/test_codebase/test_mmcls/test_mmcls_models.py index 1670600ff1..05318369c7 100644 --- a/tests/test_codebase/test_mmcls/test_mmcls_models.py +++ b/tests/test_codebase/test_mmcls/test_mmcls_models.py @@ -250,48 +250,6 @@ def test_gap__forward(backend_type: Backend, inputs: list): model_output, rewrite_output, rtol=1e-03, atol=1e-05) -@pytest.mark.skipif( - reason='Only support GPU test', condition=not torch.cuda.is_available()) -@pytest.mark.parametrize('backend_type', [(Backend.TENSORRT)]) -def test_windows_msa_cls(backend_type: Backend): - check_backend(backend_type) - from mmcls.models.utils.attention import WindowMSA - model = WindowMSA(96, (7, 7), 3) - model.cuda().eval() - output_names = ['output'] - - deploy_cfg = mmcv.Config( - dict( - backend_config=dict( - type=backend_type.value, - common_config=dict(fp16_mode=True, max_workspace_size=1 << 20), - model_inputs=[ - dict( - input_shapes=dict( - x=dict( - min_shape=[12, 49, 96], - opt_shape=[12, 49, 96], - max_shape=[12, 49, 96]), - mask=dict( - min_shape=[12, 49, 49], - opt_shape=[12, 49, 49], - max_shape=[12, 49, 49]))) - ]), - onnx_config=dict( - input_shape=None, - input_names=['x', 'mask'], - output_names=output_names))) - - x = torch.randn([12, 49, 96]).cuda() - mask = torch.randn([12, 49, 49]).cuda() - wrapped_model = WrapModel(model, 'forward') - rewrite_inputs = {'x': x, 'mask': mask} - _ = get_rewrite_outputs( - wrapped_model=wrapped_model, - model_inputs=rewrite_inputs, - deploy_cfg=deploy_cfg) - - @pytest.mark.skipif( reason='Only support GPU test', condition=not torch.cuda.is_available()) @pytest.mark.parametrize('backend_type', [(Backend.TENSORRT)])