diff --git a/docs/understand_mmcv/ops.md b/docs/understand_mmcv/ops.md index cd415790e5..df533c3686 100644 --- a/docs/understand_mmcv/ops.md +++ b/docs/understand_mmcv/ops.md @@ -19,6 +19,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - MaskedConv - NMS - PSAMask +- RoIPointPool3d - RoIPool - RoIAlign - SimpleRoIAlign diff --git a/docs_zh_CN/understand_mmcv/ops.md b/docs_zh_CN/understand_mmcv/ops.md index 74a3871591..4a3d99f8c6 100644 --- a/docs_zh_CN/understand_mmcv/ops.md +++ b/docs_zh_CN/understand_mmcv/ops.md @@ -19,6 +19,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - MaskedConv - NMS - PSAMask +- RoIPointPool3d - RoIPool - RoIAlign - SimpleRoIAlign diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index ada85defda..86da62d0aa 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -39,6 +39,7 @@ from .roi_align import RoIAlign, roi_align from .roi_align_rotated import RoIAlignRotated, roi_align_rotated from .roi_pool import RoIPool, roi_pool +from .roipoint_pool3d import RoIPointPool3d from .saconv import SAConv2d from .sync_bn import SyncBatchNorm from .three_interpolate import three_interpolate @@ -60,10 +61,10 @@ 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', - 'box_iou_rotated', 'nms_rotated', 'knn', 'ball_query', 'upfirdn2d', - 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'RoIAlignRotated', - 'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn', - 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', - 'border_align', 'gather_points', 'furthest_point_sample', + 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', + 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', + 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', + 'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention', + 'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample', 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation' ] diff --git a/mmcv/ops/csrc/common/cuda/roipoint_pool3d_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/roipoint_pool3d_cuda_kernel.cuh new file mode 100644 index 0000000000..7597719e69 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/roipoint_pool3d_cuda_kernel.cuh @@ -0,0 +1,144 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROIPOINT_POOL3D_CUDA_KERNEL_CUH +#define ROIPOINT_POOL3D_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz, + T &local_x, T &local_y) { + T cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +template +__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x, + T &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the + // bottom center + T x = pt[0], y = pt[1], z = pt[2]; + T cx = box3d[0], cy = box3d[1], cz = box3d[2]; + T dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6]; + cz += dz / 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > dz / 2.0) return 0; + lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); + T in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) & + (local_y > -dy / 2.0) & (local_y < dy / 2.0); + return in_flag; +} + +template +__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num, + const T *xyz, const T *boxes3d, + int *pts_assign) { + // params xyz: (B, N, 3) + // params boxes3d: (B, M, 7) + // params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means + // background points + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + int box_idx = blockIdx.y; + int bs_idx = blockIdx.z; + + if (pt_idx >= pts_num || box_idx >= boxes_num || bs_idx >= batch_size) { + return; + } + int assign_idx = bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx; + pts_assign[assign_idx] = 0; + + int box_offset = bs_idx * boxes_num * 7 + box_idx * 7; + int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3; + + T local_x = 0, local_y = 0; + int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset, + local_x, local_y); + pts_assign[assign_idx] = cur_in_flag; +} + +__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num, + int sampled_pts_num, const int *pts_assign, + int *pts_idx, int *pooled_empty_flag) { + // params xyz: (B, N, 3) + // params pts_feature: (B, N, C) + // params pts_assign: (B, N) + // params pts_idx: (B, M, 512) + // params pooled_empty_flag: (B, M) + + int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (boxes_idx >= boxes_num) { + return; + } + + int bs_idx = blockIdx.y; + + int cnt = 0; + for (int k = 0; k < pts_num; k++) { + if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + boxes_idx]) { + if (cnt < sampled_pts_num) { + pts_idx[bs_idx * boxes_num * sampled_pts_num + + boxes_idx * sampled_pts_num + cnt] = k; + cnt++; + } else + break; + } + } + + if (cnt == 0) { + pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1; + } else if (cnt < sampled_pts_num) { + // duplicate same points for sampling + for (int k = cnt; k < sampled_pts_num; k++) { + int duplicate_idx = k % cnt; + int base_offset = + bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num; + pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx]; + } + } +} + +template +__global__ void roipoint_pool3d_forward( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const T *xyz, const int *pts_idx, const T *pts_feature, + T *pooled_features, int *pooled_empty_flag) { + // params xyz: (B, N, 3) + // params pts_idx: (B, M, 512) + // params pts_feature: (B, N, C) + // params pooled_features: (B, M, 512, 3+C) + // params pooled_empty_flag: (B, M) + + int sample_pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + int box_idx = blockIdx.y; + int bs_idx = blockIdx.z; + + if (sample_pt_idx >= sampled_pts_num || box_idx >= boxes_num || + bs_idx >= batch_size) { + return; + } + + if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) { + return; + } + + int temp_idx = bs_idx * boxes_num * sampled_pts_num + + box_idx * sampled_pts_num + sample_pt_idx; + int src_pt_idx = pts_idx[temp_idx]; + int dst_feature_offset = temp_idx * (3 + feature_in_len); + + for (int j = 0; j < 3; j++) + pooled_features[dst_feature_offset + j] = + xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j]; + + int src_feature_offset = + bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len; + memcpy(pooled_features + dst_feature_offset + 3, + pts_feature + src_feature_offset, feature_in_len * sizeof(T)); +} + +#endif // ROIPOINT_POOL3D_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/roipoint_pool3d_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/roipoint_pool3d_cuda.cu new file mode 100644 index 0000000000..49c003f909 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/roipoint_pool3d_cuda.cu @@ -0,0 +1,60 @@ +/* +Modified from +https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu +Point cloud feature pooling +Written by Shaoshuai Shi +All Rights Reserved 2018. +*/ + +#include +#include + +#include "pytorch_cuda_helper.hpp" +#include "roipoint_pool3d_cuda_kernel.cuh" + +void RoIPointPool3dForwardCUDAKernelLauncher( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, Tensor pooled_features, + Tensor pooled_empty_flag) { + Tensor pts_assign = at::empty({batch_size, pts_num, boxes_num}, + boxes3d.options().dtype(at::kInt)); + + at::cuda::CUDAGuard device_guard(xyz.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz.scalar_type(), "assign_pts_to_box3d", [&] { + assign_pts_to_box3d<<>>( + batch_size, pts_num, boxes_num, xyz.data_ptr(), + boxes3d.data_ptr(), pts_assign.data_ptr()); + }); + + Tensor pts_idx = at::empty({batch_size, boxes_num, sampled_pts_num}, + boxes3d.options().dtype(at::kInt)); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks2(DIVUP(boxes_num, THREADS_PER_BLOCK), batch_size); + + get_pooled_idx<<>>( + batch_size, pts_num, boxes_num, sampled_pts_num, + pts_assign.data_ptr(), pts_idx.data_ptr(), + pooled_empty_flag.data_ptr()); + + dim3 blocks_pool(DIVUP(sampled_pts_num, THREADS_PER_BLOCK), boxes_num, + batch_size); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz.scalar_type(), "roipoint_pool3d_forward", [&] { + roipoint_pool3d_forward<<>>( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + xyz.data_ptr(), pts_idx.data_ptr(), + pts_feature.data_ptr(), + pooled_features.data_ptr(), + pooled_empty_flag.data_ptr()); + }); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index b35ccb3b76..385cf1a608 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -65,6 +65,9 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois, int pooled_width, float spatial_scale, int sampling_ratio, float gamma); +void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, + Tensor pooled_features, Tensor pooled_empty_flag); + void gather_points_forward(int b, int c, int n, int npoints, Tensor points_tensor, Tensor idx_tensor, Tensor out_tensor); @@ -364,6 +367,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("grad_offset"), py::arg("pooled_height"), py::arg("pooled_width"), py::arg("spatial_scale"), py::arg("sampling_ratio"), py::arg("gamma")); + m.def("roipoint_pool3d_forward", &roipoint_pool3d_forward, + "roipoint_pool3d_forward", py::arg("xyz"), py::arg("boxes3d"), + py::arg("pts_feature"), py::arg("pooled_features"), + py::arg("pooled_empty_flag")); m.def("sigmoid_focal_loss_forward", &sigmoid_focal_loss_forward, "sigmoid_focal_loss_forward ", py::arg("input"), py::arg("target"), py::arg("weight"), py::arg("output"), py::arg("gamma"), diff --git a/mmcv/ops/csrc/pytorch/roipoint_pool3d.cpp b/mmcv/ops/csrc/pytorch/roipoint_pool3d.cpp new file mode 100644 index 0000000000..e9b5054e70 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/roipoint_pool3d.cpp @@ -0,0 +1,60 @@ +/* +Modified from +https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d.cpp +Point cloud feature pooling +Written by Shaoshuai Shi +All Rights Reserved 2018. +*/ + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void RoIPointPool3dForwardCUDAKernelLauncher( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); + +void roipoint_pool3d_forward_cuda(int batch_size, int pts_num, int boxes_num, + int feature_in_len, int sampled_pts_num, + const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, + Tensor pooled_features, + Tensor pooled_empty_flag) { + RoIPointPool3dForwardCUDAKernelLauncher( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz, + boxes3d, pts_feature, pooled_features, pooled_empty_flag); +}; +#endif + +void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, + Tensor pooled_features, Tensor pooled_empty_flag) { + // params xyz: (B, N, 3) + // params boxes3d: (B, M, 7) + // params pts_feature: (B, N, C) + // params pooled_features: (B, M, 512, 3+C) + // params pooled_empty_flag: (B, M) + + if (xyz.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(xyz); + CHECK_CUDA_INPUT(boxes3d); + CHECK_CUDA_INPUT(pts_feature); + CHECK_CUDA_INPUT(pooled_features); + CHECK_CUDA_INPUT(pooled_empty_flag); + + int batch_size = xyz.size(0); + int pts_num = xyz.size(1); + int boxes_num = boxes3d.size(1); + int feature_in_len = pts_feature.size(2); + int sampled_pts_num = pooled_features.size(2); + + roipoint_pool3d_forward_cuda(batch_size, pts_num, boxes_num, feature_in_len, + sampled_pts_num, xyz, boxes3d, pts_feature, + pooled_features, pooled_empty_flag); +#else + AT_ERROR("roipoint_pool3d is not compiled with GPU support"); +#endif + } else { + AT_ERROR("roipoint_pool3d is not implemented on CPU"); + } +} diff --git a/mmcv/ops/roipoint_pool3d.py b/mmcv/ops/roipoint_pool3d.py new file mode 100644 index 0000000000..0a21412c07 --- /dev/null +++ b/mmcv/ops/roipoint_pool3d.py @@ -0,0 +1,77 @@ +from torch import nn as nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward']) + + +class RoIPointPool3d(nn.Module): + """Encode the geometry-specific features of each 3D proposal. + + Please refer to `Paper of PartA2 `_ + for more details. + + Args: + num_sampled_points (int, optional): Number of samples in each roi. + Default: 512. + """ + + def __init__(self, num_sampled_points=512): + super().__init__() + self.num_sampled_points = num_sampled_points + + def forward(self, points, point_features, boxes3d): + """ + Args: + points (torch.Tensor): Input points whose shape is (B, N, C). + point_features (torch.Tensor): Features of input points whose shape + is (B, N, C). + boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). + + Returns: + pooled_features (torch.Tensor): The output pooled features whose + shape is (B, M, 512, 3 + C). + pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). + """ + return RoIPointPool3dFunction.apply(points, point_features, boxes3d, + self.num_sampled_points) + + +class RoIPointPool3dFunction(Function): + + @staticmethod + def forward(ctx, points, point_features, boxes3d, num_sampled_points=512): + """ + Args: + points (torch.Tensor): Input points whose shape is (B, N, C). + point_features (torch.Tensor): Features of input points whose shape + is (B, N, C). + boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). + num_sampled_points (int, optional): The num of sampled points. + Default: 512. + + Returns: + pooled_features (torch.Tensor): The output pooled features whose + shape is (B, M, 512, 3 + C). + pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). + """ + assert len(points.shape) == 3 and points.shape[2] == 3 + batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[ + 1], point_features.shape[2] + pooled_boxes3d = boxes3d.view(batch_size, -1, 7) + pooled_features = point_features.new_zeros( + (batch_size, boxes_num, num_sampled_points, 3 + feature_len)) + pooled_empty_flag = point_features.new_zeros( + (batch_size, boxes_num)).int() + + ext_module.roipoint_pool3d_forward(points.contiguous(), + pooled_boxes3d.contiguous(), + point_features.contiguous(), + pooled_features, pooled_empty_flag) + + return pooled_features, pooled_empty_flag + + @staticmethod + def backward(ctx, grad_out): + raise NotImplementedError diff --git a/tests/test_ops/test_roipoint_pool3d.py b/tests/test_ops/test_roipoint_pool3d.py new file mode 100644 index 0000000000..7db3885d71 --- /dev/null +++ b/tests/test_ops/test_roipoint_pool3d.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from mmcv.ops import RoIPointPool3d + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_gather_points(): + feats = torch.tensor( + [[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], + [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]], + dtype=torch.float32).unsqueeze(0).cuda() + points = feats.clone() + rois = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], + [-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], + dtype=torch.float32).cuda() + + roipoint_pool3d = RoIPointPool3d(num_sampled_points=4) + roi_feat, empty_flag = roipoint_pool3d(feats, points, rois) + expected_roi_feat = torch.tensor([[[[1, 2, 3.3, 1, 2, 3.3], + [1.2, 2.5, 3, 1.2, 2.5, 3], + [0.8, 2.1, 3.5, 0.8, 2.1, 3.5], + [1.6, 2.6, 3.6, 1.6, 2.6, 3.6]], + [[-9.2, 21, 18.2, -9.2, 21, 18.2], + [-9.2, 21, 18.2, -9.2, 21, 18.2], + [-9.2, 21, 18.2, -9.2, 21, 18.2], + [-9.2, 21, 18.2, -9.2, 21, + 18.2]]]]).cuda() + expected_empty_flag = torch.tensor([[0, 0]]).int().cuda() + + assert torch.allclose(roi_feat, expected_roi_feat) + assert torch.allclose(empty_flag, expected_empty_flag)