From 38846d0ce1b2ac71272c38a26e1b37d7a27e92ba Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Thu, 21 Apr 2022 15:07:36 +0300 Subject: [PATCH 1/4] add iou3d --- docs/en/understand_mmcv/ops.md | 1 + docs/zh_cn/understand_mmcv/ops.md | 1 + mmcv/ops/__init__.py | 20 +- .../csrc/common/cuda/iou3d_cuda_kernel.cuh | 173 +++++++++--------- mmcv/ops/csrc/pytorch/cuda/cudabind.cpp | 101 +++++----- mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu | 51 ++---- mmcv/ops/csrc/pytorch/iou3d.cpp | 66 +++---- mmcv/ops/csrc/pytorch/pybind.cpp | 26 +-- mmcv/ops/iou3d.py | 66 +++---- tests/test_ops/test_iou3d.py | 83 ++++++--- 10 files changed, 270 insertions(+), 318 deletions(-) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index f100558fc4..e07df3e43b 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -24,6 +24,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - MaskedConv - MinAreaPolygon - NMS +- NMS3D - PointsInPolygons - PSAMask - RiRoIAlignRotated diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 776b05536f..199de199e0 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - MaskedConv - MinAreaPolygon - NMS +- NMS3D - PointsInPolygons - PSAMask - RotatedFeatureAlign diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index fd112f049d..2bbf9096f0 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -28,7 +28,7 @@ from .group_points import GroupAll, QueryAndGroup, grouping_operation from .info import (get_compiler_version, get_compiling_cuda_version, get_onnxruntime_op_path) -from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev +from .iou3d import boxes_iou3d, nms3d, nms3d_normal from .knn import knn from .masked_conv import MaskedConv2d, masked_conv2d from .min_area_polygons import min_area_polygons @@ -89,13 +89,13 @@ 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample', 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', - 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization', - 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', - 'SparseConv2d', 'SparseConv3d', 'SparseConvTranspose2d', - 'SparseConvTranspose3d', 'SparseInverseConv2d', 'SparseInverseConv3d', - 'SubMConv2d', 'SubMConv3d', 'SparseModule', 'SparseSequential', - 'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd', - 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all', - 'points_in_polygons', 'min_area_polygons', 'active_rotated_filter', - 'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d' + 'boxes_iou3d', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization', + 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d', + 'SparseConv3d', 'SparseConvTranspose2d', 'SparseConvTranspose3d', + 'SparseInverseConv2d', 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d', + 'SparseModule', 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d', + 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part', + 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', + 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou', + 'diff_iou_rotated_2d', 'diff_iou_rotated_3d' ] diff --git a/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh index c08dffb3a3..cd733e770e 100644 --- a/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh @@ -50,21 +50,17 @@ __device__ int check_rect_cross(const Point &p1, const Point &p2, } __device__ inline int check_in_box2d(const float *box, const Point &p) { - // params: box (5) [x1, y1, x2, y2, angle] - const float MARGIN = 1e-5; - - float center_x = (box[0] + box[2]) / 2; - float center_y = (box[1] + box[3]) / 2; - float angle_cos = cos(-box[4]), - angle_sin = - sin(-box[4]); // rotate the point in the opposite direction of box - float rot_x = - (p.x - center_x) * angle_cos - (p.y - center_y) * angle_sin + center_x; - float rot_y = - (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y; - - return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN && - rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN); + // params: box (7) [x, y, z, dx, dy, dz, heading] + const float MARGIN = 1e-2; + + float center_x = box[0], center_y = box[1]; + // rotate the point in the opposite direction of box + float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]); + float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin); + float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos; + + return (fabs(rot_x) < box[3] / 2 + MARGIN && + fabs(rot_y) < box[4] / 2 + MARGIN); } __device__ inline int intersection(const Point &p1, const Point &p0, @@ -116,16 +112,19 @@ __device__ inline int point_cmp(const Point &a, const Point &b, } __device__ inline float box_overlap(const float *box_a, const float *box_b) { - // params: box_a (5) [x1, y1, x2, y2, angle] - // params: box_b (5) [x1, y1, x2, y2, angle] + // params box_a: [x, y, z, dx, dy, dz, heading] + // params box_b: [x, y, z, dx, dy, dz, heading] - float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3], - a_angle = box_a[4]; - float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3], - b_angle = box_b[4]; + float a_angle = box_a[6], b_angle = box_b[6]; + float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2, + a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2; + float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half; + float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half; + float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half; + float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half; - Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2); - Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2); + Point center_a(box_a[0], box_a[1]); + Point center_b(box_b[0], box_b[1]); Point box_a_corners[5]; box_a_corners[0].set(a_x1, a_y1); @@ -209,50 +208,36 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) { } __device__ inline float iou_bev(const float *box_a, const float *box_b) { - // params: box_a (5) [x1, y1, x2, y2, angle] - // params: box_b (5) [x1, y1, x2, y2, angle] - float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]); - float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]); + // params box_a: [x, y, z, dx, dy, dz, heading] + // params box_b: [x, y, z, dx, dy, dz, heading] + float sa = box_a[3] * box_a[4]; + float sb = box_b[3] * box_b[4]; float s_overlap = box_overlap(box_a, box_b); return s_overlap / fmaxf(sa + sb - s_overlap, EPS); } -__global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel( - const int num_a, const float *boxes_a, const int num_b, - const float *boxes_b, float *ans_overlap) { - CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) { - if (a_idx >= num_a || b_idx >= num_b) { - return; - } - const float *cur_box_a = boxes_a + a_idx * 5; - const float *cur_box_b = boxes_b + b_idx * 5; - float s_overlap = box_overlap(cur_box_a, cur_box_b); - ans_overlap[a_idx * num_b + b_idx] = s_overlap; - } -} - -__global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a, - const float *boxes_a, - const int num_b, - const float *boxes_b, - float *ans_iou) { +__global__ void iou3d_boxes_iou3d_forward_cuda_kernel(const int num_a, + const float *boxes_a, + const int num_b, + const float *boxes_b, + float *ans_iou) { CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) { if (a_idx >= num_a || b_idx >= num_b) { return; } - const float *cur_box_a = boxes_a + a_idx * 5; - const float *cur_box_b = boxes_b + b_idx * 5; + const float *cur_box_a = boxes_a + a_idx * 7; + const float *cur_box_b = boxes_b + b_idx * 7; float cur_iou_bev = iou_bev(cur_box_a, cur_box_b); ans_iou[a_idx * num_b + b_idx] = cur_iou_bev; } } -__global__ void nms_forward_cuda_kernel(const int boxes_num, - const float nms_overlap_thresh, - const float *boxes, - unsigned long long *mask) { - // params: boxes (N, 5) [x1, y1, x2, y2, ry] +__global__ void iou3d_nms3d_forward_cuda_kernel(const int boxes_num, + const float nms_overlap_thresh, + const float *boxes, + unsigned long long *mask) { + // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading] // params: mask (N, N/THREADS_PER_BLOCK_NMS) const int blocks = (boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS; @@ -264,25 +249,29 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num, const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); - __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7]; if (threadIdx.x < col_size) { - block_boxes[threadIdx.x * 5 + 0] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; - block_boxes[threadIdx.x * 5 + 1] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; - block_boxes[threadIdx.x * 5 + 2] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; - block_boxes[threadIdx.x * 5 + 3] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; - block_boxes[threadIdx.x * 5 + 4] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; + block_boxes[threadIdx.x * 7 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0]; + block_boxes[threadIdx.x * 7 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1]; + block_boxes[threadIdx.x * 7 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2]; + block_boxes[threadIdx.x * 7 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3]; + block_boxes[threadIdx.x * 7 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4]; + block_boxes[threadIdx.x * 7 + 5] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5]; + block_boxes[threadIdx.x * 7 + 6] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6]; } __syncthreads(); if (threadIdx.x < row_size) { const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; - const float *cur_box = boxes + cur_box_idx * 5; + const float *cur_box = boxes + cur_box_idx * 7; int i = 0; unsigned long long t = 0; @@ -291,7 +280,7 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num, start = threadIdx.x + 1; } for (i = start; i < col_size; i++) { - if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh) { t |= 1ULL << i; } } @@ -303,20 +292,24 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num, } __device__ inline float iou_normal(float const *const a, float const *const b) { - float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); - float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); + // params: a: [x, y, z, dx, dy, dz, heading] + // params: b: [x, y, z, dx, dy, dz, heading] + + float left = fmaxf(a[0] - a[3] / 2, b[0] - b[3] / 2), + right = fminf(a[0] + a[3] / 2, b[0] + b[3] / 2); + float top = fmaxf(a[1] - a[4] / 2, b[1] - b[4] / 2), + bottom = fminf(a[1] + a[4] / 2, b[1] + b[4] / 2); float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f); float interS = width * height; - float Sa = (a[2] - a[0]) * (a[3] - a[1]); - float Sb = (b[2] - b[0]) * (b[3] - b[1]); + float Sa = a[3] * a[4]; + float Sb = b[3] * b[4]; return interS / fmaxf(Sa + Sb - interS, EPS); } -__global__ void nms_normal_forward_cuda_kernel(const int boxes_num, - const float nms_overlap_thresh, - const float *boxes, - unsigned long long *mask) { - // params: boxes (N, 5) [x1, y1, x2, y2, ry] +__global__ void iou3d_nms3d_normal_forward_cuda_kernel( + const int boxes_num, const float nms_overlap_thresh, const float *boxes, + unsigned long long *mask) { + // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading] // params: mask (N, N/THREADS_PER_BLOCK_NMS) const int blocks = @@ -329,25 +322,29 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num, const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); - __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7]; if (threadIdx.x < col_size) { - block_boxes[threadIdx.x * 5 + 0] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; - block_boxes[threadIdx.x * 5 + 1] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; - block_boxes[threadIdx.x * 5 + 2] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; - block_boxes[threadIdx.x * 5 + 3] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; - block_boxes[threadIdx.x * 5 + 4] = - boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; + block_boxes[threadIdx.x * 7 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0]; + block_boxes[threadIdx.x * 7 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1]; + block_boxes[threadIdx.x * 7 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2]; + block_boxes[threadIdx.x * 7 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3]; + block_boxes[threadIdx.x * 7 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4]; + block_boxes[threadIdx.x * 7 + 5] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5]; + block_boxes[threadIdx.x * 7 + 6] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6]; } __syncthreads(); if (threadIdx.x < row_size) { const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; - const float *cur_box = boxes + cur_box_idx * 5; + const float *cur_box = boxes + cur_box_idx * 7; int i = 0; unsigned long long t = 0; @@ -356,7 +353,7 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num, start = threadIdx.x + 1; } for (i = start; i < col_size; i++) { - if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + if (iou_normal(cur_box, block_boxes + i * 7) > nms_overlap_thresh) { t |= 1ULL << i; } } diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index 93b19d4b6b..e290ed891f 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -564,73 +564,58 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA, REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA, group_points_backward_cuda); -void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, - const Tensor boxes_a, - const int num_b, - const Tensor boxes_b, - Tensor ans_overlap); - -void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a, - const Tensor boxes_a, - const int num_b, - const Tensor boxes_b, - Tensor ans_iou); - -void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes, - unsigned long long* mask, int boxes_num, - float nms_overlap_thresh); - -void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes, - unsigned long long* mask, - int boxes_num, - float nms_overlap_thresh); - -void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a, - const int num_b, const Tensor boxes_b, - Tensor ans_overlap) { - IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b, - ans_overlap); +void IoU3DBoxesIoU3DForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_iou); + +void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long* mask, + int boxes_num, + float nms_overlap_thresh); + +void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long* mask, + int boxes_num, + float nms_overlap_thresh); + +void iou3d_boxes_iou3d_forward_cuda(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_iou) { + IoU3DBoxesIoU3DForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b, + ans_iou); }; -void iou3d_boxes_iou_bev_forward_cuda(const int num_a, const Tensor boxes_a, - const int num_b, const Tensor boxes_b, - Tensor ans_iou) { - IoU3DBoxesIoUBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b, - ans_iou); +void iou3d_nms3d_forward_cuda(const Tensor boxes, unsigned long long* mask, + int boxes_num, float nms_overlap_thresh) { + IoU3DNMS3DForwardCUDAKernelLauncher(boxes, mask, boxes_num, + nms_overlap_thresh); }; -void iou3d_nms_forward_cuda(const Tensor boxes, unsigned long long* mask, - int boxes_num, float nms_overlap_thresh) { - IoU3DNMSForwardCUDAKernelLauncher(boxes, mask, boxes_num, nms_overlap_thresh); +void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, + unsigned long long* mask, int boxes_num, + float nms_overlap_thresh) { + IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num, + nms_overlap_thresh); }; -void iou3d_nms_normal_forward_cuda(const Tensor boxes, unsigned long long* mask, - int boxes_num, float nms_overlap_thresh) { - IoU3DNMSNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num, - nms_overlap_thresh); -}; - -void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a, - const int num_b, const Tensor boxes_b, - Tensor ans_overlap); - -void iou3d_boxes_iou_bev_forward_impl(const int num_a, const Tensor boxes_a, - const int num_b, const Tensor boxes_b, - Tensor ans_iou); +void iou3d_boxes_iou3d_forward_impl(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_iou); -void iou3d_nms_forward_impl(const Tensor boxes, unsigned long long* mask, - int boxes_num, float nms_overlap_thresh); +void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long* mask, + int boxes_num, float nms_overlap_thresh); -void iou3d_nms_normal_forward_impl(const Tensor boxes, unsigned long long* mask, - int boxes_num, float nms_overlap_thresh); +void iou3d_nms3d_normal_forward_impl(const Tensor boxes, + unsigned long long* mask, int boxes_num, + float nms_overlap_thresh); -REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA, - iou3d_boxes_overlap_bev_forward_cuda); -REGISTER_DEVICE_IMPL(iou3d_boxes_iou_bev_forward_impl, CUDA, - iou3d_boxes_iou_bev_forward_cuda); -REGISTER_DEVICE_IMPL(iou3d_nms_forward_impl, CUDA, iou3d_nms_forward_cuda); -REGISTER_DEVICE_IMPL(iou3d_nms_normal_forward_impl, CUDA, - iou3d_nms_normal_forward_cuda); +REGISTER_DEVICE_IMPL(iou3d_boxes_iou3d_forward_impl, CUDA, + iou3d_boxes_iou3d_forward_cuda); +REGISTER_DEVICE_IMPL(iou3d_nms3d_forward_impl, CUDA, iou3d_nms3d_forward_cuda); +REGISTER_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, CUDA, + iou3d_nms3d_normal_forward_cuda); void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample, const Tensor xyz, const Tensor new_xyz, diff --git a/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu index bf930c79ac..f79c7f9582 100644 --- a/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu @@ -12,11 +12,11 @@ All Rights Reserved 2019-2020. #include "iou3d_cuda_kernel.cuh" #include "pytorch_cuda_helper.hpp" -void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, - const Tensor boxes_a, - const int num_b, - const Tensor boxes_b, - Tensor ans_overlap) { +void IoU3DBoxesIoU3DForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_iou) { at::cuda::CUDAGuard device_guard(boxes_a.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -25,36 +25,17 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, GET_BLOCKS(num_a, THREADS_PER_BLOCK_IOU3D)); dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D); - iou3d_boxes_overlap_bev_forward_cuda_kernel<<>>( - num_a, boxes_a.data_ptr(), num_b, boxes_b.data_ptr(), - ans_overlap.data_ptr()); - - AT_CUDA_CHECK(cudaGetLastError()); -} - -void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a, - const Tensor boxes_a, - const int num_b, - const Tensor boxes_b, - Tensor ans_iou) { - at::cuda::CUDAGuard device_guard(boxes_a.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - // blockIdx.x(col), blockIdx.y(row) - dim3 blocks(GET_BLOCKS(num_b, THREADS_PER_BLOCK_IOU3D), - GET_BLOCKS(num_a, THREADS_PER_BLOCK_IOU3D)); - dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D); - - iou3d_boxes_iou_bev_forward_cuda_kernel<<>>( + iou3d_boxes_iou3d_forward_cuda_kernel<<>>( num_a, boxes_a.data_ptr(), num_b, boxes_b.data_ptr(), ans_iou.data_ptr()); AT_CUDA_CHECK(cudaGetLastError()); } -void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes, - unsigned long long *mask, int boxes_num, - float nms_overlap_thresh) { +void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, + int boxes_num, + float nms_overlap_thresh) { at::cuda::CUDAGuard device_guard(boxes.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -62,16 +43,16 @@ void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes, GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS)); dim3 threads(THREADS_PER_BLOCK_NMS); - nms_forward_cuda_kernel<<>>( + iou3d_nms3d_forward_cuda_kernel<<>>( boxes_num, nms_overlap_thresh, boxes.data_ptr(), mask); AT_CUDA_CHECK(cudaGetLastError()); } -void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes, - unsigned long long *mask, - int boxes_num, - float nms_overlap_thresh) { +void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, + int boxes_num, + float nms_overlap_thresh) { at::cuda::CUDAGuard device_guard(boxes.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -79,7 +60,7 @@ void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes, GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS)); dim3 threads(THREADS_PER_BLOCK_NMS); - nms_normal_forward_cuda_kernel<<>>( + iou3d_nms3d_normal_forward_cuda_kernel<<>>( boxes_num, nms_overlap_thresh, boxes.data_ptr(), mask); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/mmcv/ops/csrc/pytorch/iou3d.cpp b/mmcv/ops/csrc/pytorch/iou3d.cpp index 8af4f05daa..813ffff0fd 100644 --- a/mmcv/ops/csrc/pytorch/iou3d.cpp +++ b/mmcv/ops/csrc/pytorch/iou3d.cpp @@ -12,59 +12,39 @@ All Rights Reserved 2019-2020. const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; -void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a, - const int num_b, const Tensor boxes_b, - Tensor ans_overlap) { - DISPATCH_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, num_a, boxes_a, - num_b, boxes_b, ans_overlap); -} - -void iou3d_boxes_iou_bev_forward_impl(const int num_a, const Tensor boxes_a, - const int num_b, const Tensor boxes_b, - Tensor ans_iou) { - DISPATCH_DEVICE_IMPL(iou3d_boxes_iou_bev_forward_impl, num_a, boxes_a, num_b, +void iou3d_boxes_iou3d_forward_impl(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_iou) { + DISPATCH_DEVICE_IMPL(iou3d_boxes_iou3d_forward_impl, num_a, boxes_a, num_b, boxes_b, ans_iou); } -void iou3d_nms_forward_impl(const Tensor boxes, unsigned long long *mask, - int boxes_num, float nms_overlap_thresh) { - DISPATCH_DEVICE_IMPL(iou3d_nms_forward_impl, boxes, mask, boxes_num, +void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh) { + DISPATCH_DEVICE_IMPL(iou3d_nms3d_forward_impl, boxes, mask, boxes_num, nms_overlap_thresh); } -void iou3d_nms_normal_forward_impl(const Tensor boxes, unsigned long long *mask, - int boxes_num, float nms_overlap_thresh) { - DISPATCH_DEVICE_IMPL(iou3d_nms_normal_forward_impl, boxes, mask, boxes_num, +void iou3d_nms3d_normal_forward_impl(const Tensor boxes, + unsigned long long *mask, int boxes_num, + float nms_overlap_thresh) { + DISPATCH_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, boxes, mask, boxes_num, nms_overlap_thresh); } -void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b, - Tensor ans_overlap) { - // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] - // params boxes_b: (M, 5) - // params ans_overlap: (N, M) - - int num_a = boxes_a.size(0); - int num_b = boxes_b.size(0); - - iou3d_boxes_overlap_bev_forward_impl(num_a, boxes_a, num_b, boxes_b, - ans_overlap); -} - -void iou3d_boxes_iou_bev_forward(Tensor boxes_a, Tensor boxes_b, - Tensor ans_iou) { - // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] +void iou3d_boxes_iou3d_forward(Tensor boxes_a, Tensor boxes_b, Tensor ans_iou) { + // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading] // params boxes_b: (M, 5) // params ans_overlap: (N, M) int num_a = boxes_a.size(0); int num_b = boxes_b.size(0); - iou3d_boxes_iou_bev_forward_impl(num_a, boxes_a, num_b, boxes_b, ans_iou); + iou3d_boxes_iou3d_forward_impl(num_a, boxes_a, num_b, boxes_b, ans_iou); } -void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num, - float nms_overlap_thresh) { - // params boxes: (N, 5) [x1, y1, x2, y2, ry] +void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num, + float nms_overlap_thresh) { + // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading] // params keep: (N) CHECK_CONTIGUOUS(boxes); CHECK_CONTIGUOUS(keep); @@ -80,7 +60,7 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num, at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong)); unsigned long long *mask_data = (unsigned long long *)mask.data_ptr(); - iou3d_nms_forward_impl(boxes, mask_data, boxes_num, nms_overlap_thresh); + iou3d_nms3d_forward_impl(boxes, mask_data, boxes_num, nms_overlap_thresh); at::Tensor mask_cpu = mask.to(at::kCPU); unsigned long long *mask_host = @@ -106,9 +86,9 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num, } } -void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num, - float nms_overlap_thresh) { - // params boxes: (N, 5) [x1, y1, x2, y2, ry] +void iou3d_nms3d_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num, + float nms_overlap_thresh) { + // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading] // params keep: (N) CHECK_CONTIGUOUS(boxes); @@ -125,8 +105,8 @@ void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num, at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong)); unsigned long long *mask_data = (unsigned long long *)mask.data_ptr(); - iou3d_nms_normal_forward_impl(boxes, mask_data, boxes_num, - nms_overlap_thresh); + iou3d_nms3d_normal_forward_impl(boxes, mask_data, boxes_num, + nms_overlap_thresh); at::Tensor mask_cpu = mask.to(at::kCPU); unsigned long long *mask_host = diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index b53ef3fb10..7909aac197 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -115,17 +115,14 @@ void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor, int b, int n, int m, int nsample); -void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b, - Tensor ans_overlap); -void iou3d_boxes_iou_bev_forward(Tensor boxes_a, Tensor boxes_b, - Tensor ans_iou); +void iou3d_boxes_iou3d_forward(Tensor boxes_a, Tensor boxes_b, Tensor ans_iou); -void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num, - float nms_overlap_thresh); +void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num, + float nms_overlap_thresh); -void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num, - float nms_overlap_thresh); +void iou3d_nms3d_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num, + float nms_overlap_thresh); void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, int b, int n, int m); @@ -535,17 +532,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"), py::arg("new_xyz_tensor"), py::arg("idx_tensor"), py::arg("dist2_tensor")); - m.def("iou3d_boxes_overlap_bev_forward", &iou3d_boxes_overlap_bev_forward, - "iou3d_boxes_overlap_bev_forward", py::arg("boxes_a"), - py::arg("boxes_b"), py::arg("ans_overlap")); - m.def("iou3d_boxes_iou_bev_forward", &iou3d_boxes_iou_bev_forward, - "iou3d_boxes_iou_bev_forward", py::arg("boxes_a"), py::arg("boxes_b"), + m.def("iou3d_boxes_iou3d_forward", &iou3d_boxes_iou3d_forward, + "iou3d_boxes_iou3d_forward", py::arg("boxes_a"), py::arg("boxes_b"), py::arg("ans_iou")); - m.def("iou3d_nms_forward", &iou3d_nms_forward, "iou3d_nms_forward", + m.def("iou3d_nms3d_forward", &iou3d_nms3d_forward, "iou3d_nms3d_forward", py::arg("boxes"), py::arg("keep"), py::arg("num_out"), py::arg("nms_overlap_thresh")); - m.def("iou3d_nms_normal_forward", &iou3d_nms_normal_forward, - "iou3d_nms_normal_forward", py::arg("boxes"), py::arg("keep"), + m.def("iou3d_nms3d_normal_forward", &iou3d_nms3d_normal_forward, + "iou3d_nms3d_normal_forward", py::arg("boxes"), py::arg("keep"), py::arg("num_out"), py::arg("nms_overlap_thresh")); m.def("furthest_point_sampling_forward", &furthest_point_sampling_forward, "furthest_point_sampling_forward", py::arg("points_tensor"), diff --git a/mmcv/ops/iou3d.py b/mmcv/ops/iou3d.py index 8a5550658f..81d541bf85 100644 --- a/mmcv/ops/iou3d.py +++ b/mmcv/ops/iou3d.py @@ -4,17 +4,17 @@ from ..utils import ext_loader ext_module = ext_loader.load_ext('_ext', [ - 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward', - 'iou3d_nms_normal_forward' + 'iou3d_boxes_iou3d_forward', 'iou3d_nms3d_forward', + 'iou3d_nms3d_normal_forward' ]) -def boxes_iou_bev(boxes_a, boxes_b): - """Calculate boxes IoU in the Bird's Eye View. +def boxes_iou3d(boxes_a, boxes_b): + """Calculate boxes 3D IoU. Args: - boxes_a (torch.Tensor): Input boxes a with shape (M, 5). - boxes_b (torch.Tensor): Input boxes b with shape (N, 5). + boxes_a (torch.Tensor): Input boxes a with shape (M, 7). + boxes_b (torch.Tensor): Input boxes b with shape (N, 7). Returns: torch.Tensor: IoU result with shape (M, N). @@ -22,68 +22,56 @@ def boxes_iou_bev(boxes_a, boxes_b): ans_iou = boxes_a.new_zeros( torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) - ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(), - boxes_b.contiguous(), ans_iou) + ext_module.iou3d_boxes_iou3d_forward(boxes_a.contiguous(), + boxes_b.contiguous(), ans_iou) return ans_iou -def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): - """NMS function GPU implementation (for BEV boxes). The overlap of two - boxes for IoU calculation is defined as the exact overlapping area of the - two boxes. In this function, one can also set ``pre_max_size`` and - ``post_max_size``. +def nms3d(boxes, scores, iou_threshold): + """3D NMS function GPU implementation (for BEV boxes). Args: - boxes (torch.Tensor): Input boxes with the shape of [N, 5] - ([x1, y1, x2, y2, ry]). - scores (torch.Tensor): Scores of boxes with the shape of [N]. - thresh (float): Overlap threshold of NMS. - pre_max_size (int, optional): Max size of boxes before NMS. - Default: None. - post_max_size (int, optional): Max size of boxes after NMS. - Default: None. + boxes (torch.Tensor): Input boxes with the shape of (N, 7) + ([x, y, z, dx, dy, dz, heading]). + scores (torch.Tensor): Scores of boxes with the shape of (N). + iou_threshold (float): Overlap threshold of NMS. Returns: torch.Tensor: Indexes after NMS. """ - assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]' + assert boxes.size(1) == 7, 'Input boxes shape should be (N, 7)' order = scores.sort(0, descending=True)[1] - - if pre_max_size is not None: - order = order[:pre_max_size] boxes = boxes[order].contiguous() keep = torch.zeros(boxes.size(0), dtype=torch.long) num_out = torch.zeros(size=(), dtype=torch.long) - ext_module.iou3d_nms_forward( - boxes, keep, num_out, nms_overlap_thresh=thresh) + ext_module.iou3d_nms3d_forward( + boxes, keep, num_out, nms_overlap_thresh=iou_threshold) keep = order[keep[:num_out].cuda(boxes.device)].contiguous() - if post_max_size is not None: - keep = keep[:post_max_size] return keep -def nms_normal_bev(boxes, scores, thresh): - """Normal NMS function GPU implementation (for BEV boxes). The overlap of - two boxes for IoU calculation is defined as the exact overlapping area of - the two boxes WITH their yaw angle set to 0. +def nms3d_normal(boxes, scores, iou_threshold): + """Normal 3D NMS function GPU implementation. The overlap of two boxes for + IoU calculation is defined as the exact overlapping area of the two boxes + WITH their yaw angle set to 0. Args: - boxes (torch.Tensor): Input boxes with shape (N, 5). + boxes (torch.Tensor): Input boxes with shape (N, 7). + ([x, y, z, dx, dy, dz, heading]). scores (torch.Tensor): Scores of predicted boxes with shape (N). - thresh (float): Overlap threshold of NMS. + iou_threshold (float): Overlap threshold of NMS. Returns: torch.Tensor: Remaining indices with scores in descending order. """ - assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' + assert boxes.shape[1] == 7, 'Input boxes shape should be (N, 7)' order = scores.sort(0, descending=True)[1] - boxes = boxes[order].contiguous() keep = torch.zeros(boxes.size(0), dtype=torch.long) num_out = torch.zeros(size=(), dtype=torch.long) - ext_module.iou3d_nms_normal_forward( - boxes, keep, num_out, nms_overlap_thresh=thresh) + ext_module.iou3d_nms3d_normal_forward( + boxes, keep, num_out, nms_overlap_thresh=iou_threshold) return order[keep[:num_out].cuda(boxes.device)].contiguous() diff --git a/tests/test_ops/test_iou3d.py b/tests/test_ops/test_iou3d.py index 66cc44b947..a5ad3328d7 100644 --- a/tests/test_ops/test_iou3d.py +++ b/tests/test_ops/test_iou3d.py @@ -3,59 +3,84 @@ import pytest import torch -from mmcv.ops import boxes_iou_bev, nms_bev, nms_normal_bev +from mmcv.ops import boxes_iou3d, nms3d, nms3d_normal @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_boxes_iou_bev(): - np_boxes1 = np.asarray( - [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], - [7.0, 7.0, 8.0, 8.0, 0.4]], - dtype=np.float32) - np_boxes2 = np.asarray( - [[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5], - [5.0, 5.0, 6.0, 7.0, 0.4]], - dtype=np.float32) - np_expect_ious = np.asarray( - [[0.2621, 0.2948, 0.0000], [0.0549, 0.1587, 0.0000], - [0.0000, 0.0000, 0.0000]], - dtype=np.float32) +def test_boxes_iou3d(): + np_boxes1 = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0], + [3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.0]], + dtype=np.float32) + np_boxes2 = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0], + [1.0, 1.0, 1.0, 2.0, 2.0, 2.0, np.pi / 2], + [1.0, 1.0, 1.0, 2.0, 2.0, 2.0, np.pi / 4]], + dtype=np.float32) + np_expect_ious = np.asarray([[1.0, 1.0, 1.0 / 2**0.5], + [1.0 / 7, 1.0 / 7, 1.0 / 7], [0.0, 0.0, 0.0]], + dtype=np.float32) boxes1 = torch.from_numpy(np_boxes1).cuda() boxes2 = torch.from_numpy(np_boxes2).cuda() - ious = boxes_iou_bev(boxes1, boxes2) + ious = boxes_iou3d(boxes1, boxes2) assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_nms_bev(): - np_boxes = np.array( - [[6.0, 3.0, 8.0, 7.0, 2.0], [3.0, 6.0, 9.0, 11.0, 1.0], - [3.0, 7.0, 10.0, 12.0, 1.0], [1.0, 4.0, 13.0, 7.0, 3.0]], - dtype=np.float32) - np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) +def test_nms3d(): + # test for 5 boxes + np_boxes = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0], + [3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.3], + [3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.0], + [3.0, 3.2, 3.2, 3.0, 2.0, 2.0, 0.3]], + dtype=np.float32) + np_scores = np.array([0.6, 0.9, 0.1, 0.2, 0.15], dtype=np.float32) np_inds = np.array([1, 0, 3]) boxes = torch.from_numpy(np_boxes) scores = torch.from_numpy(np_scores) - inds = nms_bev(boxes.cuda(), scores.cuda(), thresh=0.3) + inds = nms3d(boxes.cuda(), scores.cuda(), iou_threshold=0.3) assert np.allclose(inds.cpu().numpy(), np_inds) + # test for many boxes + np.random.seed(42) + np_boxes = np.random.rand(555, 7).astype(np.float32) + np_scores = np.random.rand(555).astype(np.float32) + boxes = torch.from_numpy(np_boxes) + scores = torch.from_numpy(np_scores) + inds = nms3d(boxes.cuda(), scores.cuda(), iou_threshold=0.3) + + assert len(inds.cpu().numpy()) == 176 + @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_nms_normal_bev(): - np_boxes = np.array( - [[6.0, 3.0, 8.0, 7.0, 2.0], [3.0, 6.0, 9.0, 11.0, 1.0], - [3.0, 7.0, 10.0, 12.0, 1.0], [1.0, 4.0, 13.0, 7.0, 3.0]], - dtype=np.float32) - np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) +def test_nms3d_normal(): + # test for 5 boxes + np_boxes = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0], + [3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.3], + [3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.0], + [3.0, 3.2, 3.2, 3.0, 2.0, 2.0, 0.3]], + dtype=np.float32) + np_scores = np.array([0.6, 0.9, 0.1, 0.2, 0.15], dtype=np.float32) np_inds = np.array([1, 0, 3]) boxes = torch.from_numpy(np_boxes) scores = torch.from_numpy(np_scores) - inds = nms_normal_bev(boxes.cuda(), scores.cuda(), thresh=0.3) + inds = nms3d_normal(boxes.cuda(), scores.cuda(), iou_threshold=0.3) assert np.allclose(inds.cpu().numpy(), np_inds) + + # test for many boxes + np.random.seed(42) + np_boxes = np.random.rand(555, 7).astype(np.float32) + np_scores = np.random.rand(555).astype(np.float32) + boxes = torch.from_numpy(np_boxes) + scores = torch.from_numpy(np_scores) + inds = nms3d_normal(boxes.cuda(), scores.cuda(), iou_threshold=0.3) + + assert len(inds.cpu().numpy()) == 148 From 02997ae20659332b9e007dc4326249af65d4a303 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Sat, 23 Apr 2022 14:32:04 +0300 Subject: [PATCH 2/4] revert deprecated python function --- mmcv/ops/__init__.py | 3 +- mmcv/ops/iou3d.py | 114 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 2bbf9096f0..90d52dd66a 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -28,7 +28,8 @@ from .group_points import GroupAll, QueryAndGroup, grouping_operation from .info import (get_compiler_version, get_compiling_cuda_version, get_onnxruntime_op_path) -from .iou3d import boxes_iou3d, nms3d, nms3d_normal +from .iou3d import (boxes_iou3d, boxes_iou_bev, nms3d, nms3d_normal, nms_bev, + nms_normal_bev) from .knn import knn from .masked_conv import MaskedConv2d, masked_conv2d from .min_area_polygons import min_area_polygons diff --git a/mmcv/ops/iou3d.py b/mmcv/ops/iou3d.py index 81d541bf85..563bbbd4ce 100644 --- a/mmcv/ops/iou3d.py +++ b/mmcv/ops/iou3d.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings + import torch from ..utils import ext_loader @@ -75,3 +77,115 @@ def nms3d_normal(boxes, scores, iou_threshold): ext_module.iou3d_nms3d_normal_forward( boxes, keep, num_out, nms_overlap_thresh=iou_threshold) return order[keep[:num_out].cuda(boxes.device)].contiguous() + + +def _xyxyr2xyzwlhr(boxes): + """Convert [x1, y1, x2, y2, heading] box to [x, y, 0, dx, dy, 1, heading] + box. + + Args: + box (torch.Tensor): Input boxes with shape (N, 5). + + Returns: + torch.Tensor: Converted boxes with shape (N, 7). + """ + warnings.warn( + 'This function is deprecated and will be removed in the future.', + DeprecationWarning) + return torch.stack( + ((boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2, + torch.zeros_like(boxes[:, 0]), boxes[:, 2] - boxes[:, 0], + boxes[:, 3] - boxes[:, 1], torch.ones_like(boxes[:, 0]), boxes[:, 4]), + dim=-1) + + +def boxes_iou_bev(boxes_a, boxes_b): + """Calculate boxes IoU in the Bird's Eye View. + + Args: + boxes_a (torch.Tensor): Input boxes a with shape (M, 5). + boxes_b (torch.Tensor): Input boxes b with shape (N, 5). + + Returns: + torch.Tensor: IoU result with shape (M, N). + """ + warnings.warn( + '`iou3d.boxes_iou_bev` is deprecated and will be removed in' + ' the future. Please, use `box_iou_rotated.box_iou_rotated`.', + DeprecationWarning) + ans_iou = boxes_a.new_zeros( + torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) + boxes_a = _xyxyr2xyzwlhr(boxes_a) + boxes_b = _xyxyr2xyzwlhr(boxes_b) + + ext_module.iou3d_boxes_iou3d_forward(boxes_a.contiguous(), + boxes_b.contiguous(), ans_iou) + + return ans_iou + + +def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): + """NMS function GPU implementation (for BEV boxes). + + The overlap of two + boxes for IoU calculation is defined as the exact overlapping area of the + two boxes. In this function, one can also set ``pre_max_size`` and + ``post_max_size``. + Args: + boxes (torch.Tensor): Input boxes with the shape of [N, 5] + ([x1, y1, x2, y2, ry]). + scores (torch.Tensor): Scores of boxes with the shape of [N]. + thresh (float): Overlap threshold of NMS. + pre_max_size (int, optional): Max size of boxes before NMS. + Default: None. + post_max_size (int, optional): Max size of boxes after NMS. + Default: None. + Returns: + torch.Tensor: Indexes after NMS. + """ + warnings.warn( + '`iou3d.nms_bev` is deprecated and will be removed in' + ' the future. Please, use `nms.nms_rotated`.', DeprecationWarning) + assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]' + order = scores.sort(0, descending=True)[1] + + if pre_max_size is not None: + order = order[:pre_max_size] + boxes = _xyxyr2xyzwlhr(boxes)[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = torch.zeros(size=(), dtype=torch.long) + ext_module.iou3d_nms3d_forward( + boxes, keep, num_out, nms_overlap_thresh=thresh) + keep = order[keep[:num_out].cuda(boxes.device)].contiguous() + if post_max_size is not None: + keep = keep[:post_max_size] + return keep + + +def nms_normal_bev(boxes, scores, thresh): + """Normal NMS function GPU implementation (for BEV boxes). + + The overlap of + two boxes for IoU calculation is defined as the exact overlapping area of + the two boxes WITH their yaw angle set to 0. + Args: + boxes (torch.Tensor): Input boxes with shape (N, 5). + scores (torch.Tensor): Scores of predicted boxes with shape (N). + thresh (float): Overlap threshold of NMS. + Returns: + torch.Tensor: Remaining indices with scores in descending order. + """ + warnings.warn( + '`iou3d.nms_normal_bev` is deprecated and will be removed in' + ' the future. Please, use `nms.nms`.', DeprecationWarning) + assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' + order = scores.sort(0, descending=True)[1] + + boxes = _xyxyr2xyzwlhr(boxes)[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = torch.zeros(size=(), dtype=torch.long) + ext_module.iou3d_nms3d_normal_forward( + boxes, keep, num_out, nms_overlap_thresh=thresh) + return order[keep[:num_out].cuda(boxes.device)].contiguous() From dfc23faec80071ba9326cbccbad477780d74941b Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Sat, 23 Apr 2022 14:53:00 +0300 Subject: [PATCH 3/4] fix lint --- mmcv/ops/__init__.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) mode change 100644 => 100755 mmcv/ops/__init__.py diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py old mode 100644 new mode 100755 index 90d52dd66a..487c6df9bf --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -90,11 +90,12 @@ 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample', 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', - 'boxes_iou3d', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization', - 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d', - 'SparseConv3d', 'SparseConvTranspose2d', 'SparseConvTranspose3d', - 'SparseInverseConv2d', 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d', - 'SparseModule', 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d', + 'boxes_iou3d', 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'nms3d', + 'nms3d_normal', 'Voxelization', 'voxelization', 'dynamic_scatter', + 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d', 'SparseConv3d', + 'SparseConvTranspose2d', 'SparseConvTranspose3d', 'SparseInverseConv2d', + 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d', 'SparseModule', + 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou', From 985d5bffabac34317f6e1323abe753e35b9a8fda Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Thu, 12 May 2022 11:57:43 +0300 Subject: [PATCH 4/4] replace 3d iou/nms calls for bev iou/nms --- mmcv/ops/iou3d.py | 62 +++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 35 deletions(-) mode change 100644 => 100755 mmcv/ops/iou3d.py diff --git a/mmcv/ops/iou3d.py b/mmcv/ops/iou3d.py old mode 100644 new mode 100755 index 563bbbd4ce..4a120df4f7 --- a/mmcv/ops/iou3d.py +++ b/mmcv/ops/iou3d.py @@ -79,9 +79,8 @@ def nms3d_normal(boxes, scores, iou_threshold): return order[keep[:num_out].cuda(boxes.device)].contiguous() -def _xyxyr2xyzwlhr(boxes): - """Convert [x1, y1, x2, y2, heading] box to [x, y, 0, dx, dy, 1, heading] - box. +def _xyxyr2xywhr(boxes): + """Convert [x1, y1, x2, y2, heading] box to [x, y, dx, dy, heading] box. Args: box (torch.Tensor): Input boxes with shape (N, 5). @@ -94,8 +93,7 @@ def _xyxyr2xyzwlhr(boxes): DeprecationWarning) return torch.stack( ((boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2, - torch.zeros_like(boxes[:, 0]), boxes[:, 2] - boxes[:, 0], - boxes[:, 3] - boxes[:, 1], torch.ones_like(boxes[:, 0]), boxes[:, 4]), + boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1], boxes[:, 4]), dim=-1) @@ -103,25 +101,22 @@ def boxes_iou_bev(boxes_a, boxes_b): """Calculate boxes IoU in the Bird's Eye View. Args: - boxes_a (torch.Tensor): Input boxes a with shape (M, 5). - boxes_b (torch.Tensor): Input boxes b with shape (N, 5). + boxes_a (torch.Tensor): Input boxes a with shape (M, 5) + ([x1, y1, x2, y2, ry]). + boxes_b (torch.Tensor): Input boxes b with shape (N, 5) + ([x1, y1, x2, y2, ry]). Returns: torch.Tensor: IoU result with shape (M, N). """ + from .box_iou_rotated import box_iou_rotated + warnings.warn( '`iou3d.boxes_iou_bev` is deprecated and will be removed in' ' the future. Please, use `box_iou_rotated.box_iou_rotated`.', DeprecationWarning) - ans_iou = boxes_a.new_zeros( - torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) - boxes_a = _xyxyr2xyzwlhr(boxes_a) - boxes_b = _xyxyr2xyzwlhr(boxes_b) - - ext_module.iou3d_boxes_iou3d_forward(boxes_a.contiguous(), - boxes_b.contiguous(), ans_iou) - return ans_iou + return box_iou_rotated(_xyxyr2xywhr(boxes_a), _xyxyr2xywhr(boxes_b)) def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): @@ -132,9 +127,9 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): two boxes. In this function, one can also set ``pre_max_size`` and ``post_max_size``. Args: - boxes (torch.Tensor): Input boxes with the shape of [N, 5] + boxes (torch.Tensor): Input boxes with the shape of (N, 5) ([x1, y1, x2, y2, ry]). - scores (torch.Tensor): Scores of boxes with the shape of [N]. + scores (torch.Tensor): Scores of boxes with the shape of (N,). thresh (float): Overlap threshold of NMS. pre_max_size (int, optional): Max size of boxes before NMS. Default: None. @@ -143,21 +138,22 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): Returns: torch.Tensor: Indexes after NMS. """ + from .nms import nms_rotated + warnings.warn( '`iou3d.nms_bev` is deprecated and will be removed in' ' the future. Please, use `nms.nms_rotated`.', DeprecationWarning) - assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]' + assert boxes.size(1) == 5, 'Input boxes shape should be (N, 5)' order = scores.sort(0, descending=True)[1] if pre_max_size is not None: order = order[:pre_max_size] - boxes = _xyxyr2xyzwlhr(boxes)[order].contiguous() + boxes = _xyxyr2xywhr(boxes)[order] + scores = scores[order] + + keep = nms_rotated(boxes, scores, thresh)[1] + keep = order[keep] - keep = torch.zeros(boxes.size(0), dtype=torch.long) - num_out = torch.zeros(size=(), dtype=torch.long) - ext_module.iou3d_nms3d_forward( - boxes, keep, num_out, nms_overlap_thresh=thresh) - keep = order[keep[:num_out].cuda(boxes.device)].contiguous() if post_max_size is not None: keep = keep[:post_max_size] return keep @@ -170,22 +166,18 @@ def nms_normal_bev(boxes, scores, thresh): two boxes for IoU calculation is defined as the exact overlapping area of the two boxes WITH their yaw angle set to 0. Args: - boxes (torch.Tensor): Input boxes with shape (N, 5). - scores (torch.Tensor): Scores of predicted boxes with shape (N). + boxes (torch.Tensor): Input boxes with shape (N, 5) + ([x1, y1, x2, y2, ry]). + scores (torch.Tensor): Scores of predicted boxes with shape (N,). thresh (float): Overlap threshold of NMS. Returns: torch.Tensor: Remaining indices with scores in descending order. """ + from .nms import nms + warnings.warn( '`iou3d.nms_normal_bev` is deprecated and will be removed in' ' the future. Please, use `nms.nms`.', DeprecationWarning) - assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' - order = scores.sort(0, descending=True)[1] - - boxes = _xyxyr2xyzwlhr(boxes)[order].contiguous() + assert boxes.shape[1] == 5, 'Input boxes shape should be (N, 5)' - keep = torch.zeros(boxes.size(0), dtype=torch.long) - num_out = torch.zeros(size=(), dtype=torch.long) - ext_module.iou3d_nms3d_normal_forward( - boxes, keep, num_out, nms_overlap_thresh=thresh) - return order[keep[:num_out].cuda(boxes.device)].contiguous() + return nms(boxes[:, :-1], scores, thresh)[1]