Skip to content

Commit

Permalink
add a new aicore operator 'points_in_polyogns'
Browse files Browse the repository at this point in the history
  • Loading branch information
long11111111 committed Jun 26, 2023
1 parent 4f51a93 commit f4037c3
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ We implement common ops used in detection, segmentation, etc.
| NMSQuadri ||| | | |
| PixelGroup || | | | |
| PointsInBoxes ||| | | |
| PointsInPolygons | || | | |
| PointsInPolygons | || | | |
| PSAMask |||| ||
| RotatedFeatureAlign |||| | |
| RoIPointPool3d | ||| | |
Expand Down
25 changes: 25 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

constexpr int32_t MAX_POLYGONS_BATCH = 2800;


void points_in_polygons_npu(const Tensor points, Tensor polygons, Tensor output, const int rows,
const int cols) {
TORCH_CHECK((polygons.sizes()[0] <= MAX_POLYGONS_BATCH),
"The batch of polygons tensor must be less than MAX_POLYGONS_BATCH");
at::Tensor trans_polygons = polygons.transpose(0, 1);
OpCommand cmd;
at::Tensor new_trans_polygons = NpuUtils::format_contiguous(trans_polygons);
cmd.Name("PointsInPolygons")
.Input(points, (string)"points")
.Input(new_trans_polygons, (string)"polygons")
.Output(output)
.Run();
}

void points_in_polygons_forward_impl(const Tensor points, Tensor polygons, Tensor output, const int rows, const int cols);

REGISTER_NPU_IMPL(points_in_polygons_forward_impl, points_in_polygons_npu);
14 changes: 10 additions & 4 deletions mmcv/ops/points_in_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor:
assert polygons.shape[1] == 8, \
'polygons dimension should be 8, ' \
f'but got unexpected shape {polygons.shape[1]}'
output = torch.full([points.shape[0], polygons.shape[0]],
0.).cuda().float()
ext_module.points_in_polygons_forward(points.contiguous(),
polygons.contiguous(), output)
if points.device.type == "npu":
output = torch.full([points.shape[0], polygons.shape[0]],
0.).to(torch.float32).npu()
ext_module.points_in_polygons_forward(points.contiguous(),
polygons.contiguous(), output)
else:
output = torch.full([points.shape[0], polygons.shape[0]],
0.).cuda().float()
ext_module.points_in_polygons_forward(points.contiguous(),
polygons.contiguous(), output)
return output
31 changes: 31 additions & 0 deletions tests/test_ops/test_points_in_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from mmcv.ops import points_in_polygons
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.skipif(
Expand All @@ -21,3 +22,33 @@ def test_points_in_polygons():
expected_output = torch.from_numpy(expected_output).cuda().float()
assert torch.allclose(
points_in_polygons(points, polygons), expected_output, 1e-3)

@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_points_in_polygons_device(device):
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
[100, 0]])
polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.],
[400., 400., 500., 500., 600., 300., 500., 200.],
[300., 300., 600., 700., 700., 700., 700., 100.]])
expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.],
[1., 0., 0.], [0., 0., 0.]])
points = torch.from_numpy(points).to(torch.float32)
polygons = torch.from_numpy(polygons).to(torch.float32)
points = points.to(device)
polygons = polygons.to(device)
expected_output = torch.from_numpy(expected_output).to(torch.float32)
res = points_in_polygons(points, polygons).cpu()
assert np.allclose(res, expected_output, atol=1e-4)

0 comments on commit f4037c3

Please sign in to comment.