Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add support for Ascend devices with nms_rotated #2550

Merged
merged 9 commits into from
Jan 28, 2023
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc.
| ModulatedDeformConv2d | √ | √ | √ | | √ |
| MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | | | |
| NMSRotated | √ | √ | | | |
| NMSQuadri | √ | √ | | | |
| PixelGroup | √ | | | | |
| PointsInBoxes | √ | √ | | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ModulatedDeformConv2d | √ | √ | √ | | √ |
| MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | | | |
| NMSRotated | √ | √ | | | |
| NMSQuadri | √ | √ | | | |
| PixelGroup | √ | | | | |
| PointsInBoxes | √ | √ | | | |
Expand Down
15 changes: 13 additions & 2 deletions mmcv/ops/csrc/pytorch/nms_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,30 @@ Tensor nms_rotated_cuda(const Tensor dets, const Tensor scores,
const float iou_threshold, const int multi_label);
#endif

#ifdef MMCV_WITH_NPU
Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
const Tensor labels, const float iou_threshold);
#endif

// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label) {
const Tensor dets_sorted, const Tensor labels,
const float iou_threshold, const int multi_label) {
assert(dets.device().is_cuda() == scores.device().is_cuda());
if (dets.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
return nms_rotated_cuda(dets, scores, order, dets_sorted, iou_threshold,
multi_label);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else if (dets.device().type() == at::kXLA) {
#ifdef MMCV_WITH_NPU
return nms_rotated_npu(dets, scores, labels, iou_threshold);
#else
AT_ERROR("Not compiled with NPU support");
#endif
}

Expand Down
32 changes: 32 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;

Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
const Tensor labels, const float iou_threshold) {
auto originDtype = dets.scalar_type();
at::Tensor detsCast = dets;
at::Tensor scoresCast = scores;
if (originDtype != at::ScalarType::Float) {
detsCast = NPUNativeFunctions::npu_dtype_cast(dets, at::kFloat);
scoresCast = NPUNativeFunctions::npu_dtype_cast(scores, at::kFloat);
}
c10::SmallVector<int64_t, SIZE> selectedIndexSize = {dets.size(0)};
at::Tensor selectedBox = OpPreparation::ApplyTensor(dets);
at::Tensor selectedIndex = OpPreparation::ApplyTensor(
selectedIndexSize, dets.options().dtype(at::kInt), dets);

c10::SmallVector<int64_t, N> output_sync_idx = {0, 1};
OpCommand cmd;
cmd.Sync(output_sync_idx)
.Name("RotatedNMS")
.Input(detsCast)
.Input(scoresCast)
.Input(labels)
.Output(selectedBox)
.Output(selectedIndex)
.Attr("iou_threshold", (float)iou_threshold)
.Run();
selectedIndex = NPUNativeFunctions::npu_dtype_cast(selectedIndex, at::kLong);
return selectedIndex;
}
6 changes: 3 additions & 3 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);

Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label);
const Tensor dets_sorted, const Tensor labels,
const float iou_threshold, const int multi_label);

Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
Expand Down Expand Up @@ -748,7 +748,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("mode_flag"), py::arg("aligned"));
m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"),
py::arg("scores"), py::arg("order"), py::arg("dets_sorted"),
py::arg("iou_threshold"), py::arg("multi_label"));
py::arg("labels"), py::arg("iou_threshold"), py::arg("multi_label"));
m.def("ball_query_forward", &ball_query_forward, "ball_query_forward",
py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"),
py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"),
Expand Down
17 changes: 16 additions & 1 deletion mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,19 @@ def nms_rotated(dets: Tensor,
else:
dets_cw = dets
multi_label = labels is not None
if labels is None:
input_labels = scores.new_empty(0, dtype=torch.int)
else:
input_labels = labels
if dets.device.type == 'npu':
order = scores.new_empty(0, dtype=torch.long)
keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
input_labels, iou_threshold,
multi_label)
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds

if multi_label:
dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) # type: ignore
else:
Expand All @@ -467,11 +480,13 @@ def nms_rotated(dets: Tensor,
scores,
order,
dets_sorted,
input_labels,
iou_threshold=iou_threshold,
multi_label=multi_label)
else:
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
iou_threshold, multi_label)
input_labels, iou_threshold,
multi_label)
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds
Expand Down
35 changes: 27 additions & 8 deletions tests/test_ops/test_nms_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@
import pytest
import torch

from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.skipif(
not torch.cuda.is_available(),
reason='GPU is required to test NMSRotated op')
class TestNmsRotated:

def test_ml_nms_rotated(self):
@pytest.mark.parametrize('device', [
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
])
def test_ml_nms_rotated(self, device):
from mmcv.ops import nms_rotated
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
Expand All @@ -24,8 +33,8 @@ def test_ml_nms_rotated(self):
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 0], dtype=np.int64)

boxes = torch.from_numpy(np_boxes).cuda()
labels = torch.from_numpy(np_labels).cuda()
boxes = torch.from_numpy(np_boxes).to(device)
labels = torch.from_numpy(np_labels).to(device)

# test cw angle definition
dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5, labels)
Expand All @@ -41,7 +50,17 @@ def test_ml_nms_rotated(self):
assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)

def test_nms_rotated(self):
@pytest.mark.parametrize('device', [
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
])
def test_nms_rotated(self, device):
from mmcv.ops import nms_rotated
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
Expand All @@ -55,7 +74,7 @@ def test_nms_rotated(self):
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 0], dtype=np.int64)

boxes = torch.from_numpy(np_boxes).cuda()
boxes = torch.from_numpy(np_boxes).to(device)

# test cw angle definition
dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5)
Expand Down