Skip to content

Commit

Permalink
[Feature] Support deform_conv on 2.x branch with cambricon MLU backend
Browse files Browse the repository at this point in the history
  • Loading branch information
qipengh authored and huangqipeng committed Jun 6, 2023
1 parent e8b1250 commit eae6e3e
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
| ConvexIoU | || | | |
| CornerPool | || | | |
| Correlation | || | | |
| Deformable Convolution v1/v2 ||| | ||
| Deformable Convolution v1/v2 ||| | ||
| Deformable RoIPool | ||| ||
| DiffIoURotated | || | | |
| DynamicScatter | || | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ConvexIoU | || | | |
| CornerPool | || | | |
| Correlation | || | | |
| Deformable Convolution v1/v2 ||| | ||
| Deformable Convolution v1/v2 ||| | ||
| Deformable RoIPool | ||| ||
| DiffIoURotated | || | | |
| DynamicScatter | || | | |
Expand Down
3 changes: 2 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
]

if IS_MLU_AVAILABLE:
from .deform_conv import DeformConv2dPack_MLU # noqa:F401
from .modulated_deform_conv import \
ModulatedDeformConv2dPack_MLU # noqa:F401
__all__.append('ModulatedDeformConv2dPack_MLU')
__all__.extend(['ModulatedDeformConv2dPack_MLU', 'DeformConv2dPack_MLU'])
61 changes: 61 additions & 0 deletions mmcv/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single

from mmcv.utils import IS_MLU_AVAILABLE
from ..utils import ext_loader
from .modulated_deform_conv import ModulatedDeformConv2dFunction

Expand Down Expand Up @@ -437,3 +438,63 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)


if IS_MLU_AVAILABLE:
import torchvision
from mmengine.utils import digit_version
from torchvision.ops import deform_conv2d as tv_deform_conv2d

@MODELS.register_module('DCN', force=True)
class DeformConv2dPack_MLU(DeformConv2d):
"""This class is the DCN implementation of the MLU device. The MLU
backend support of the operator has been implemented in torchvision.
The mmcv registration mechanism is used for multiplexing here. The
torchvision implementation of DCN is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
im2col_step (int): Number of samples processed by
im2col_cuda_kernel per call. It will work when ``batch_size``
> ``im2col_step``, but ``batch_size`` must be divisible by
``im2col_step``. Default: 32. `New in version 1.7.2.
Currently not supported on MLU devices.`
"""

def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs)

self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()

def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()

def forward(self, x: Tensor) -> Tensor: # type: ignore
cur_im2col_step = min(self.im2col_step, x.size(0))
assert (x.size(0) % cur_im2col_step
) == 0, 'batch size must be divisible by im2col_step'
offset = self.conv_offset(x)
x = x.type_as(offset)
weight = self.weight.type_as(x)
return tv_deform_conv2d(x, offset, weight, None, self.stride,
self.padding, self.dilation)
41 changes: 29 additions & 12 deletions tests/test_ops/test_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from mmcv.utils import IS_MLU_AVAILABLE

try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
Expand Down Expand Up @@ -45,7 +47,10 @@ def _test_deformconv(self,
im2col_step=2):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
from mmcv.ops import DeformConv2dPack
if device == 'mlu':
from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack
else:
from mmcv.ops import DeformConv2dPack
c_in = 1
c_out = 1
batch_size = 10
Expand All @@ -69,6 +74,8 @@ def _test_deformconv(self,
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
if device == 'cuda':
model.cuda()
elif device == 'mlu':
model.mlu()
model.type(dtype)

out = model(x)
Expand Down Expand Up @@ -108,6 +115,7 @@ def _test_deformconv(self,
def _test_amp_deformconv(self,
input_dtype,
threshold=1e-3,
device='cuda',
batch_size=10,
im2col_step=2):
"""The function to test amp released on pytorch 1.6.0.
Expand All @@ -120,15 +128,18 @@ def _test_amp_deformconv(self,
input_dtype: torch.float or torch.half.
threshold: the same as above function.
"""
if not torch.cuda.is_available():
if not torch.cuda.is_available() and device == 'cuda':
return
from mmcv.ops import DeformConv2dPack
if device == 'mlu':
from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack
else:
from mmcv.ops import DeformConv2dPack
c_in = 1
c_out = 1
repeated_input = np.repeat(input, batch_size, axis=0)
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
x = torch.Tensor(repeated_input).cuda().type(input_dtype)
x = torch.Tensor(repeated_input).to(device).type(input_dtype)
x.requires_grad = True
model = DeformConv2dPack(
in_channels=c_in,
Expand All @@ -143,7 +154,10 @@ def _test_amp_deformconv(self,
torch.Tensor(offset_bias).reshape(8))
model.weight.data = torch.nn.Parameter(
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
model.cuda()
if device == 'cuda':
model.cuda()
elif device == 'mlu':
model.mlu()

out = model(x)
out.backward(torch.ones_like(out))
Expand Down Expand Up @@ -180,21 +194,24 @@ def _test_amp_deformconv(self,
def test_deformconv(self):
self._test_deformconv(torch.double, device='cpu')
self._test_deformconv(torch.float, device='cpu', threshold=1e-1)
self._test_deformconv(torch.double)
self._test_deformconv(torch.float)
self._test_deformconv(torch.half, threshold=1e-1)
device = 'mlu' if IS_MLU_AVAILABLE else 'cuda'
self._test_deformconv(torch.double, device=device)
self._test_deformconv(torch.float, device=device)
self._test_deformconv(torch.half, threshold=1e-1, device=device)
# test batch_size < im2col_step
self._test_deformconv(torch.float, batch_size=1, im2col_step=2)
self._test_deformconv(
torch.float, batch_size=1, im2col_step=2, device=device)
# test bach_size % im2col_step != 0
with pytest.raises(
AssertionError,
match='batch size must be divisible by im2col_step'):
self._test_deformconv(torch.float, batch_size=10, im2col_step=3)
self._test_deformconv(
torch.float, batch_size=10, im2col_step=3, device=device)

# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1)
self._test_amp_deformconv(torch.float, 1e-1, device)
self._test_amp_deformconv(torch.half, 1e-1, device)

0 comments on commit eae6e3e

Please sign in to comment.