Skip to content

Commit

Permalink
[Feature] Add roipooling cuda ops (#843)
Browse files Browse the repository at this point in the history
* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import

* add roipooling cuda ops

* add roi extractor

* add test_roi_extractor unittest

* Modify setup.py to install roipooling ops

* modify docstring

* remove enlarge bbox in roipoint pooling

* add_roipooling_ops

* modify docstring

Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com>
Co-authored-by: THU17cyz <congyezhen71@hotmail.com>
  • Loading branch information
3 people authored and Tai-Wang committed Sep 24, 2021
1 parent 2936b3b commit 4f36084
Show file tree
Hide file tree
Showing 8 changed files with 407 additions and 2 deletions.
6 changes: 5 additions & 1 deletion mmdet3d/models/roi_heads/roi_extractors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.roi_heads.roi_extractors import SingleRoIExtractor
from .single_roiaware_extractor import Single3DRoIAwareExtractor
from .single_roipoint_extractor import Single3DRoIPointExtractor

__all__ = ['SingleRoIExtractor', 'Single3DRoIAwareExtractor']
__all__ = [
'SingleRoIExtractor', 'Single3DRoIAwareExtractor',
'Single3DRoIPointExtractor'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from torch import nn as nn

from mmdet3d import ops
from mmdet3d.core.bbox.structures import rotation_3d_in_axis
from mmdet.models.builder import ROI_EXTRACTORS


@ROI_EXTRACTORS.register_module()
class Single3DRoIPointExtractor(nn.Module):
"""Point-wise roi-aware Extractor.
Extract Point-wise roi features.
Args:
roi_layer (dict): The config of roi layer.
"""

def __init__(self, roi_layer=None):
super(Single3DRoIPointExtractor, self).__init__()
self.roi_layer = self.build_roi_layers(roi_layer)

def build_roi_layers(self, layer_cfg):
"""Build roi layers using `layer_cfg`"""
cfg = layer_cfg.copy()
layer_type = cfg.pop('type')
assert hasattr(ops, layer_type)
layer_cls = getattr(ops, layer_type)
roi_layers = layer_cls(**cfg)
return roi_layers

def forward(self, feats, coordinate, batch_inds, rois):
"""Extract point-wise roi features.
Args:
feats (torch.FloatTensor): Point-wise features with
shape (batch, npoints, channels) for pooling.
coordinate (torch.FloatTensor): Coordinate of each point.
batch_inds (torch.LongTensor): Indicate the batch of each point.
rois (torch.FloatTensor): Roi boxes with batch indices.
Returns:
torch.FloatTensor: Pooled features
"""
rois = rois[..., 1:]
rois = rois.view(batch_inds, -1, rois.shape[-1])
with torch.no_grad():
pooled_roi_feat, pooled_empty_flag = self.roi_layer(
coordinate, feats, rois)

# canonical transformation
roi_center = rois[:, :, 0:3]
pooled_roi_feat[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
pooled_roi_feat = pooled_roi_feat.view(-1,
pooled_roi_feat.shape[-2],
pooled_roi_feat.shape[-1])
pooled_roi_feat[:, :, 0:3] = rotation_3d_in_axis(
pooled_roi_feat[:, :, 0:3],
-(rois.view(-1, rois.shape[-1])[:, 6]),
axis=2)
pooled_roi_feat[pooled_empty_flag.view(-1) > 0] = 0

return pooled_roi_feat
3 changes: 3 additions & 0 deletions mmdet3d/ops/roipoint_pool3d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .roipoint_pool3d import RoIPointPool3d

__all__ = ['RoIPointPool3d']
71 changes: 71 additions & 0 deletions mmdet3d/ops/roipoint_pool3d/roipoint_pool3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from torch import nn as nn
from torch.autograd import Function

from . import roipoint_pool3d_ext


class RoIPointPool3d(nn.Module):

def __init__(self, num_sampled_points=512):
super().__init__()
"""
Args:
num_sampled_points (int): Number of samples in each roi
"""
self.num_sampled_points = num_sampled_points

def forward(self, points, point_features, boxes3d):
"""
Args:
points (torch.Tensor): Input points whose shape is BxNx3
point_features: (B, N, C)
boxes3d: (B, M, 7), [x, y, z, dx, dy, dz, heading]
Returns:
torch.Tensor: (B, M, 512, 3 + C) pooled_features
torch.Tensor: (B, M) pooled_empty_flag
"""
return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
self.num_sampled_points)


class RoIPointPool3dFunction(Function):

@staticmethod
def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
"""
Args:
points (torch.Tensor): Input points whose shape is (B, N, 3)
point_features (torch.Tensor): Input points features shape is \
(B, N, C)
boxes3d (torch.Tensor): Input bounding boxes whose shape is \
(B, M, 7)
num_sampled_points (int): the num of sampled points
Returns:
torch.Tensor: (B, M, 512, 3 + C) pooled_features
torch.Tensor: (B, M) pooled_empty_flag
"""
assert points.shape.__len__() == 3 and points.shape[2] == 3
batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
1], point_features.shape[2]
pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
pooled_features = point_features.new_zeros(
(batch_size, boxes_num, num_sampled_points, 3 + feature_len))
pooled_empty_flag = point_features.new_zeros(
(batch_size, boxes_num)).int()

roipoint_pool3d_ext.forward(points.contiguous(),
pooled_boxes3d.contiguous(),
point_features.contiguous(),
pooled_features, pooled_empty_flag)

return pooled_features, pooled_empty_flag

@staticmethod
def backward(ctx, grad_out):
raise NotImplementedError


if __name__ == '__main__':
pass
66 changes: 66 additions & 0 deletions mmdet3d/ops/roipoint_pool3d/src/roipoint_pool3d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
Modified for
https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <torch/serialize/tensor.h>
#include <torch/extension.h>

#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)


void roipool3dLauncher(int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num,
const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag);


int roipool3d_gpu(at::Tensor xyz, at::Tensor boxes3d, at::Tensor pts_feature, at::Tensor pooled_features, at::Tensor pooled_empty_flag){
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)
CHECK_INPUT(xyz);
CHECK_INPUT(boxes3d);
CHECK_INPUT(pts_feature);
CHECK_INPUT(pooled_features);
CHECK_INPUT(pooled_empty_flag);

int batch_size = xyz.size(0);
int pts_num = xyz.size(1);
int boxes_num = boxes3d.size(1);
int feature_in_len = pts_feature.size(2);
int sampled_pts_num = pooled_features.size(2);


const float * xyz_data = xyz.data<float>();
const float * boxes3d_data = boxes3d.data<float>();
const float * pts_feature_data = pts_feature.data<float>();
float * pooled_features_data = pooled_features.data<float>();
int * pooled_empty_flag_data = pooled_empty_flag.data<int>();

roipool3dLauncher(batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz_data, boxes3d_data, pts_feature_data, pooled_features_data, pooled_empty_flag_data);



return 1;
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &roipool3d_gpu, "roipool3d forward (CUDA)");
}
168 changes: 168 additions & 0 deletions mmdet3d/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
Modified from
https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/

#include <math.h>
#include <stdio.h>

#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
// #define DEBUG

__device__ inline void lidar_to_local_coords(float shift_x, float shift_y,
float rz, float &local_x,
float &local_y) {
float cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}

__device__ inline int check_pt_in_box3d(const float *pt, const float *box3d,
float &local_x, float &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the
// bottom center
float x = pt[0], y = pt[1], z = pt[2];
float cx = box3d[0], cy = box3d[1], cz = box3d[2];
float dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];
cz += dz / 2.0; // shift to the center since cz in box3d is the bottom center

if (fabsf(z - cz) > dz / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) &
(local_y > -dy / 2.0) & (local_y < dy / 2.0);
return in_flag;
}

__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num, const float *xyz, const float *boxes3d, int *pts_assign){
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means background points
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;

if (pt_idx >= pts_num || box_idx >= boxes_num || bs_idx >= batch_size){
return;
}
int assign_idx = bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
pts_assign[assign_idx] = 0;

int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3;


float local_x = 0, local_y = 0;
int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset, local_x, local_y);
pts_assign[assign_idx] = cur_in_flag;
// printf("bs=%d, pt=%d, in=%d\n", bs_idx, pt_idx, pts_assign[bs_idx * pts_num + pt_idx]);
}


__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num, int sampled_pts_num,
const int *pts_assign, int *pts_idx, int *pooled_empty_flag){
// params xyz: (B, N, 3)
// params pts_feature: (B, N, C)
// params pts_assign: (B, N)
// params pts_idx: (B, M, 512)
// params pooled_empty_flag: (B, M)

int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (boxes_idx >= boxes_num){
return;
}

int bs_idx = blockIdx.y;

int cnt = 0;
for (int k = 0; k < pts_num; k++){
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + boxes_idx]){
if (cnt < sampled_pts_num){
pts_idx[bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num + cnt] = k;
cnt++;
}
else break;
}
}

if (cnt == 0){
pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1;
}
else if (cnt < sampled_pts_num){
// duplicate same points for sampling
for (int k = cnt; k < sampled_pts_num; k++){
int duplicate_idx = k % cnt;
int base_offset = bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num;
pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx];
}
}
}


__global__ void roipool3d_forward(int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num,
const float *xyz, const int *pts_idx, const float *pts_feature,
float *pooled_features, int *pooled_empty_flag){
// params xyz: (B, N, 3)
// params pts_idx: (B, M, 512)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)

int sample_pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;

if (sample_pt_idx >= sampled_pts_num || box_idx >= boxes_num || bs_idx >= batch_size){
return;
}

if (pooled_empty_flag[bs_idx * boxes_num + box_idx]){
return;
}

int temp_idx = bs_idx * boxes_num * sampled_pts_num + box_idx * sampled_pts_num + sample_pt_idx;
int src_pt_idx = pts_idx[temp_idx];
int dst_feature_offset = temp_idx * (3 + feature_in_len);

for (int j = 0; j < 3; j++)
pooled_features[dst_feature_offset + j] = xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j];

int src_feature_offset = bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len;
for (int j = 0; j < feature_in_len; j++)
pooled_features[dst_feature_offset + 3 + j] = pts_feature[src_feature_offset + j];
}


void roipool3dLauncher(int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num,
const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag){

// printf("batch_size=%d, pts_num=%d, boxes_num=%d\n", batch_size, pts_num, boxes_num);
int *pts_assign = NULL;
cudaMalloc(&pts_assign, batch_size * pts_num * boxes_num * sizeof(int)); // (batch_size, N, M)
// cudaMemset(&pts_assign, -1, batch_size * pts_num * boxes_num * sizeof(int));

dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
assign_pts_to_box3d<<<blocks, threads>>>(batch_size, pts_num, boxes_num, xyz, boxes3d, pts_assign);

int *pts_idx = NULL;
cudaMalloc(&pts_idx, batch_size * boxes_num * sampled_pts_num * sizeof(int)); // (batch_size, M, sampled_pts_num)

dim3 blocks2(DIVUP(boxes_num, THREADS_PER_BLOCK), batch_size); // blockIdx.x(col), blockIdx.y(row)
get_pooled_idx<<<blocks2, threads>>>(batch_size, pts_num, boxes_num, sampled_pts_num, pts_assign, pts_idx, pooled_empty_flag);

dim3 blocks_pool(DIVUP(sampled_pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
roipool3d_forward<<<blocks_pool, threads>>>(batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz, pts_idx, pts_feature, pooled_features, pooled_empty_flag);

cudaFree(pts_assign);
cudaFree(pts_idx);

#ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function
#endif
}
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ def add_mim_extention():
'src/roiaware_pool3d_kernel.cu',
'src/points_in_boxes_cuda.cu',
]),
make_cuda_ext(
name='roipoint_pool3d_ext',
module='mmdet3d.ops.roipoint_pool3d',
sources=['src/roipoint_pool3d.cpp'],
sources_cuda=['src/roipoint_pool3d_kernel.cu']),
make_cuda_ext(
name='ball_query_ext',
module='mmdet3d.ops.ball_query',
Expand Down
Loading

0 comments on commit 4f36084

Please sign in to comment.