diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 262711d21c..9784e459b4 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc. | ConvexIoU | | √ | | | | | CornerPool | | √ | | | | | Correlation | | √ | | | | -| Deformable Convolution v1/v2 | √ | √ | | | | +| Deformable Convolution v1/v2 | √ | √ | | | √ | | Deformable RoIPool | | √ | √ | | √ | | DiffIoURotated | | √ | | | | | DynamicScatter | | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index a15392e186..715b38a7fc 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ConvexIoU | | √ | | | | | CornerPool | | √ | | | | | Correlation | | √ | | | | -| Deformable Convolution v1/v2 | √ | √ | | | | +| Deformable Convolution v1/v2 | √ | √ | | | √ | | Deformable RoIPool | | √ | √ | | √ | | DiffIoURotated | | √ | | | | | DynamicScatter | | √ | | | | diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 85f665cd32..dcc0abb6e7 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -12,6 +12,7 @@ from mmcv.utils import deprecated_api_warning from ..cnn import CONV_LAYERS from ..utils import ext_loader, print_log +from .modulated_deform_conv import ModulatedDeformConv2dFunction ext_module = ext_loader.load_ext('_ext', [ 'deform_conv_forward', 'deform_conv_backward_input', @@ -46,6 +47,23 @@ def symbolic(g, bias_i=bias, im2col_step_i=im2col_step) + @staticmethod + def _npu_backward(ctx, grad_output): + input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ + ctx.saved_tensors + grad_input, grad_weight, grad_offset_all, grad_bias = \ + torch.npu_deformable_conv2dbk( + input_tensor, grad_output, offset_out, weight, offset_all, + kernel_size=[weight.shape[3], weight.shape[2]], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, deformable_groups=ctx.deform_groups, + modulated=True) + grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp) + return grad_input, grad_offset, grad_weight, \ + None, None, None, None, None, None, None + @staticmethod def forward(ctx, input: Tensor, @@ -69,6 +87,7 @@ def forward(ctx, ctx.groups = groups ctx.deform_groups = deform_groups ctx.im2col_step = im2col_step + ctx.device = input.device.type # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; # amp won't cast the type of model (float32), but "offset" is cast @@ -79,6 +98,13 @@ def forward(ctx, # whatever the pytorch version is. input = input.type_as(offset) weight = weight.type_as(input) + if ctx.device == 'npu': + mask_shape, _ = torch.chunk(offset, 2, dim=1) + mask = torch.ones_like(mask_shape).to(input.device) + bias = input.new_empty(0) + output = ModulatedDeformConv2dFunction._npu_forward( + ctx, input, offset, mask, weight, bias) + return output ctx.save_for_backward(input, offset, weight) output = input.new_empty( @@ -115,6 +141,8 @@ def backward( ctx, grad_output: Tensor ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None, None, None, None, None, None, None]: + if ctx.device == 'npu': + return DeformConv2dFunction._npu_backward(ctx, grad_output) input, offset, weight = ctx.saved_tensors grad_input = grad_offset = grad_weight = None diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 6a5173cb4f..782e810fc6 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -39,7 +39,7 @@ def _calculate_sort_index(kernel_h, kernel_w, deformable_group): split_num = deformable_group * 2 * kernel_h * kernel_w sort_index = list(range(split_num)) sort_index_fp = (sort_index[1::2] + sort_index[::2]) - sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)} + sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)} sort_index_bp = [sort_index_bp_dict[i] for i in sort_index] sort_index_fp = torch.IntTensor(sort_index_fp) sort_index_bp = torch.IntTensor(sort_index_bp)