Skip to content

Commit

Permalink
[Feature] Support DCNv1 on Ascend device (#2480)
Browse files Browse the repository at this point in the history
* update lately npu modification--DCNv1

update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* check code

* Add ops to EN/ZH documents
  • Loading branch information
MiniTIckW authored Jan 6, 2023
1 parent a953537 commit f76de90
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
| ConvexIoU | || | | |
| CornerPool | || | | |
| Correlation | || | | |
| Deformable Convolution v1/v2 ||| | | |
| Deformable Convolution v1/v2 ||| | | |
| Deformable RoIPool | ||| ||
| DiffIoURotated | || | | |
| DynamicScatter | || | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ConvexIoU | || | | |
| CornerPool | || | | |
| Correlation | || | | |
| Deformable Convolution v1/v2 ||| | | |
| Deformable Convolution v1/v2 ||| | | |
| Deformable RoIPool | ||| ||
| DiffIoURotated | || | | |
| DynamicScatter | || | | |
Expand Down
28 changes: 28 additions & 0 deletions mmcv/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log
from .modulated_deform_conv import ModulatedDeformConv2dFunction

ext_module = ext_loader.load_ext('_ext', [
'deform_conv_forward', 'deform_conv_backward_input',
Expand Down Expand Up @@ -46,6 +47,23 @@ def symbolic(g,
bias_i=bias,
im2col_step_i=im2col_step)

@staticmethod
def _npu_backward(ctx, grad_output):
input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \
ctx.saved_tensors
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, deformable_groups=ctx.deform_groups,
modulated=True)
grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp)
return grad_input, grad_offset, grad_weight, \
None, None, None, None, None, None, None

@staticmethod
def forward(ctx,
input: Tensor,
Expand All @@ -69,6 +87,7 @@ def forward(ctx,
ctx.groups = groups
ctx.deform_groups = deform_groups
ctx.im2col_step = im2col_step
ctx.device = input.device.type

# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of model (float32), but "offset" is cast
Expand All @@ -79,6 +98,13 @@ def forward(ctx,
# whatever the pytorch version is.
input = input.type_as(offset)
weight = weight.type_as(input)
if ctx.device == 'npu':
mask_shape, _ = torch.chunk(offset, 2, dim=1)
mask = torch.ones_like(mask_shape).to(input.device)
bias = input.new_empty(0)
output = ModulatedDeformConv2dFunction._npu_forward(
ctx, input, offset, mask, weight, bias)
return output
ctx.save_for_backward(input, offset, weight)

output = input.new_empty(
Expand Down Expand Up @@ -115,6 +141,8 @@ def backward(
ctx, grad_output: Tensor
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None,
None, None, None, None, None, None]:
if ctx.device == 'npu':
return DeformConv2dFunction._npu_backward(ctx, grad_output)
input, offset, weight = ctx.saved_tensors

grad_input = grad_offset = grad_weight = None
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _calculate_sort_index(kernel_h, kernel_w, deformable_group):
split_num = deformable_group * 2 * kernel_h * kernel_w
sort_index = list(range(split_num))
sort_index_fp = (sort_index[1::2] + sort_index[::2])
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)}
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)}
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
sort_index_fp = torch.IntTensor(sort_index_fp)
sort_index_bp = torch.IntTensor(sort_index_bp)
Expand Down

0 comments on commit f76de90

Please sign in to comment.