Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add roipoint_pool3d op from mmdet3d #1358

Merged
merged 11 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- MaskedConv
- NMS
- PSAMask
- RoIPointPool3d
- RoIPool
- RoIAlign
- SimpleRoIAlign
Expand Down
1 change: 1 addition & 0 deletions docs_zh_CN/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- MaskedConv
- NMS
- PSAMask
- RoIPointPool3d
- RoIPool
- RoIAlign
- SimpleRoIAlign
Expand Down
11 changes: 6 additions & 5 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_pool import RoIPool, roi_pool
from .roipoint_pool3d import RoIPointPool3d
from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm
from .tin_shift import TINShift, tin_shift
Expand All @@ -51,9 +52,9 @@
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'Correlation'
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'RoIPointPool3d',
'nms_rotated', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'contour_expand', 'MultiScaleDeformableAttention',
'BorderAlign', 'border_align', 'Correlation'
]
146 changes: 146 additions & 0 deletions mmcv/ops/csrc/common/cuda/roipoint_pool3d_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ROIPOINT_POOL3D_CUDA_KERNEL_CUH
#define ROIPOINT_POOL3D_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
DCNSW marked this conversation as resolved.
Show resolved Hide resolved

template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}

template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the
// bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];
cz += dz / 2.0; // shift to the center since cz in box3d is the bottom center

if (fabsf(z - cz) > dz / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
T in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) &
(local_y > -dy / 2.0) & (local_y < dy / 2.0);
return in_flag;
}

template <typename T>
__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
const T *xyz, const T *boxes3d,
int *pts_assign) {
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means
// background points
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;

if (pt_idx >= pts_num || box_idx >= boxes_num || bs_idx >= batch_size) {
return;
}
int assign_idx = bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
pts_assign[assign_idx] = 0;

int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3;

T local_x = 0, local_y = 0;
int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset,
local_x, local_y);
pts_assign[assign_idx] = cur_in_flag;
}

__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
int sampled_pts_num, const int *pts_assign,
int *pts_idx, int *pooled_empty_flag) {
// params xyz: (B, N, 3)
// params pts_feature: (B, N, C)
// params pts_assign: (B, N)
// params pts_idx: (B, M, 512)
// params pooled_empty_flag: (B, M)

int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (boxes_idx >= boxes_num) {
return;
}

int bs_idx = blockIdx.y;

int cnt = 0;
for (int k = 0; k < pts_num; k++) {
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + boxes_idx]) {
if (cnt < sampled_pts_num) {
pts_idx[bs_idx * boxes_num * sampled_pts_num +
boxes_idx * sampled_pts_num + cnt] = k;
cnt++;
} else
break;
}
}

if (cnt == 0) {
pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1;
} else if (cnt < sampled_pts_num) {
// duplicate same points for sampling
for (int k = cnt; k < sampled_pts_num; k++) {
int duplicate_idx = k % cnt;
int base_offset =
bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num;
pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx];
}
}
}

template <typename T>
__global__ void roipoint_pool3d_forward(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const T *xyz, const int *pts_idx, const T *pts_feature,
T *pooled_features, int *pooled_empty_flag) {
// params xyz: (B, N, 3)
// params pts_idx: (B, M, 512)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)

int sample_pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;

if (sample_pt_idx >= sampled_pts_num || box_idx >= boxes_num ||
bs_idx >= batch_size) {
return;
}

if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) {
return;
}

int temp_idx = bs_idx * boxes_num * sampled_pts_num +
box_idx * sampled_pts_num + sample_pt_idx;
int src_pt_idx = pts_idx[temp_idx];
int dst_feature_offset = temp_idx * (3 + feature_in_len);

for (int j = 0; j < 3; j++)
pooled_features[dst_feature_offset + j] =
xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j];

int src_feature_offset =
bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len;
memcpy(pooled_features + dst_feature_offset + 3,
pts_feature + src_feature_offset, feature_in_len * sizeof(T));
}

#endif // ROIPOINT_POOL3D_CUDA_KERNEL_CUH
60 changes: 60 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/roipoint_pool3d_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
Modified from
https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
DCNSW marked this conversation as resolved.
Show resolved Hide resolved
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/

#include <math.h>
#include <stdio.h>

#include "pytorch_cuda_helper.hpp"
#include "roipoint_pool3d_cuda_kernel.cuh"

void RoIPointPool3dForwardCUDAKernelLauncher(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature, Tensor pooled_features,
Tensor pooled_empty_flag) {
Tensor pts_assign = at::empty({batch_size, pts_num, boxes_num},
boxes3d.options().dtype(at::kInt));

at::cuda::CUDAGuard device_guard(xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num,
batch_size); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz.scalar_type(), "assign_pts_to_box3d", [&] {
assign_pts_to_box3d<scalar_t><<<blocks, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, xyz.data_ptr<scalar_t>(),
boxes3d.data_ptr<scalar_t>(), pts_assign.data_ptr<int>());
});

Tensor pts_idx = at::empty({batch_size, boxes_num, sampled_pts_num},
boxes3d.options().dtype(at::kInt));

dim3 blocks2(DIVUP(boxes_num, THREADS_PER_BLOCK),
batch_size); // blockIdx.x(col), blockIdx.y(row)

get_pooled_idx<<<blocks2, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, sampled_pts_num,
pts_assign.data_ptr<int>(), pts_idx.data_ptr<int>(),
pooled_empty_flag.data_ptr<int>());

dim3 blocks_pool(DIVUP(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
batch_size);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz.scalar_type(), "roipoint_pool3d_forward", [&] {
roipoint_pool3d_forward<scalar_t><<<blocks_pool, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz.data_ptr<scalar_t>(), pts_idx.data_ptr<int>(),
pts_feature.data_ptr<scalar_t>(),
pooled_features.data_ptr<scalar_t>(),
pooled_empty_flag.data_ptr<int>());
});
}
7 changes: 7 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois,
int pooled_width, float spatial_scale,
int sampling_ratio, float gamma);

void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
Tensor pooled_features, Tensor pooled_empty_flag);

void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);

Expand Down Expand Up @@ -296,6 +299,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_offset"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sampling_ratio"), py::arg("gamma"));
m.def("roipoint_pool3d_forward", &roipoint_pool3d_forward,
"roipoint_pool3d_forward", py::arg("xyz"), py::arg("boxes3d"),
py::arg("pts_feature"), py::arg("pooled_features"),
py::arg("pooled_empty_flag"));
m.def("sigmoid_focal_loss_forward", &sigmoid_focal_loss_forward,
"sigmoid_focal_loss_forward ", py::arg("input"), py::arg("target"),
py::arg("weight"), py::arg("output"), py::arg("gamma"),
Expand Down
60 changes: 60 additions & 0 deletions mmcv/ops/csrc/pytorch/roipoint_pool3d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
Modified for
https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
DCNSW marked this conversation as resolved.
Show resolved Hide resolved
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/

#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
void RoIPointPool3dForwardCUDAKernelLauncher(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag);

void roipoint_pool3d_forward_cuda(int batch_size, int pts_num, int boxes_num,
int feature_in_len, int sampled_pts_num,
const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature,
Tensor pooled_features,
Tensor pooled_empty_flag) {
RoIPointPool3dForwardCUDAKernelLauncher(
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz,
boxes3d, pts_feature, pooled_features, pooled_empty_flag);
};
#endif

void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
Tensor pooled_features, Tensor pooled_empty_flag) {
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)

if (xyz.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(xyz);
CHECK_CUDA_INPUT(boxes3d);
CHECK_CUDA_INPUT(pts_feature);
CHECK_CUDA_INPUT(pooled_features);
CHECK_CUDA_INPUT(pooled_empty_flag);

int batch_size = xyz.size(0);
int pts_num = xyz.size(1);
int boxes_num = boxes3d.size(1);
int feature_in_len = pts_feature.size(2);
int sampled_pts_num = pooled_features.size(2);

roipoint_pool3d_forward_cuda(batch_size, pts_num, boxes_num, feature_in_len,
sampled_pts_num, xyz, boxes3d, pts_feature,
pooled_features, pooled_empty_flag);
#else
AT_ERROR("roipoint_pool3d is not compiled with GPU support");
#endif
} else {
AT_ERROR("roipoint_pool3d is not implemented on CPU");
}
}
Loading