forked from open-mmlab/mmcv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add support for Ascend devices with nms_rotated (open-mmlab…
…#2550) * [Feature]: add nms_rotated npu adaptater code * [BugFix]: modify param in nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated.cpp * [Doc]: add nms_rotated op in supported op list at ops.md * [Test]: add nms_rotated unit_test * [Bug]: remove device parameter in test_batched_nms function
- Loading branch information
1 parent
7e8bec1
commit a081009
Showing
7 changed files
with
93 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#include "pytorch_npu_helper.hpp" | ||
|
||
using namespace NPU_NAME_SPACE; | ||
|
||
Tensor nms_rotated_npu(const Tensor dets, const Tensor scores, | ||
const Tensor labels, const float iou_threshold) { | ||
auto originDtype = dets.scalar_type(); | ||
at::Tensor detsCast = dets; | ||
at::Tensor scoresCast = scores; | ||
if (originDtype != at::ScalarType::Float) { | ||
detsCast = NPUNativeFunctions::npu_dtype_cast(dets, at::kFloat); | ||
scoresCast = NPUNativeFunctions::npu_dtype_cast(scores, at::kFloat); | ||
} | ||
c10::SmallVector<int64_t, SIZE> selectedIndexSize = {dets.size(0)}; | ||
at::Tensor selectedBox = OpPreparation::ApplyTensor(dets); | ||
at::Tensor selectedIndex = OpPreparation::ApplyTensor( | ||
selectedIndexSize, dets.options().dtype(at::kInt), dets); | ||
|
||
c10::SmallVector<int64_t, N> output_sync_idx = {0, 1}; | ||
OpCommand cmd; | ||
cmd.Sync(output_sync_idx) | ||
.Name("RotatedNMS") | ||
.Input(detsCast) | ||
.Input(scoresCast) | ||
.Input(labels) | ||
.Output(selectedBox) | ||
.Output(selectedIndex) | ||
.Attr("iou_threshold", (float)iou_threshold) | ||
.Run(); | ||
selectedIndex = NPUNativeFunctions::npu_dtype_cast(selectedIndex, at::kLong); | ||
return selectedIndex; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters