Skip to content

Commit

Permalink
Deploy the Swin Transformer on TensorRT. (#652)
Browse files Browse the repository at this point in the history
* resolve conflicts

* update ut and docs

* fix ut

* refine docstring

* add comments and refine UT

* resolve comments

* resolve comments

* update doc

* add roll export

* check backend

* update regression test
  • Loading branch information
AllentDan authored Jun 30, 2022
1 parent 8f1508e commit efd3995
Show file tree
Hide file tree
Showing 13 changed files with 640 additions and 140 deletions.
23 changes: 23 additions & 0 deletions docs/en/03-benchmark/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,29 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center" rowspan="2"><a href="https://github.com/open-mmlab/mmdetection/blob/master/configs/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py">Swin-Transformer</a></td>
<td align="center" rowspan="2">Instance Segmentation</td>
<td align="center" rowspan="2">COCO2017</td>
<td align="center">box AP</td>
<td align="center">42.7</td>
<td align="center">-</td>
<td align="center">42.7</td>
<td align="center">42.5</td>
<td align="center">37.7</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center">mask AP</td>
<td align="center">39.3</td>
<td align="center">-</td>
<td align="center">39.3</td>
<td align="center">39.3</td>
<td align="center">35.4</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
</tbody>
</table>
</div>
Expand Down
142 changes: 72 additions & 70 deletions docs/en/03-benchmark/supported_models.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/en/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/
| RepPoints | ObjectDetection | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
| Cascade Mask R-CNN | InstanceSegmentation | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| Mask R-CNN | InstanceSegmentation | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
| Swin Transformer | InstanceSegmentation | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
23 changes: 23 additions & 0 deletions docs/zh_cn/03-benchmark/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,29 @@ GPU: ncnn, TensorRT, PPLNN
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center" rowspan="2"><a href="https://github.com/open-mmlab/mmdetection/blob/master/configs/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py">Swin-Transformer</a></td>
<td align="center" rowspan="2">Instance Segmentation</td>
<td align="center" rowspan="2">COCO2017</td>
<td align="center">box AP</td>
<td align="center">42.7</td>
<td align="center">-</td>
<td align="center">42.7</td>
<td align="center">42.5</td>
<td align="center">37.7</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center">mask AP</td>
<td align="center">39.3</td>
<td align="center">-</td>
<td align="center">39.3</td>
<td align="center">39.3</td>
<td align="center">35.4</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
</tbody>
</table>
</div>
Expand Down
140 changes: 71 additions & 69 deletions docs/zh_cn/03-benchmark/supported_models.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .detectors import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403
from .transformer import * # noqa: F401,F403
200 changes: 200 additions & 0 deletions mmdeploy/codebase/mmdet/models/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import get_common_config


@FUNCTION_REWRITER.register_rewriter(
Expand Down Expand Up @@ -63,3 +64,202 @@ def focus__forward__ncnn(ctx, self, x):
x = x.reshape(_b, c * 4, h // 2, w // 2)

return self.conv(x)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.WindowMSA.forward',
backend='tensorrt')
def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
"""Rewrite forward function of WindowMSA class for TensorRT.
1. replace Gather operation of qkv with split.
2. replace SoftMax operation with a workaround done by PyTorch.
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
-1).permute(2, 0, 3, 1, 4).contiguous()

# replace the gather operation with the split
q, k, v = [i.squeeze(0) for i in torch.split(qkv, 1, 0)]

q = q * self.scale

attn = (q @ k.transpose(-2, -1))

relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)

if mask is not None:
nW = mask.shape[0]
attn = attn.view(-1, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)

# replace softmax with a workaround
# weird bug from TensorRT. softmax cannot be used here for fp32 and it
# can be used in fp16, but softmax fp16 performance is not as good as
# exp and log_softmax. Besides, only the UT of exp and log_softmax passed.
fp16_mode = get_common_config(ctx.cfg).get('fp16_mode', False)
if fp16_mode:
attn = torch.exp(torch.log_softmax(attn, dim=self.softmax.dim))
else:
means = torch.mean(attn, self.softmax.dim, keepdim=True)[0]
attn_exp = torch.exp(attn - means)
attn_exp_sum = torch.sum(attn_exp, self.softmax.dim, keepdim=True)
attn = attn_exp / attn_exp_sum

attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).contiguous().reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_reverse',
backend='tensorrt')
def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W):
"""Rewrite window_reverse function of ShiftWindowMSA class for TensorRT.
For TensorRT, seems radical shape transformations are not allowed. Replace
them with soft ones.
Args:
windows: (num_windows*B, window_size, window_size, C)
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))

# x = windows.view(B, H // window_size, W // window_size, window_size,
# window_size, -1)
x = windows.view(B, -1, W, window_size, windows.shape[-1])
x = x.view(B, x.shape[1], -1, window_size, window_size, x.shape[-1])
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, x.shape[-1])
return x


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_partition',
backend='tensorrt')
def shift_window_msa__window_partition__tensorrt(ctx, self, x):
"""Rewrite window_partition function of ShiftWindowMSA class for TensorRT.
For TensorRT, seems radical shape transformations are not allowed. Replace
them with soft ones.
Args:
x: (B, H, W, C)
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H, -1, window_size, C)
x = x.view(B, -1, window_size, x.shape[-3], window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward',
backend='tensorrt')
def shift_window_msa__forward__tensorrt(ctx, self, query, hw_shape):
"""Rewrite forward function of ShiftWindowMSA class for TensorRT.
1. replace dynamic padding with static padding and dynamic slice.
2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability.
"""
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)

# pad feature maps to multiples of window size
query = query.permute(0, 3, 1, 2).contiguous()
# query = torch.nn.ZeroPad2d([0, self.window_size, 0, self.window_size])(
# query)
query = torch.cat(
[query, query.new_zeros(B, C, H, self.window_size)], dim=-1)
query = torch.cat(
[query,
query.new_zeros(B, C, self.window_size, query.shape[-1])],
dim=-2)
slice_h = (H + self.window_size - 1) // self.window_size * self.window_size
slice_w = (W + self.window_size - 1) // self.window_size * self.window_size
query = query[:, :, :slice_h, :slice_w]
query = query.permute(0, 2, 3, 1).contiguous()
H_pad, W_pad = query.shape[1], query.shape[2]

# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

# calculate attention mask for SW-MSA
w_mask = torch.cat([
shifted_query.new_zeros(W_pad - self.window_size),
shifted_query.new_full((self.window_size - self.shift_size, ), 1),
shifted_query.new_full((self.shift_size, ), 2)
])
h_mask = torch.cat([
shifted_query.new_zeros(H_pad - self.window_size),
shifted_query.new_full((self.window_size - self.shift_size, ), 3),
shifted_query.new_full((self.shift_size, ), 6)
])

img_mask = w_mask.unsqueeze(0) + h_mask.unsqueeze(1)
img_mask = img_mask.unsqueeze(0)
img_mask = img_mask.unsqueeze(-1)

# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(-1,
self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
shifted_query = query
attn_mask = None

# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)

# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)

# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

# B H' W' C
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x

x = x[:, :H, :W, :].contiguous()

x = x.view(B, H * W, C)

x = self.drop(x)
return x
50 changes: 50 additions & 0 deletions mmdeploy/codebase/mmdet/models/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.utils.transformer.PatchMerging.forward',
backend='tensorrt')
def patch_merging__forward__tensorrt(ctx, self, x, input_size):
"""Rewrite forward function of PatchMerging class for TensorRT.
In original implementation, mmdet applies nn.unfold to accelerate the
inferece. However, the onnx graph of it can not be parsed correctly by
TensorRT. In mmdeploy, it is replaced.
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
H, W = input_size
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
assert H % 2 == 0 and W % 2 == 0, f'x size ({H}*{W}) are not even.'

x = x.view(B, H, W, C)

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x2 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = x.view(x.shape[0], x.shape[1], 4,
-1).permute(0, 1, 3, 2).reshape(x.shape[0], x.shape[1], -1)
x = self.norm(x) if self.norm else x
x = self.reduction(x)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1

output_size = (out_h, out_w)
return x, output_size
5 changes: 4 additions & 1 deletion mmdeploy/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from .layer_norm import layer_norm__ncnn
from .linear import linear__ncnn
from .lstm import generic_rnn__ncnn
from .pad import _prepare_onnx_paddings__tensorrt
from .roll import roll_default
from .squeeze import squeeze__default

__all__ = [
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
'adaptive_avg_pool3d__default', 'grid_sampler__default',
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn',
'layer_norm__ncnn', 'linear__ncnn'
'layer_norm__ncnn', 'linear__ncnn', '_prepare_onnx_paddings__tensorrt',
'roll_default'
]
Loading

0 comments on commit efd3995

Please sign in to comment.