Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add support of points_in_polyogns for Ascend device #2864

Merged
merged 1 commit into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ We implement common ops used in detection, segmentation, etc.
| NMSQuadri | √ | √ | | | |
| PixelGroup | √ | | | | |
| PointsInBoxes | √ | √ | | | |
| PointsInPolygons | | √ | | | |
| PointsInPolygons | | √ | | | |
| PSAMask | √ | √ | √ | | √ |
| RotatedFeatureAlign | √ | √ | √ | | |
| RoIPointPool3d | | √ | √ | | |
Expand Down
27 changes: 27 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

constexpr int32_t MAX_POLYGONS_BATCH = 2800;

void points_in_polygons_npu(const Tensor points, Tensor polygons, Tensor output,
const int rows, const int cols) {
TORCH_CHECK(
(polygons.sizes()[0] <= MAX_POLYGONS_BATCH),
"The batch of polygons tensor must be less than MAX_POLYGONS_BATCH");
at::Tensor trans_polygons = polygons.transpose(0, 1);
OpCommand cmd;
at::Tensor new_trans_polygons = NpuUtils::format_contiguous(trans_polygons);
cmd.Name("PointsInPolygons")
.Input(points, (string) "points")
.Input(new_trans_polygons, (string) "polygons")
.Output(output)
.Run();
}

void points_in_polygons_forward_impl(const Tensor points, Tensor polygons,
Tensor output, const int rows,
const int cols);

REGISTER_NPU_IMPL(points_in_polygons_forward_impl, points_in_polygons_npu);
7 changes: 5 additions & 2 deletions mmcv/ops/points_in_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor:
assert polygons.shape[1] == 8, \
'polygons dimension should be 8, ' \
f'but got unexpected shape {polygons.shape[1]}'
output = torch.full([points.shape[0], polygons.shape[0]],
0.).cuda().float()
output = torch.zeros(
points.shape[0],
polygons.shape[0],
dtype=torch.float32,
device=points.device)
ext_module.points_in_polygons_forward(points.contiguous(),
polygons.contiguous(), output)
return output
27 changes: 18 additions & 9 deletions tests/test_ops/test_points_in_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,29 @@
import torch

from mmcv.ops import points_in_polygons
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_points_in_polygons():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_points_in_polygons(device):
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
[100, 0]])
polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.],
[400., 400., 500., 500., 600., 300., 500., 200.],
[300., 300., 600., 700., 700., 700., 700., 100.]])
expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.],
[1., 0., 0.], [0., 0., 0.]])
points = torch.from_numpy(points).cuda().float()
polygons = torch.from_numpy(polygons).cuda().float()
expected_output = torch.from_numpy(expected_output).cuda().float()
assert torch.allclose(
points_in_polygons(points, polygons), expected_output, 1e-3)
[1., 0., 0.], [0., 0., 0.]]).astype(np.float32)
points = torch.tensor(points, dtype=torch.float32, device=device)
polygons = torch.tensor(polygons, dtype=torch.float32, device=device)
assert np.allclose(
points_in_polygons(points, polygons).cpu().numpy(), expected_output,
1e-3)