Skip to content

Commit

Permalink
[Feature] Support modulated_deform_conv and deform_conv with cambrico…
Browse files Browse the repository at this point in the history
…n MLU backend (#2823)
  • Loading branch information
qipengh authored Jul 27, 2023
1 parent 987d34b commit 86a38aa
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 34 deletions.
4 changes: 2 additions & 2 deletions docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
| ConvexIoU | || | | |
| CornerPool | || | | |
| Correlation | || | | |
| Deformable Convolution v1/v2 ||| | ||
| Deformable Convolution v1/v2 ||| | ||
| Deformable RoIPool | ||| ||
| DiffIoURotated | ||| | |
| DynamicScatter | ||| | |
Expand All @@ -32,7 +32,7 @@ We implement common ops used in detection, segmentation, etc.
| MaskedConv | ||| ||
| MergeCells | || | | |
| MinAreaPolygon | || | | |
| ModulatedDeformConv2d ||| | ||
| ModulatedDeformConv2d ||| | ||
| MultiScaleDeformableAttn | ||| | |
| NMS |||| ||
| NMSRotated |||| ||
Expand Down
4 changes: 2 additions & 2 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ConvexIoU | || | | |
| CornerPool | || | | |
| Correlation | || | | |
| Deformable Convolution v1/v2 ||| | ||
| Deformable Convolution v1/v2 ||| | ||
| Deformable RoIPool | ||| ||
| DiffIoURotated | ||| | |
| DynamicScatter | ||| | |
Expand All @@ -32,7 +32,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MaskedConv | ||| ||
| MergeCells | || | | |
| MinAreaPolygon | || | | |
| ModulatedDeformConv2d ||| | ||
| ModulatedDeformConv2d ||| | ||
| MultiScaleDeformableAttn | ||| | |
| NMS |||| ||
| NMSRotated |||| ||
Expand Down
7 changes: 7 additions & 0 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_MLU_AVAILABLE
from .active_rotated_filter import active_rotated_filter
from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
Expand Down Expand Up @@ -109,3 +110,9 @@
'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d',
'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align'
]

if IS_MLU_AVAILABLE:
from .deform_conv import DeformConv2dPack_MLU # noqa:F401
from .modulated_deform_conv import \
ModulatedDeformConv2dPack_MLU # noqa:F401
__all__.extend(['ModulatedDeformConv2dPack_MLU', 'DeformConv2dPack_MLU'])
61 changes: 61 additions & 0 deletions mmcv/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single

from mmcv.utils import IS_MLU_AVAILABLE
from ..utils import ext_loader
from .modulated_deform_conv import ModulatedDeformConv2dFunction

Expand Down Expand Up @@ -438,3 +439,63 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)


if IS_MLU_AVAILABLE:
import torchvision
from mmengine.utils import digit_version
from torchvision.ops import deform_conv2d as tv_deform_conv2d

@MODELS.register_module('DCN', force=True)
class DeformConv2dPack_MLU(DeformConv2d):
"""This class is the DCN implementation of the MLU device. The MLU
backend support of the operator has been implemented in torchvision.
The mmcv registration mechanism is used for multiplexing here. The
torchvision implementation of DCN is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
im2col_step (int): Number of samples processed by
im2col_cuda_kernel per call. It will work when ``batch_size``
> ``im2col_step``, but ``batch_size`` must be divisible by
``im2col_step``. Default: 32. `New in version 1.7.2.
Currently not supported on MLU devices.`
"""

def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs)

self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()

def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()

def forward(self, x: Tensor) -> Tensor: # type: ignore
cur_im2col_step = min(self.im2col_step, x.size(0))
assert (x.size(0) % cur_im2col_step
) == 0, 'batch size must be divisible by im2col_step'
offset = self.conv_offset(x)
x = x.type_as(offset)
weight = self.weight.type_as(x)
return tv_deform_conv2d(x, offset, weight, None, self.stride,
self.padding, self.dilation)
66 changes: 66 additions & 0 deletions mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single

from mmcv.utils import IS_MLU_AVAILABLE
from ..utils import ext_loader

ext_module = ext_loader.load_ext(
Expand Down Expand Up @@ -358,3 +359,68 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)


if IS_MLU_AVAILABLE:
import torchvision
from mmengine.utils import digit_version
from torchvision.ops import deform_conv2d as tv_deform_conv2d

@MODELS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
"""This class is the DCNv2 implementation of the MLU device.
The MLU backend support of the operator has been implemented
in torchvision. The mmcv registration mechanism is used for
multiplexing here. The torchvision implementation of DCNv2 is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
"""

def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=True)
self.init_weights()

def init_weights(self):
super().init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()

def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
x = x.type_as(offset)
weight = self.weight.type_as(x)
mask = mask.type_as(x)
return tv_deform_conv2d(
x,
offset,
weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask=mask)
88 changes: 73 additions & 15 deletions tests/test_ops/test_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

if IS_MLU_AVAILABLE:
torch.backends.cnnl.allow_tf32 = False

try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
Expand Down Expand Up @@ -45,7 +50,10 @@ def _test_deformconv(self,
im2col_step=2):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
from mmcv.ops import DeformConv2dPack
if device == 'mlu':
from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack
else:
from mmcv.ops import DeformConv2dPack
c_in = 1
c_out = 1
batch_size = 10
Expand All @@ -69,6 +77,8 @@ def _test_deformconv(self,
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
if device == 'cuda':
model.cuda()
elif device == 'mlu':
model.mlu()
model.type(dtype)

out = model(x)
Expand Down Expand Up @@ -108,6 +118,7 @@ def _test_deformconv(self,
def _test_amp_deformconv(self,
input_dtype,
threshold=1e-3,
device='cuda',
batch_size=10,
im2col_step=2):
"""The function to test amp released on pytorch 1.6.0.
Expand All @@ -120,15 +131,18 @@ def _test_amp_deformconv(self,
input_dtype: torch.float or torch.half.
threshold: the same as above function.
"""
if not torch.cuda.is_available():
if not torch.cuda.is_available() and device == 'cuda':
return
from mmcv.ops import DeformConv2dPack
if device == 'mlu':
from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack
else:
from mmcv.ops import DeformConv2dPack
c_in = 1
c_out = 1
repeated_input = np.repeat(input, batch_size, axis=0)
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
x = torch.Tensor(repeated_input).cuda().type(input_dtype)
x = torch.Tensor(repeated_input).to(device).type(input_dtype)
x.requires_grad = True
model = DeformConv2dPack(
in_channels=c_in,
Expand All @@ -143,7 +157,10 @@ def _test_amp_deformconv(self,
torch.Tensor(offset_bias).reshape(8))
model.weight.data = torch.nn.Parameter(
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
model.cuda()
if device == 'cuda':
model.cuda()
elif device == 'mlu':
model.mlu()

out = model(x)
out.backward(torch.ones_like(out))
Expand Down Expand Up @@ -177,24 +194,65 @@ def _test_amp_deformconv(self,
with pytest.raises(AssertionError):
model = DeformConv2d(3, 4, 3, groups=3)

def test_deformconv(self):
self._test_deformconv(torch.double, device='cpu')
self._test_deformconv(torch.float, device='cpu', threshold=1e-1)
self._test_deformconv(torch.double)
self._test_deformconv(torch.float)
self._test_deformconv(torch.half, threshold=1e-1)
@pytest.mark.parametrize('device, threshold', [
('cpu', 1e-1),
pytest.param(
'cuda',
1e-3,
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
1e-3,
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')),
])
def test_deformconv_float(self, device, threshold):
self._test_deformconv(torch.float, device=device, threshold=threshold)
# test batch_size < im2col_step
self._test_deformconv(torch.float, batch_size=1, im2col_step=2)
self._test_deformconv(
torch.float, batch_size=1, im2col_step=2, device=device)
# test bach_size % im2col_step != 0
with pytest.raises(
AssertionError,
match='batch size must be divisible by im2col_step'):
self._test_deformconv(torch.float, batch_size=10, im2col_step=3)
self._test_deformconv(
torch.float, batch_size=10, im2col_step=3, device=device)

@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')),
])
def test_deformconv_double(self, device):
self._test_deformconv(torch.double, device=device)

@pytest.mark.parametrize('device, threshold', [
pytest.param(
'cuda',
1e-1,
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
1e-1,
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')),
])
def test_deformconv_half(self, device, threshold):
self._test_deformconv(torch.half, device=device, threshold=threshold)
# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1)
self._test_amp_deformconv(
torch.float, device=device, threshold=threshold)
self._test_amp_deformconv(
torch.half, device=device, threshold=threshold)
4 changes: 4 additions & 0 deletions tests/test_ops/test_masked_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

if IS_MLU_AVAILABLE:
torch.backends.cnnl.allow_tf32 = False
torch.backends.mlu.matmul.allow_tf32 = False


class TestMaskedConv2d:

Expand Down
Loading

0 comments on commit 86a38aa

Please sign in to comment.