Skip to content

Commit

Permalink
modify with commit suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
mengpenghui committed Dec 2, 2022
1 parent 2c859b0 commit a7bb388
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 80 deletions.
28 changes: 16 additions & 12 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_MLU_AVAILABLE
from .active_rotated_filter import active_rotated_filter
from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
Expand Down Expand Up @@ -37,7 +38,6 @@
from .min_area_polygons import min_area_polygons
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
ModulatedDeformConv2dPack_MLU,
modulated_deform_conv2d)
from .multi_scale_deform_attn import MultiScaleDeformableAttention
from .nms import batched_nms, nms, nms_match, nms_quadri, nms_rotated, soft_nms
Expand Down Expand Up @@ -81,17 +81,17 @@
'get_compiler_version', 'get_compiling_cuda_version',
'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'ModulatedDeformConv2dPack_MLU', 'modulated_deform_conv2d', 'batched_nms',
'nms', 'soft_nms', 'nms_match', 'RoIAlign', 'roi_align', 'RoIPool',
'roi_pool', 'SyncBatchNorm', 'Conv2d', 'ConvTranspose2d', 'Linear',
'MaxPool2d', 'CrissCrossAttention', 'PSAMask', 'point_sample',
'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'SAConv2d', 'TINShift',
'tin_shift', 'assign_score_withk', 'box_iou_rotated', 'box_iou_quadri',
'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', 'upfirdn2d',
'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'rotated_feature_align',
'RiRoIAlignRotated', 'riroi_align_rotated', 'RoIAlignRotated',
'roi_align_rotated', 'pixel_group', 'QueryAndGroup', 'GroupAll',
'grouping_operation', 'contour_expand', 'three_nn', 'three_interpolate',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'box_iou_quadri', 'RoIPointPool3d', 'nms_rotated',
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'rotated_feature_align', 'RiRoIAlignRotated',
'riroi_align_rotated', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'QueryAndGroup', 'GroupAll', 'grouping_operation',
'contour_expand', 'three_nn', 'three_interpolate',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample', 'nms_quadri',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
Expand All @@ -107,3 +107,7 @@
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool'
]

if IS_MLU_AVAILABLE:
from .modulated_deform_conv import ModulatedDeformConv2dPack_MLU # noqa:F401
__all__.append('ModulatedDeformConv2dPack_MLU')
142 changes: 80 additions & 62 deletions mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,67 +357,85 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
error_msgs)


@CONV_LAYERS.register_module(
'DCNv2' if IS_MLU_AVAILABLE else 'DCNv2_disabled',
force=True if IS_MLU_AVAILABLE else False)
class ModulatedDeformConv2dPack_MLU(nn.modules.Module):

def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.init_weights()
if IS_MLU_AVAILABLE:

def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
@CONV_LAYERS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(nn.modules.Module):
"""This class is the DCNv2 implementation of the MLU device. The MLU
backend support of the operator has been implemented in torchvision.
The mmcv registration mechanism is used for multiplexing here. The
torchvision implementation of DCNv2 is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
"""

def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.init_weights()

def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()

def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return tv_deform_conv2d(
x,
offset,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask=mask)
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return tv_deform_conv2d(
x,
offset,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask=mask)
12 changes: 6 additions & 6 deletions tests/test_ops/test_modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@

class TestMdconv:

def _test_mdconv(self,
dtype=torch.float,
device='mlu' if IS_MLU_AVAILABLE else 'cuda'):
def _test_mdconv(self, dtype=torch.float, device='cuda'):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
if device == 'mlu':
Expand Down Expand Up @@ -117,9 +115,11 @@ def _test_amp_mdconv(self, input_dtype=torch.float):
def test_mdconv(self):
self._test_mdconv(torch.double, device='cpu')
self._test_mdconv(torch.float, device='cpu')
self._test_mdconv(torch.double)
self._test_mdconv(torch.float)
self._test_mdconv(torch.half)

device = 'mlu' if IS_MLU_AVAILABLE else 'cuda'
self._test_mdconv(torch.double, device=device)
self._test_mdconv(torch.float, device=device)
self._test_mdconv(torch.half, device=device)

# test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half
Expand Down

0 comments on commit a7bb388

Please sign in to comment.