Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Replace BEV IoU with 3D IoU #1902

Merged
merged 4 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- MaskedConv
- MinAreaPolygon
- NMS
- NMS3D
- PointsInPolygons
- PSAMask
- RiRoIAlignRotated
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- MaskedConv
- MinAreaPolygon
- NMS
- NMS3D
- PointsInPolygons
- PSAMask
- RotatedFeatureAlign
Expand Down
20 changes: 10 additions & 10 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .group_points import GroupAll, QueryAndGroup, grouping_operation
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
from .iou3d import boxes_iou3d, nms3d, nms3d_normal
from .knn import knn
from .masked_conv import MaskedConv2d, masked_conv2d
from .min_area_polygons import min_area_polygons
Expand Down Expand Up @@ -89,13 +89,13 @@
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'SparseConv2d', 'SparseConv3d', 'SparseConvTranspose2d',
'SparseConvTranspose3d', 'SparseInverseConv2d', 'SparseInverseConv3d',
'SubMConv2d', 'SubMConv3d', 'SparseModule', 'SparseSequential',
'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter',
'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
'boxes_iou3d', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization',
'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d',
'SparseConv3d', 'SparseConvTranspose2d', 'SparseConvTranspose3d',
'SparseInverseConv2d', 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d',
'SparseModule', 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d',
'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
]
173 changes: 85 additions & 88 deletions mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,17 @@ __device__ int check_rect_cross(const Point &p1, const Point &p2,
}

__device__ inline int check_in_box2d(const float *box, const Point &p) {
// params: box (5) [x1, y1, x2, y2, angle]
const float MARGIN = 1e-5;

float center_x = (box[0] + box[2]) / 2;
float center_y = (box[1] + box[3]) / 2;
float angle_cos = cos(-box[4]),
angle_sin =
sin(-box[4]); // rotate the point in the opposite direction of box
float rot_x =
(p.x - center_x) * angle_cos - (p.y - center_y) * angle_sin + center_x;
float rot_y =
(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y;

return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN &&
rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN);
// params: box (7) [x, y, z, dx, dy, dz, heading]
const float MARGIN = 1e-2;

float center_x = box[0], center_y = box[1];
// rotate the point in the opposite direction of box
float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]);
float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;

return (fabs(rot_x) < box[3] / 2 + MARGIN &&
fabs(rot_y) < box[4] / 2 + MARGIN);
}

__device__ inline int intersection(const Point &p1, const Point &p0,
Expand Down Expand Up @@ -116,16 +112,19 @@ __device__ inline int point_cmp(const Point &a, const Point &b,
}

__device__ inline float box_overlap(const float *box_a, const float *box_b) {
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
// params box_a: [x, y, z, dx, dy, dz, heading]
// params box_b: [x, y, z, dx, dy, dz, heading]

float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3],
a_angle = box_a[4];
float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3],
b_angle = box_b[4];
float a_angle = box_a[6], b_angle = box_b[6];
float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2,
a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;

Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2);
Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2);
Point center_a(box_a[0], box_a[1]);
Point center_b(box_b[0], box_b[1]);

Point box_a_corners[5];
box_a_corners[0].set(a_x1, a_y1);
Expand Down Expand Up @@ -209,50 +208,36 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) {
}

__device__ inline float iou_bev(const float *box_a, const float *box_b) {
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]);
float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]);
// params box_a: [x, y, z, dx, dy, dz, heading]
// params box_b: [x, y, z, dx, dy, dz, heading]
float sa = box_a[3] * box_a[4];
float sb = box_b[3] * box_b[4];
float s_overlap = box_overlap(box_a, box_b);
return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
}

__global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel(
const int num_a, const float *boxes_a, const int num_b,
const float *boxes_b, float *ans_overlap) {
CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) {
if (a_idx >= num_a || b_idx >= num_b) {
return;
}
const float *cur_box_a = boxes_a + a_idx * 5;
const float *cur_box_b = boxes_b + b_idx * 5;
float s_overlap = box_overlap(cur_box_a, cur_box_b);
ans_overlap[a_idx * num_b + b_idx] = s_overlap;
}
}

__global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a,
const float *boxes_a,
const int num_b,
const float *boxes_b,
float *ans_iou) {
__global__ void iou3d_boxes_iou3d_forward_cuda_kernel(const int num_a,
const float *boxes_a,
const int num_b,
const float *boxes_b,
float *ans_iou) {
CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) {
if (a_idx >= num_a || b_idx >= num_b) {
return;
}

const float *cur_box_a = boxes_a + a_idx * 5;
const float *cur_box_b = boxes_b + b_idx * 5;
const float *cur_box_a = boxes_a + a_idx * 7;
const float *cur_box_b = boxes_b + b_idx * 7;
float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
}
}

__global__ void nms_forward_cuda_kernel(const int boxes_num,
const float nms_overlap_thresh,
const float *boxes,
unsigned long long *mask) {
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
__global__ void iou3d_nms3d_forward_cuda_kernel(const int boxes_num,
const float nms_overlap_thresh,
const float *boxes,
unsigned long long *mask) {
// params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
Expand All @@ -264,25 +249,29 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);

__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];

if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
block_boxes[threadIdx.x * 7 + 0] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
block_boxes[threadIdx.x * 7 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
block_boxes[threadIdx.x * 7 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
block_boxes[threadIdx.x * 7 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
block_boxes[threadIdx.x * 7 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
block_boxes[threadIdx.x * 7 + 5] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
block_boxes[threadIdx.x * 7 + 6] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
}
__syncthreads();

if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const float *cur_box = boxes + cur_box_idx * 5;
const float *cur_box = boxes + cur_box_idx * 7;

int i = 0;
unsigned long long t = 0;
Expand All @@ -291,7 +280,7 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
Expand All @@ -303,20 +292,24 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
}

__device__ inline float iou_normal(float const *const a, float const *const b) {
float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
// params: a: [x, y, z, dx, dy, dz, heading]
// params: b: [x, y, z, dx, dy, dz, heading]

float left = fmaxf(a[0] - a[3] / 2, b[0] - b[3] / 2),
right = fminf(a[0] + a[3] / 2, b[0] + b[3] / 2);
float top = fmaxf(a[1] - a[4] / 2, b[1] - b[4] / 2),
bottom = fminf(a[1] + a[4] / 2, b[1] + b[4] / 2);
float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);
float interS = width * height;
float Sa = (a[2] - a[0]) * (a[3] - a[1]);
float Sb = (b[2] - b[0]) * (b[3] - b[1]);
float Sa = a[3] * a[4];
float Sb = b[3] * b[4];
return interS / fmaxf(Sa + Sb - interS, EPS);
}

__global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
const float nms_overlap_thresh,
const float *boxes,
unsigned long long *mask) {
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
__global__ void iou3d_nms3d_normal_forward_cuda_kernel(
const int boxes_num, const float nms_overlap_thresh, const float *boxes,
unsigned long long *mask) {
// params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)

const int blocks =
Expand All @@ -329,25 +322,29 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);

__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];

if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
block_boxes[threadIdx.x * 7 + 0] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
block_boxes[threadIdx.x * 7 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
block_boxes[threadIdx.x * 7 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
block_boxes[threadIdx.x * 7 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
block_boxes[threadIdx.x * 7 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
block_boxes[threadIdx.x * 7 + 5] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
block_boxes[threadIdx.x * 7 + 6] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
}
__syncthreads();

if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const float *cur_box = boxes + cur_box_idx * 5;
const float *cur_box = boxes + cur_box_idx * 7;

int i = 0;
unsigned long long t = 0;
Expand All @@ -356,7 +353,7 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
if (iou_normal(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
Expand Down
Loading