Skip to content

Commit

Permalink
[Feature] Add roipoint_pool3d op from mmdet3d (#1358)
Browse files Browse the repository at this point in the history
* add ops (roipoint_pool3d) in mmdet3d

* refactor code

* fix typo

* add unit test

* refactor code

* refactor code

* refactor code

* fix typo
  • Loading branch information
DCNSW authored Oct 15, 2021
1 parent e99bfce commit 5c25ae1
Show file tree
Hide file tree
Showing 9 changed files with 391 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- MaskedConv
- NMS
- PSAMask
- RoIPointPool3d
- RoIPool
- RoIAlign
- SimpleRoIAlign
Expand Down
1 change: 1 addition & 0 deletions docs_zh_CN/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- MaskedConv
- NMS
- PSAMask
- RoIPointPool3d
- RoIPool
- RoIAlign
- SimpleRoIAlign
Expand Down
11 changes: 6 additions & 5 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_pool import RoIPool, roi_pool
from .roipoint_pool3d import RoIPointPool3d
from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
Expand All @@ -60,10 +61,10 @@
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'nms_rotated', 'knn', 'ball_query', 'upfirdn2d',
'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'RoIAlignRotated',
'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn',
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
'border_align', 'gather_points', 'furthest_point_sample',
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention',
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
]
144 changes: 144 additions & 0 deletions mmcv/ops/csrc/common/cuda/roipoint_pool3d_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ROIPOINT_POOL3D_CUDA_KERNEL_CUH
#define ROIPOINT_POOL3D_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}

template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the
// bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];
cz += dz / 2.0; // shift to the center since cz in box3d is the bottom center

if (fabsf(z - cz) > dz / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
T in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) &
(local_y > -dy / 2.0) & (local_y < dy / 2.0);
return in_flag;
}

template <typename T>
__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
const T *xyz, const T *boxes3d,
int *pts_assign) {
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means
// background points
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;

if (pt_idx >= pts_num || box_idx >= boxes_num || bs_idx >= batch_size) {
return;
}
int assign_idx = bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
pts_assign[assign_idx] = 0;

int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3;

T local_x = 0, local_y = 0;
int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset,
local_x, local_y);
pts_assign[assign_idx] = cur_in_flag;
}

__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
int sampled_pts_num, const int *pts_assign,
int *pts_idx, int *pooled_empty_flag) {
// params xyz: (B, N, 3)
// params pts_feature: (B, N, C)
// params pts_assign: (B, N)
// params pts_idx: (B, M, 512)
// params pooled_empty_flag: (B, M)

int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (boxes_idx >= boxes_num) {
return;
}

int bs_idx = blockIdx.y;

int cnt = 0;
for (int k = 0; k < pts_num; k++) {
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + boxes_idx]) {
if (cnt < sampled_pts_num) {
pts_idx[bs_idx * boxes_num * sampled_pts_num +
boxes_idx * sampled_pts_num + cnt] = k;
cnt++;
} else
break;
}
}

if (cnt == 0) {
pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1;
} else if (cnt < sampled_pts_num) {
// duplicate same points for sampling
for (int k = cnt; k < sampled_pts_num; k++) {
int duplicate_idx = k % cnt;
int base_offset =
bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num;
pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx];
}
}
}

template <typename T>
__global__ void roipoint_pool3d_forward(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const T *xyz, const int *pts_idx, const T *pts_feature,
T *pooled_features, int *pooled_empty_flag) {
// params xyz: (B, N, 3)
// params pts_idx: (B, M, 512)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)

int sample_pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;

if (sample_pt_idx >= sampled_pts_num || box_idx >= boxes_num ||
bs_idx >= batch_size) {
return;
}

if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) {
return;
}

int temp_idx = bs_idx * boxes_num * sampled_pts_num +
box_idx * sampled_pts_num + sample_pt_idx;
int src_pt_idx = pts_idx[temp_idx];
int dst_feature_offset = temp_idx * (3 + feature_in_len);

for (int j = 0; j < 3; j++)
pooled_features[dst_feature_offset + j] =
xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j];

int src_feature_offset =
bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len;
memcpy(pooled_features + dst_feature_offset + 3,
pts_feature + src_feature_offset, feature_in_len * sizeof(T));
}

#endif // ROIPOINT_POOL3D_CUDA_KERNEL_CUH
60 changes: 60 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/roipoint_pool3d_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
Modified from
https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/

#include <math.h>
#include <stdio.h>

#include "pytorch_cuda_helper.hpp"
#include "roipoint_pool3d_cuda_kernel.cuh"

void RoIPointPool3dForwardCUDAKernelLauncher(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature, Tensor pooled_features,
Tensor pooled_empty_flag) {
Tensor pts_assign = at::empty({batch_size, pts_num, boxes_num},
boxes3d.options().dtype(at::kInt));

at::cuda::CUDAGuard device_guard(xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz.scalar_type(), "assign_pts_to_box3d", [&] {
assign_pts_to_box3d<scalar_t><<<blocks, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, xyz.data_ptr<scalar_t>(),
boxes3d.data_ptr<scalar_t>(), pts_assign.data_ptr<int>());
});

Tensor pts_idx = at::empty({batch_size, boxes_num, sampled_pts_num},
boxes3d.options().dtype(at::kInt));

// blockIdx.x(col), blockIdx.y(row)
dim3 blocks2(DIVUP(boxes_num, THREADS_PER_BLOCK), batch_size);

get_pooled_idx<<<blocks2, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, sampled_pts_num,
pts_assign.data_ptr<int>(), pts_idx.data_ptr<int>(),
pooled_empty_flag.data_ptr<int>());

dim3 blocks_pool(DIVUP(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
batch_size);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz.scalar_type(), "roipoint_pool3d_forward", [&] {
roipoint_pool3d_forward<scalar_t><<<blocks_pool, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz.data_ptr<scalar_t>(), pts_idx.data_ptr<int>(),
pts_feature.data_ptr<scalar_t>(),
pooled_features.data_ptr<scalar_t>(),
pooled_empty_flag.data_ptr<int>());
});
}
7 changes: 7 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois,
int pooled_width, float spatial_scale,
int sampling_ratio, float gamma);

void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
Tensor pooled_features, Tensor pooled_empty_flag);

void gather_points_forward(int b, int c, int n, int npoints,
Tensor points_tensor, Tensor idx_tensor,
Tensor out_tensor);
Expand Down Expand Up @@ -364,6 +367,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_offset"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sampling_ratio"), py::arg("gamma"));
m.def("roipoint_pool3d_forward", &roipoint_pool3d_forward,
"roipoint_pool3d_forward", py::arg("xyz"), py::arg("boxes3d"),
py::arg("pts_feature"), py::arg("pooled_features"),
py::arg("pooled_empty_flag"));
m.def("sigmoid_focal_loss_forward", &sigmoid_focal_loss_forward,
"sigmoid_focal_loss_forward ", py::arg("input"), py::arg("target"),
py::arg("weight"), py::arg("output"), py::arg("gamma"),
Expand Down
60 changes: 60 additions & 0 deletions mmcv/ops/csrc/pytorch/roipoint_pool3d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
Modified from
https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d.cpp
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/

#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
void RoIPointPool3dForwardCUDAKernelLauncher(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag);

void roipoint_pool3d_forward_cuda(int batch_size, int pts_num, int boxes_num,
int feature_in_len, int sampled_pts_num,
const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature,
Tensor pooled_features,
Tensor pooled_empty_flag) {
RoIPointPool3dForwardCUDAKernelLauncher(
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz,
boxes3d, pts_feature, pooled_features, pooled_empty_flag);
};
#endif

void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
Tensor pooled_features, Tensor pooled_empty_flag) {
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)

if (xyz.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(xyz);
CHECK_CUDA_INPUT(boxes3d);
CHECK_CUDA_INPUT(pts_feature);
CHECK_CUDA_INPUT(pooled_features);
CHECK_CUDA_INPUT(pooled_empty_flag);

int batch_size = xyz.size(0);
int pts_num = xyz.size(1);
int boxes_num = boxes3d.size(1);
int feature_in_len = pts_feature.size(2);
int sampled_pts_num = pooled_features.size(2);

roipoint_pool3d_forward_cuda(batch_size, pts_num, boxes_num, feature_in_len,
sampled_pts_num, xyz, boxes3d, pts_feature,
pooled_features, pooled_empty_flag);
#else
AT_ERROR("roipoint_pool3d is not compiled with GPU support");
#endif
} else {
AT_ERROR("roipoint_pool3d is not implemented on CPU");
}
}
Loading

0 comments on commit 5c25ae1

Please sign in to comment.