From a5470c3e7ef7f244871a1aaa1aa1c114b8f4dd0b Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Mon, 27 Sep 2021 03:07:26 +0800 Subject: [PATCH 01/12] add ops (voxel) in mmdet3d --- docs/understand_mmcv/ops.md | 2 + docs_zh_CN/understand_mmcv/ops.md | 2 + mmcv/ops/__init__.py | 5 +- .../cuda/scatter_points_cuda_kernel.cuh | 171 +++++++++++++++ .../common/cuda/voxelization_cuda_kernel.cuh | 175 ++++++++++++++++ .../csrc/pytorch/cuda/scatter_points_cuda.cu | 128 ++++++++++++ .../csrc/pytorch/cuda/voxelization_cuda.cu | 194 ++++++++++++++++++ mmcv/ops/csrc/pytorch/pybind.cpp | 40 ++++ mmcv/ops/csrc/pytorch/scatter_points.cpp | 81 ++++++++ mmcv/ops/csrc/pytorch/voxelization.cpp | 85 ++++++++ mmcv/ops/csrc/pytorch/voxelization_cpu.cpp | 157 ++++++++++++++ mmcv/ops/scatter_points.py | 110 ++++++++++ mmcv/ops/voxelization.py | 126 ++++++++++++ tests/test_ops/test_scatter_points.py | 93 +++++++++ 14 files changed, 1368 insertions(+), 1 deletion(-) create mode 100644 mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/scatter_points.cpp create mode 100644 mmcv/ops/csrc/pytorch/voxelization.cpp create mode 100644 mmcv/ops/csrc/pytorch/voxelization_cpu.cpp create mode 100644 mmcv/ops/scatter_points.py create mode 100644 mmcv/ops/voxelization.py create mode 100644 tests/test_ops/test_scatter_points.py diff --git a/docs/understand_mmcv/ops.md b/docs/understand_mmcv/ops.md index e04f32c7cd..1d866ea128 100644 --- a/docs/understand_mmcv/ops.md +++ b/docs/understand_mmcv/ops.md @@ -10,6 +10,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - CornerPool - Deformable Convolution v1/v2 - Deformable RoIPool +- DynamicScatter - GeneralizedAttention - MaskedConv - NMS @@ -21,5 +22,6 @@ We implement common CUDA ops used in detection, segmentation, etc. - SoftmaxFocalLoss - SoftNMS - Synchronized BatchNorm +- Voxelization - Weight standardization - Correlation diff --git a/docs_zh_CN/understand_mmcv/ops.md b/docs_zh_CN/understand_mmcv/ops.md index 0d7c17fd7f..e114ffdcb9 100644 --- a/docs_zh_CN/understand_mmcv/ops.md +++ b/docs_zh_CN/understand_mmcv/ops.md @@ -10,6 +10,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - CornerPool - Deformable Convolution v1/v2 - Deformable RoIPool +- DynamicScatter - GeneralizedAttention - MaskedConv - NMS @@ -21,5 +22,6 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - SoftmaxFocalLoss - SoftNMS - Synchronized BatchNorm +- Voxelization - Weight standardization - Correlation diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 3a4a838e7f..d2f2db528c 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -34,9 +34,11 @@ from .roi_align_rotated import RoIAlignRotated, roi_align_rotated from .roi_pool import RoIPool, roi_pool from .saconv import SAConv2d +from .scatter_points import DynamicScatter, dynamic_scatter from .sync_bn import SyncBatchNorm from .tin_shift import TINShift, tin_shift from .upfirdn2d import upfirdn2d +from .voxelize import Voxelization, voxelization __all__ = [ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', @@ -54,6 +56,7 @@ 'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', - 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align', + 'MultiScaleDeformableAttention', 'Voxelization', 'voxelization', + 'dynamic_scatter', 'DynamicScatter', 'BorderAlign', 'border_align', 'Correlation' ] diff --git a/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh new file mode 100644 index 0000000000..7deade0547 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh @@ -0,0 +1,171 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef SCATTER_POINTS_CUDA_KERNEL_CUH +#define SCATTER_POINTS_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; + +__device__ __forceinline__ static void reduceMax(float *address, float val) { + int *address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS(address_as_i, assumed, + __float_as_int(fmaxf(val, __int_as_float(assumed)))); + } while (assumed != old || __int_as_float(old) < val); +} + +__device__ __forceinline__ static void reduceMax(double *address, double val) { + unsigned long long *address_as_ull = + reinterpret_cast(address); + unsigned long long old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS( + address_as_ull, assumed, + __double_as_longlong(fmax(val, __longlong_as_double(assumed)))); + } while (assumed != old || __longlong_as_double(old) < val); +} + +// get rid of meaningless warnings when compiling host code +#ifdef __CUDA_ARCH__ +__device__ __forceinline__ static void reduceAdd(float *address, float val) { +#if (__CUDA_ARCH__ < 200) +#warning \ + "compute capability lower than 2.x. fall back to use CAS version of atomicAdd for float32" + int *address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS(address_as_i, assumed, + __float_as_int(val + __int_as_float(assumed))); + } while (assumed != old); +#else + atomicAdd(address, val); +#endif +} + +__device__ __forceinline__ static void reduceAdd(double *address, double val) { +#if (__CUDA_ARCH__ < 600) +#warning \ + "compute capability lower than 6.x. fall back to use CAS version of atomicAdd for float64" + unsigned long long *address_as_ull = + reinterpret_cast(address); + unsigned long long old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); +#else + atomicAdd(address, val); +#endif +} +#endif + +template +__global__ void feats_reduce_kernel( + const T *feats, const int32_t *coors_map, + T *reduced_feats, // shall be 0 at initialization + const int num_input, const int num_feats, const reduce_t reduce_type) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input; + x += gridDim.x * blockDim.x) { + int32_t reduce_to = coors_map[x]; + if (reduce_to == -1) continue; + + const T *feats_offset = feats + x * num_feats; + T *reduced_feats_offset = reduced_feats + reduce_to * num_feats; + if (reduce_type == reduce_t::MAX) { + for (int i = 0; i < num_feats; i++) { + reduceMax(&reduced_feats_offset[i], feats_offset[i]); + } + } else { + for (int i = 0; i < num_feats; i++) { + reduceAdd(&reduced_feats_offset[i], feats_offset[i]); + } + } + } +} + +template +__global__ void add_reduce_traceback_grad_kernel( + T *grad_feats, const T *grad_reduced_feats, const int32_t *coors_map, + const int32_t *reduce_count, const int num_input, const int num_feats, + const reduce_t reduce_type) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input; + x += gridDim.x * blockDim.x) { + int32_t reduce_to = coors_map[x]; + if (reduce_to == -1) { + continue; + } + + const int input_offset = x * num_feats; + T *grad_feats_offset = grad_feats + input_offset; + const int reduced_offset = reduce_to * num_feats; + const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset; + + if (reduce_type == reduce_t::SUM) { + for (int i = 0; i < num_feats; i++) { + grad_feats_offset[i] = grad_reduced_feats_offset[i]; + } + } else if (reduce_type == reduce_t::MEAN) { + for (int i = 0; i < num_feats; i++) { + grad_feats_offset[i] = grad_reduced_feats_offset[i] / + static_cast(reduce_count[reduce_to]); + } + } + } +} + +template +__global__ void max_reduce_traceback_scatter_idx_kernel( + const T *feats, const T *reduced_feats, int32_t *reduce_from, + const int32_t *coors_map, const int num_input, const int num_feats) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input; + x += gridDim.x * blockDim.x) { + int32_t reduce_to = coors_map[x]; + + const int input_offset = x * num_feats; + const T *feats_offset = feats + input_offset; + + if (reduce_to == -1) { + continue; + } + + const int reduced_offset = reduce_to * num_feats; + const T *reduced_feats_offset = reduced_feats + reduced_offset; + int32_t *reduce_from_offset = reduce_from + reduced_offset; + + for (int i = 0; i < num_feats; i++) { + if (feats_offset[i] == reduced_feats_offset[i]) { + atomicMin(&reduce_from_offset[i], static_cast(x)); + } + } + } +} + +template +__global__ void max_reduce_scatter_grad_kernel(T *grad_feats, + const T *grad_reduced_feats, + const int32_t *reduce_from, + const int num_reduced, + const int num_feats) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_reduced; + x += gridDim.x * blockDim.x) { + const int reduced_offset = x * num_feats; + const int32_t *scatter_to_offset = reduce_from + reduced_offset; + const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset; + + for (int i = 0; i < num_feats; i++) { + grad_feats[scatter_to_offset[i] * num_feats + i] = + grad_reduced_feats_offset[i]; + } + } +} + +#endif // SCATTER_POINTS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh new file mode 100644 index 0000000000..638f327057 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh @@ -0,0 +1,175 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef VOXELIZATION_CUDA_KERNEL_CUH +#define VOXELIZATION_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +template +__global__ void dynamic_voxelize_kernel( + const T* points, T_int* coors, const float voxel_x, const float voxel_y, + const float voxel_z, const float coors_x_min, const float coors_y_min, + const float coors_z_min, const float coors_x_max, const float coors_y_max, + const float coors_z_max, const int grid_x, const int grid_y, + const int grid_z, const int num_points, const int num_features, + const int NDim) { + // const int index = blockIdx.x * threadsPerBlock + threadIdx.x; + CUDA_1D_KERNEL_LOOP(index, num_points) { + // To save some computation + auto points_offset = points + index * num_features; + auto coors_offset = coors + index * NDim; + int c_x = floor((points_offset[0] - coors_x_min) / voxel_x); + if (c_x < 0 || c_x >= grid_x) { + coors_offset[0] = -1; + return; + } + + int c_y = floor((points_offset[1] - coors_y_min) / voxel_y); + if (c_y < 0 || c_y >= grid_y) { + coors_offset[0] = -1; + coors_offset[1] = -1; + return; + } + + int c_z = floor((points_offset[2] - coors_z_min) / voxel_z); + if (c_z < 0 || c_z >= grid_z) { + coors_offset[0] = -1; + coors_offset[1] = -1; + coors_offset[2] = -1; + } else { + coors_offset[0] = c_z; + coors_offset[1] = c_y; + coors_offset[2] = c_x; + } + } +} + +template +__global__ void assign_point_to_voxel(const int nthreads, const T* points, + T_int* point_to_voxelidx, + T_int* coor_to_voxelidx, T* voxels, + const int max_points, + const int num_features, + const int num_points, const int NDim) { + CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) { + // const int index = blockIdx.x * threadsPerBlock + threadIdx.x; + int index = thread_idx / num_features; + + int num = point_to_voxelidx[index]; + int voxelidx = coor_to_voxelidx[index]; + if (num > -1 && voxelidx > -1) { + auto voxels_offset = + voxels + voxelidx * max_points * num_features + num * num_features; + + int k = thread_idx % num_features; + voxels_offset[k] = points[thread_idx]; + } + } +} + +template +__global__ void assign_voxel_coors(const int nthreads, T_int* coor, + T_int* point_to_voxelidx, + T_int* coor_to_voxelidx, T_int* voxel_coors, + const int num_points, const int NDim) { + CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) { + // const int index = blockIdx.x * threadsPerBlock + threadIdx.x; + // if (index >= num_points) return; + int index = thread_idx / NDim; + int num = point_to_voxelidx[index]; + int voxelidx = coor_to_voxelidx[index]; + if (num == 0 && voxelidx > -1) { + auto coors_offset = voxel_coors + voxelidx * NDim; + int k = thread_idx % NDim; + coors_offset[k] = coor[thread_idx]; + } + } +} + +template +__global__ void point_to_voxelidx_kernel(const T_int* coor, + T_int* point_to_voxelidx, + T_int* point_to_pointidx, + const int max_points, + const int max_voxels, + const int num_points, const int NDim) { + CUDA_1D_KERNEL_LOOP(index, num_points) { + auto coor_offset = coor + index * NDim; + // skip invalid points + if ((index >= num_points) || (coor_offset[0] == -1)) return; + + int num = 0; + int coor_x = coor_offset[0]; + int coor_y = coor_offset[1]; + int coor_z = coor_offset[2]; + // only calculate the coors before this coor[index] + for (int i = 0; i < index; ++i) { + auto prev_coor = coor + i * NDim; + if (prev_coor[0] == -1) continue; + + // Find all previous points that have the same coors + // if find the same coor, record it + if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) && + (prev_coor[2] == coor_z)) { + num++; + if (num == 1) { + // point to the same coor that first show up + point_to_pointidx[index] = i; + } else if (num >= max_points) { + // out of boundary + return; + } + } + } + if (num == 0) { + point_to_pointidx[index] = index; + } + if (num < max_points) { + point_to_voxelidx[index] = num; + } + } +} + +template +__global__ void determin_voxel_num( + // const T_int* coor, + T_int* num_points_per_voxel, T_int* point_to_voxelidx, + T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num, + const int max_points, const int max_voxels, const int num_points) { + // only calculate the coors before this coor[index] + for (int i = 0; i < num_points; ++i) { + // if (coor[i][0] == -1) + // continue; + int point_pos_in_voxel = point_to_voxelidx[i]; + // record voxel + if (point_pos_in_voxel == -1) { + // out of max_points or invalid point + continue; + } else if (point_pos_in_voxel == 0) { + // record new voxel + int voxelidx = voxel_num[0]; + if (voxel_num[0] >= max_voxels) continue; + voxel_num[0] += 1; + coor_to_voxelidx[i] = voxelidx; + num_points_per_voxel[voxelidx] = 1; + } else { + int point_idx = point_to_pointidx[i]; + int voxelidx = coor_to_voxelidx[point_idx]; + if (voxelidx != -1) { + coor_to_voxelidx[i] = voxelidx; + num_points_per_voxel[voxelidx] += 1; + } + } + } +} + +#endif // VOXELIZATION_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu new file mode 100644 index 0000000000..3710bb7a6e --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu @@ -0,0 +1,128 @@ +#include +#include + +#include "pytorch_cuda_helper.hpp" +#include "scatter_points_cuda_kernel.cuh" + +std::vector DynamicPointToVoxelForwardCUDAKernelLauncher( + const at::Tensor &feats, const at::Tensor &coors, + const reduce_t reduce_type) { + CHECK_CUDA_INPUT(feats); + CHECK_CUDA_INPUT(coors); + + const int num_input = feats.size(0); + const int num_feats = feats.size(1); + + if (num_input == 0) + return {feats.clone().detach(), coors.clone().detach(), + coors.new_empty({0}, torch::kInt32), + coors.new_empty({0}, torch::kInt32)}; + + at::Tensor out_coors; + at::Tensor coors_map; + at::Tensor reduce_count; + + auto coors_clean = coors.masked_fill(coors.lt(0).any(-1, true), -1); + + std::tie(out_coors, coors_map, reduce_count) = + at::unique_dim(coors_clean, 0, true, true, true); + + // the first element of out_coors is always (-1,-1,-1) and should be removed + out_coors = out_coors.slice(0, 1); + reduce_count = reduce_count.slice(0, 1).to(torch::kInt32); + coors_map = coors_map.to(torch::kInt32) - 1; + + auto reduced_feats = + at::empty({out_coors.size(0), num_feats}, feats.options()); + + AT_DISPATCH_FLOATING_TYPES( + feats.scalar_type(), "feats_reduce_kernel", ([&] { + if (reduce_type == reduce_t::MAX) + reduced_feats.fill_(-std::numeric_limits::infinity()); + else + reduced_feats.fill_(static_cast(0)); + + dim3 blocks(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock), + maxGridDim)); + dim3 threads(threadsPerBlock); + feats_reduce_kernel<<>>( + feats.data_ptr(), coors_map.data_ptr(), + reduced_feats.data_ptr(), num_input, num_feats, + reduce_type); + if (reduce_type == reduce_t::MEAN) + reduced_feats /= reduce_count.unsqueeze(-1).to(reduced_feats.dtype()); + })); + + AT_CUDA_CHECK(cudaGetLastError()); + + return {reduced_feats, out_coors, coors_map, reduce_count}; +} + +void DynamicPointToVoxelBackwardCUDAKernelLauncher( + at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats, + const at::Tensor &feats, const at::Tensor &reduced_feats, + const at::Tensor &coors_map, const at::Tensor &reduce_count, + const reduce_t reduce_type) { + CHECK_CUDA_INPUT(grad_feats); + CHECK_CUDA_INPUT(grad_reduced_feats); + CHECK_CUDA_INPUT(feats); + CHECK_CUDA_INPUT(reduced_feats); + CHECK_CUDA_INPUT(coors_map); + CHECK_CUDA_INPUT(reduce_count); + + const int num_input = feats.size(0); + const int num_reduced = reduced_feats.size(0); + const int num_feats = feats.size(1); + + grad_feats.fill_(0); + // copy voxel grad to points + + if (num_input == 0 || num_reduced == 0) return; + + if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) { + AT_DISPATCH_FLOATING_TYPES( + grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel", + ([&] { + dim3 blocks(std::min( + at::cuda::ATenCeilDiv(num_input, threadsPerBlock), maxGridDim)); + dim3 threads(threadsPerBlock); + add_reduce_traceback_grad_kernel<<>>( + grad_feats.data_ptr(), + grad_reduced_feats.data_ptr(), + coors_map.data_ptr(), reduce_count.data_ptr(), + num_input, num_feats, reduce_type); + })); + + AT_CUDA_CHECK(cudaGetLastError()); + } else { + auto reduce_from = at::full({num_reduced, num_feats}, num_input, + coors_map.options().dtype(torch::kInt32)); + AT_DISPATCH_FLOATING_TYPES( + grad_reduced_feats.scalar_type(), + "max_reduce_traceback_scatter_idx_kernel", ([&] { + dim3 blocks(std::min( + at::cuda::ATenCeilDiv(num_input, threadsPerBlock), maxGridDim)); + dim3 threads(threadsPerBlock); + max_reduce_traceback_scatter_idx_kernel<<>>( + feats.data_ptr(), reduced_feats.data_ptr(), + reduce_from.data_ptr(), coors_map.data_ptr(), + num_input, num_feats); + })); + + AT_CUDA_CHECK(cudaGetLastError()); + + AT_DISPATCH_FLOATING_TYPES( + grad_reduced_feats.scalar_type(), + "max_reduce_traceback_scatter_idx_kernel", ([&] { + dim3 blocks(std::min( + at::cuda::ATenCeilDiv(num_reduced, threadsPerBlock), maxGridDim)); + dim3 threads(threadsPerBlock); + max_reduce_scatter_grad_kernel<<>>( + grad_feats.data_ptr(), + grad_reduced_feats.data_ptr(), + reduce_from.data_ptr(), num_reduced, num_feats); + })); + + AT_CUDA_CHECK(cudaGetLastError()); + } +} diff --git a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu new file mode 100644 index 0000000000..a3edb8f71e --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu @@ -0,0 +1,194 @@ +#include +#include + +#include "pytorch_cuda_helper.hpp" +#include "voxelization_cuda_kernel.cuh" + +int HardVoxelizeForwardCUDAKernelLauncher( + const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors, + at::Tensor& num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3) { + // current version tooks about 0.04s for one frame on cpu + // check device + CHECK_CUDA_INPUT(points); + + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int num_points = points.size(0); + const int num_features = points.size(1); + + const float voxel_x = voxel_size[0]; + const float voxel_y = voxel_size[1]; + const float voxel_z = voxel_size[2]; + const float coors_x_min = coors_range[0]; + const float coors_y_min = coors_range[1]; + const float coors_z_min = coors_range[2]; + const float coors_x_max = coors_range[3]; + const float coors_y_max = coors_range[4]; + const float coors_z_max = coors_range[5]; + + const int grid_x = round((coors_x_max - coors_x_min) / voxel_x); + const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); + const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); + + // map points to voxel coors + at::Tensor temp_coors = + at::zeros({num_points, NDim}, points.options().dtype(at::kInt)); + + dim3 grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096)); + dim3 block(512); + + // 1. link point to corresponding voxel coors + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "hard_voxelize_kernel", ([&] { + dynamic_voxelize_kernel<<>>( + points.contiguous().data_ptr(), + temp_coors.contiguous().data_ptr(), voxel_x, voxel_y, voxel_z, + coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max, + coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, + NDim); + })); + + cudaDeviceSynchronize(); + AT_CUDA_CHECK(cudaGetLastError()); + + // 2. map point to the idx of the corresponding voxel, find duplicate coor + // create some temporary variables + auto point_to_pointidx = -at::ones( + { + num_points, + }, + points.options().dtype(at::kInt)); + auto point_to_voxelidx = -at::ones( + { + num_points, + }, + points.options().dtype(at::kInt)); + + dim3 map_grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096)); + dim3 map_block(512); + + AT_DISPATCH_ALL_TYPES( + temp_coors.scalar_type(), "determin_duplicate", ([&] { + point_to_voxelidx_kernel<<>>( + temp_coors.contiguous().data_ptr(), + point_to_voxelidx.contiguous().data_ptr(), + point_to_pointidx.contiguous().data_ptr(), max_points, + max_voxels, num_points, NDim); + })); + + cudaDeviceSynchronize(); + AT_CUDA_CHECK(cudaGetLastError()); + + // 3. determin voxel num and voxel's coor index + // make the logic in the CUDA device could accelerate about 10 times + auto coor_to_voxelidx = -at::ones( + { + num_points, + }, + points.options().dtype(at::kInt)); + auto voxel_num = at::zeros( + { + 1, + }, + points.options().dtype(at::kInt)); // must be zero from the begining + + AT_DISPATCH_ALL_TYPES(temp_coors.scalar_type(), "determin_duplicate", ([&] { + determin_voxel_num<<<1, 1, 0, stream>>>( + num_points_per_voxel.contiguous().data_ptr(), + point_to_voxelidx.contiguous().data_ptr(), + point_to_pointidx.contiguous().data_ptr(), + coor_to_voxelidx.contiguous().data_ptr(), + voxel_num.contiguous().data_ptr(), + max_points, max_voxels, num_points); + })); + + cudaDeviceSynchronize(); + AT_CUDA_CHECK(cudaGetLastError()); + + // 4. copy point features to voxels + // Step 4 & 5 could be parallel + auto pts_output_size = num_points * num_features; + dim3 cp_grid(std::min(at::cuda::ATenCeilDiv(pts_output_size, 512), 4096)); + dim3 cp_block(512); + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "assign_point_to_voxel", ([&] { + assign_point_to_voxel<<>>( + pts_output_size, points.contiguous().data_ptr(), + point_to_voxelidx.contiguous().data_ptr(), + coor_to_voxelidx.contiguous().data_ptr(), + voxels.contiguous().data_ptr(), max_points, num_features, + num_points, NDim); + })); + // cudaDeviceSynchronize(); + // AT_CUDA_CHECK(cudaGetLastError()); + + // 5. copy coors of each voxels + auto coors_output_size = num_points * NDim; + dim3 coors_cp_grid( + std::min(at::cuda::ATenCeilDiv(coors_output_size, 512), 4096)); + dim3 coors_cp_block(512); + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "assign_point_to_voxel", ([&] { + assign_voxel_coors + <<>>( + coors_output_size, temp_coors.contiguous().data_ptr(), + point_to_voxelidx.contiguous().data_ptr(), + coor_to_voxelidx.contiguous().data_ptr(), + coors.contiguous().data_ptr(), num_points, NDim); + })); + + cudaDeviceSynchronize(); + AT_CUDA_CHECK(cudaGetLastError()); + + auto voxel_num_cpu = voxel_num.to(at::kCPU); + int voxel_num_int = voxel_num_cpu.data_ptr()[0]; + + return voxel_num_int; +} + +void DynamicVoxelizeForwardCUDAKernelLauncher( + const at::Tensor& points, at::Tensor& coors, + const std::vector voxel_size, const std::vector coors_range, + const int NDim = 3) { + // current version tooks about 0.04s for one frame on cpu + // check device + CHECK_CUDA_INPUT(points); + + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int num_points = points.size(0); + const int num_features = points.size(1); + + const float voxel_x = voxel_size[0]; + const float voxel_y = voxel_size[1]; + const float voxel_z = voxel_size[2]; + const float coors_x_min = coors_range[0]; + const float coors_y_min = coors_range[1]; + const float coors_z_min = coors_range[2]; + const float coors_x_max = coors_range[3]; + const float coors_y_max = coors_range[4]; + const float coors_z_max = coors_range[5]; + + const int grid_x = round((coors_x_max - coors_x_min) / voxel_x); + const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); + const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); + + const int col_blocks = at::cuda::ATenCeilDiv(num_points, threadsPerBlock); + dim3 blocks(col_blocks); + dim3 threads(threadsPerBlock); + + AT_DISPATCH_ALL_TYPES(points.scalar_type(), "dynamic_voxelize_kernel", [&] { + dynamic_voxelize_kernel<<>>( + points.contiguous().data_ptr(), + coors.contiguous().data_ptr(), voxel_x, voxel_y, voxel_z, + coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max, + coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, NDim); + }); + + cudaDeviceSynchronize(); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index ee3d4e3bd8..8fc2a6e9da 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -217,6 +217,30 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois, int pooled_width, float spatial_scale, int sample_num, bool aligned, bool clockwise); +std::vector dynamic_point_to_voxel_forward( + const torch::Tensor &feats, const torch::Tensor &coors, + const std::string &reduce_type); + +void dynamic_point_to_voxel_backward(torch::Tensor &grad_feats, + const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, + const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, + const torch::Tensor &reduce_count, + const std::string &reduce_type); + +int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim); + +void dynamic_voxelize_forward(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim); + void border_align_forward(const Tensor &input, const Tensor &boxes, Tensor output, Tensor argmax_idx, const int pool_size); @@ -444,6 +468,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("grad_input"), py::arg("pooled_height"), py::arg("pooled_width"), py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise")); + m.def("dynamic_point_to_voxel_forward", &dynamic_point_to_voxel_forward, + "dynamic_point_to_voxel_forward", py::arg("feats"), py::arg("coors"), + py::arg("reduce_type")); + m.def("dynamic_point_to_voxel_backward", &dynamic_point_to_voxel_backward, + "dynamic_point_to_voxel_backward", py::arg("grad_feats"), + py::arg("grad_reduced_feats"), py::arg("feats"), + py::arg("reduced_feats"), py::arg("coors_idx"), py::arg("reduce_count"), + py::arg("reduce_type")); + m.def("hard_voxelize_forward", &hard_voxelize_forward, + "hard_voxelize_forward", py::arg("points"), py::arg("voxels"), + py::arg("coors"), py::arg("num_points_per_voxel"), + py::arg("voxel_size"), py::arg("coors_range"), py::arg("max_points"), + py::arg("max_voxels"), py::arg("NDim")); + m.def("dynamic_voxelize_forward", &dynamic_voxelize_forward, + "dynamic_voxelize_forward", py::arg("points"), py::arg("coors"), + py::arg("voxel_size"), py::arg("coors_range"), py::arg("NDim")); m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "forward function of multi-scale deformable attention", py::arg("value"), py::arg("value_spatial_shapes"), diff --git a/mmcv/ops/csrc/pytorch/scatter_points.cpp b/mmcv/ops/csrc/pytorch/scatter_points.cpp new file mode 100644 index 0000000000..81aceb7ad4 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/scatter_points.cpp @@ -0,0 +1,81 @@ +#include "pytorch_cpp_helper.hpp" + +typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; + +#ifdef MMCV_WITH_CUDA +std::vector DynamicPointToVoxelForwardCUDAKernelLauncher( + const torch::Tensor &feats, const torch::Tensor &coors, + const reduce_t reduce_type); + +std::vector dynamic_point_to_voxel_forward_cuda( + const torch::Tensor &feats, const torch::Tensor &coors, + const reduce_t reduce_type) { + return DynamicPointToVoxelForwardCUDAKernelLauncher(feats, coors, + reduce_type); +}; + +void DynamicPointToVoxelBackwardCUDAKernelLauncher( + torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, + const reduce_t reduce_type); + +void dynamic_point_to_voxel_backward_cuda( + torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, + const reduce_t reduce_type){DynamicPointToVoxelBackwardCUDAKernelLauncher( + grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, + reduce_count, reduce_type)}; +#endif + +std::vector dynamic_point_to_voxel_forward_cpu( + const at::Tensor &points, const at::Tensor &voxel_mapping, + const std::vector voxel_size, const std::vector coors_range); + +inline reduce_t convert_reduce_type(const std::string &reduce_type) { + if (reduce_type == "max") + return reduce_t::MAX; + else if (reduce_type == "sum") + return reduce_t::SUM; + else if (reduce_type == "mean") + return reduce_t::MEAN; + else + TORCH_CHECK(false, "do not support reduce type " + reduce_type) + return reduce_t::SUM; +} + +inline std::vector dynamic_point_to_voxel_forward( + const torch::Tensor &feats, const torch::Tensor &coors, + const std::string &reduce_type) { + if (feats.device().is_cuda()) { +#ifdef WITH_CUDA + return dynamic_point_to_voxel_forward_cuda( + feats, coors, convert_reduce_type(reduce_type)); +#else + AT_ERROR("dynamic_point_to_voxel is not compiled with GPU support"); +#endif + } else { + AT_ERROR("dynamic_point_to_voxel is not implemented on CPU"); + return std::vector(); + } +} + +inline void dynamic_point_to_voxel_backward( + torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, + const std::string &reduce_type) { + if (grad_feats.device().is_cuda()) { +#ifdef WITH_CUDA + dynamic_point_to_voxel_backward_cuda(grad_feats, grad_reduced_feats, feats, + reduced_feats, coors_idx, reduce_count, + convert_reduce_type(reduce_type)); + return; +#else + AT_ERROR("dynamic_point_to_voxel is not compiled with GPU support"); +#endif + } else { + AT_ERROR("dynamic_point_to_voxel is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/voxelization.cpp b/mmcv/ops/csrc/pytorch/voxelization.cpp new file mode 100644 index 0000000000..66d71d6de8 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/voxelization.cpp @@ -0,0 +1,85 @@ +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +int HardVoxelizeForwardCUDAKernelLauncher( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3); + +int hard_voxelize_forward_cuda(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim = 3) { + return HardVoxelizeForwardCUDAKernelLauncher( + points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, + max_points, max_voxels, NDim); +}; + +void DynamicVoxelizeForwardCUDAKernelLauncher( + const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, const std::vector coors_range, + const int NDim = 3); + +void dynamic_voxelize_forward_cuda(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3) { + DynamicVoxelizeForwardCUDAKernelLauncher(points, coors, voxel_size, + coors_range, NDim); +}; +#endif + +int hard_voxelize_forward_cpu(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim = 3); + +void dynamic_voxelize_forward_cpu(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3){ + +}; + +inline int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim = 3) { + if (points.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + return hard_voxelize_forward_cuda( + points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, + max_points, max_voxels, NDim); +#else + AT_ERROR("hard_voxelize is not compiled with GPU support"); +#endif + } + return hard_voxelize_forward_cpu(points, voxels, coors, num_points_per_voxel, + voxel_size, coors_range, max_points, + max_voxels, NDim); +} + +inline void dynamic_voxelize_forward(const at::Tensor &points, + at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3) { + if (points.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + dynamic_voxelize_forward_cuda(points, coors, voxel_size, coors_range, NDim); +#else + AT_ERROR("dynamic_voxelize is not compiled with GPU support"); +#endif + } + dynamic_voxelize_forward_cpu(points, coors, voxel_size, coors_range, NDim); +} diff --git a/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp b/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp new file mode 100644 index 0000000000..63d83318c9 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp @@ -0,0 +1,157 @@ +#include "pytorch_cpp_helper.hpp" + +template +void dynamic_voxelize_forward_cpu_kernel( + const torch::TensorAccessor points, + torch::TensorAccessor coors, const std::vector voxel_size, + const std::vector coors_range, const std::vector grid_size, + const int num_points, const int num_features, const int NDim) { + const int ndim_minus_1 = NDim - 1; + bool failed = false; + // int coor[NDim]; + int* coor = new int[NDim](); + int c; + + for (int i = 0; i < num_points; ++i) { + failed = false; + for (int j = 0; j < NDim; ++j) { + c = floor((points[i][j] - coors_range[j]) / voxel_size[j]); + // necessary to rm points out of range + if ((c < 0 || c >= grid_size[j])) { + failed = true; + break; + } + coor[ndim_minus_1 - j] = c; + } + + for (int k = 0; k < NDim; ++k) { + if (failed) + coors[i][k] = -1; + else + coors[i][k] = coor[k]; + } + } + + delete[] coor; +} + +template +void hard_voxelize_forward_cpu_kernel( + const torch::TensorAccessor points, + torch::TensorAccessor voxels, torch::TensorAccessor coors, + torch::TensorAccessor num_points_per_voxel, + torch::TensorAccessor coor_to_voxelidx, int& voxel_num, + const std::vector voxel_size, const std::vector coors_range, + const std::vector grid_size, const int max_points, + const int max_voxels, const int num_points, const int num_features, + const int NDim) { + // declare a temp coors + at::Tensor temp_coors = at::zeros( + {num_points, NDim}, at::TensorOptions().dtype(at::kInt).device(at::kCPU)); + + // First use dynamic voxelization to get coors, + // then check max points/voxels constraints + dynamic_voxelize_forward_cpu_kernel( + points, temp_coors.accessor(), voxel_size, coors_range, grid_size, + num_points, num_features, NDim); + + int voxelidx, num; + auto coor = temp_coors.accessor(); + + for (int i = 0; i < num_points; ++i) { + // T_int* coor = temp_coors.data_ptr() + i * NDim; + + if (coor[i][0] == -1) continue; + + voxelidx = coor_to_voxelidx[coor[i][0]][coor[i][1]][coor[i][2]]; + + // record voxel + if (voxelidx == -1) { + voxelidx = voxel_num; + if (max_voxels != -1 && voxel_num >= max_voxels) continue; + voxel_num += 1; + + coor_to_voxelidx[coor[i][0]][coor[i][1]][coor[i][2]] = voxelidx; + + for (int k = 0; k < NDim; ++k) { + coors[voxelidx][k] = coor[i][k]; + } + } + + // put points into voxel + num = num_points_per_voxel[voxelidx]; + if (max_points == -1 || num < max_points) { + for (int k = 0; k < num_features; ++k) { + voxels[voxelidx][num][k] = points[i][k]; + } + num_points_per_voxel[voxelidx] += 1; + } + } + + return; +} + +void dynamic_voxelize_forward_cpu(const at::Tensor& points, at::Tensor& coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3) { + // check device + AT_ASSERTM(points.device().is_cpu(), "points must be a CPU tensor"); + + std::vector grid_size(NDim); + const int num_points = points.size(0); + const int num_features = points.size(1); + + for (int i = 0; i < NDim; ++i) { + grid_size[i] = + round((coors_range[NDim + i] - coors_range[i]) / voxel_size[i]); + } + + // coors, num_points_per_voxel, coor_to_voxelidx are int Tensor + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "dynamic_voxelize_forward_cuda_kernel", [&] { + dynamic_voxelize_forward_cpu_kernel( + points.accessor(), coors.accessor(), + voxel_size, coors_range, grid_size, num_points, num_features, NDim); + }); +} + +int hard_voxelize_forward_cpu(const at::Tensor& points, at::Tensor& voxels, + at::Tensor& coors, + at::Tensor& num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim = 3) { + // current version tooks about 0.02s_0.03s for one frame on cpu + // check device + AT_ASSERTM(points.device().is_cpu(), "points must be a CPU tensor"); + + std::vector grid_size(NDim); + const int num_points = points.size(0); + const int num_features = points.size(1); + + for (int i = 0; i < NDim; ++i) { + grid_size[i] = + round((coors_range[NDim + i] - coors_range[i]) / voxel_size[i]); + } + + // coors, num_points_per_voxel, coor_to_voxelidx are int Tensor + // printf("cpu coor_to_voxelidx size: [%d, %d, %d]\n", grid_size[2], + // grid_size[1], grid_size[0]); + at::Tensor coor_to_voxelidx = + -at::ones({grid_size[2], grid_size[1], grid_size[0]}, coors.options()); + + int voxel_num = 0; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "hard_voxelize_forward_cuda_kernel", [&] { + hard_voxelize_forward_cpu_kernel( + points.accessor(), voxels.accessor(), + coors.accessor(), num_points_per_voxel.accessor(), + coor_to_voxelidx.accessor(), voxel_num, voxel_size, + coors_range, grid_size, max_points, max_voxels, num_points, + num_features, NDim); + }); + + return voxel_num; +} diff --git a/mmcv/ops/scatter_points.py b/mmcv/ops/scatter_points.py new file mode 100644 index 0000000000..a37079e0c5 --- /dev/null +++ b/mmcv/ops/scatter_points.py @@ -0,0 +1,110 @@ +import torch +from torch import nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', + ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward']) + + +class _dynamic_scatter(Function): + + @staticmethod + def forward(ctx, feats, coors, reduce_type='max'): + """convert kitti points(N, >=3) to voxels. + + Args: + feats: [N, C] float tensor. points features to be reduced + into voxels. + coors: [N, ndim] int tensor. corresponding voxel coordinates + (specifically multi-dim voxel index) of each points. + reduce_type: str. reduce op. support 'max', 'sum' and 'mean' + Returns: + tuple + voxel_feats: [M, C] float tensor. reduced features. input features + that shares the same voxel coordinates are reduced to one row + coordinates: [M, ndim] int tensor, voxel coordinates. + """ + results = ext_module.dynamic_point_to_voxel_forward( + feats, coors, reduce_type) + (voxel_feats, voxel_coors, point2voxel_map, + voxel_points_count) = results + ctx.reduce_type = reduce_type + ctx.save_for_backward(feats, voxel_feats, point2voxel_map, + voxel_points_count) + ctx.mark_non_differentiable(voxel_coors) + return voxel_feats, voxel_coors + + @staticmethod + def backward(ctx, grad_voxel_feats, grad_voxel_coors=None): + (feats, voxel_feats, point2voxel_map, + voxel_points_count) = ctx.saved_tensors + grad_feats = torch.zeros_like(feats) + # TODO: whether to use index put or use cuda_backward + # To use index put, need point to voxel index + ext_module.dynamic_point_to_voxel_backward( + grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats, + point2voxel_map, voxel_points_count, ctx.reduce_type) + return grad_feats, None, None + + +dynamic_scatter = _dynamic_scatter.apply + + +class DynamicScatter(nn.Module): + + def __init__(self, voxel_size, point_cloud_range, average_points: bool): + super(DynamicScatter, self).__init__() + """Scatters points into voxels, used in the voxel encoder with + dynamic voxelization + + **Note**: The CPU and GPU implementation get the same output, but + have numerical difference after summation and division (e.g., 5e-7). + + Args: + average_points (bool): whether to use avg pooling to scatter + points into voxel voxel_size (list): list [x, y, z] size + of three dimension + point_cloud_range (list): + [x_min, y_min, z_min, x_max, y_max, z_max] + """ + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + self.average_points = average_points + + def forward_single(self, points, coors): + reduce = 'mean' if self.average_points else 'max' + return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce) + + def forward(self, points, coors): + """ + Args: + input: NC points + """ + if coors.size(-1) == 3: + return self.forward_single(points, coors) + else: + batch_size = coors[-1, 0] + 1 + voxels, voxel_coors = [], [] + for i in range(batch_size): + inds = torch.where(coors[:, 0] == i) + voxel, voxel_coor = self.forward_single( + points[inds], coors[inds][:, 1:]) + coor_pad = nn.functional.pad( + voxel_coor, (1, 0), mode='constant', value=i) + voxel_coors.append(coor_pad) + voxels.append(voxel) + features = torch.cat(voxels, dim=0) + feature_coors = torch.cat(voxel_coors, dim=0) + + return features, feature_coors + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'voxel_size=' + str(self.voxel_size) + tmpstr += ', point_cloud_range=' + str(self.point_cloud_range) + tmpstr += ', average_points=' + str(self.average_points) + tmpstr += ')' + return tmpstr diff --git a/mmcv/ops/voxelization.py b/mmcv/ops/voxelization.py new file mode 100644 index 0000000000..f143d94142 --- /dev/null +++ b/mmcv/ops/voxelization.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import torch +from torch import nn +from torch.autograd import Function +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward']) + + +class _Voxelization(Function): + + @staticmethod + def forward(ctx, + points, + voxel_size, + coors_range, + max_points=35, + max_voxels=20000): + """convert kitti points(N, >=3) to voxels. + + Args: + points: [N, ndim] float tensor. points[:, :3] contain xyz points + and points[:, 3:] contain other information like reflectivity + voxel_size: [3] list/tuple or array, float. xyz, indicate voxel + size + coors_range: [6] list/tuple or array, float. indicate voxel + range. format: xyzxyz, minmax + max_points: int. indicate maximum points contained in a voxel. if + max_points=-1, it means using dynamic_voxelize + max_voxels: int. indicate maximum voxels this function create. + for second, 20000 is a good choice. Users should shuffle points + before call this function because max_voxels may drop points. + + Returns: + voxels: [M, max_points, ndim] float tensor. only contain points + and returned when max_points != -1. + coordinates: [M, 3] int32 tensor, always returned. + num_points_per_voxel: [M] int32 tensor. Only returned when + max_points != -1. + """ + if max_points == -1 or max_voxels == -1: + coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int) + ext_module.dynamic_voxelize_forward(points, coors, voxel_size, + coors_range, 3) + return coors + else: + voxels = points.new_zeros( + size=(max_voxels, max_points, points.size(1))) + coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int) + num_points_per_voxel = points.new_zeros( + size=(max_voxels, ), dtype=torch.int) + voxel_num = ext_module.hard_voxelize_forward( + points, voxels, coors, num_points_per_voxel, voxel_size, + coors_range, max_points, max_voxels, 3) + # select the valid voxels + voxels_out = voxels[:voxel_num] + coors_out = coors[:voxel_num] + num_points_per_voxel_out = num_points_per_voxel[:voxel_num] + return voxels_out, coors_out, num_points_per_voxel_out + + +voxelization = _Voxelization.apply + + +class Voxelization(nn.Module): + + def __init__(self, + voxel_size, + point_cloud_range, + max_num_points, + max_voxels=20000): + super(Voxelization, self).__init__() + """ + Args: + voxel_size (list): list [x, y, z] size of three dimension + point_cloud_range (list): + [x_min, y_min, z_min, x_max, y_max, z_max] + max_num_points (int): max number of points per voxel + max_voxels (tuple or int): max number of voxels in + (training, testing) time + """ + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + self.max_num_points = max_num_points + if isinstance(max_voxels, tuple): + self.max_voxels = max_voxels + else: + self.max_voxels = _pair(max_voxels) + + point_cloud_range = torch.tensor( + point_cloud_range, dtype=torch.float32) + # [0, -40, -3, 70.4, 40, 1] + voxel_size = torch.tensor(voxel_size, dtype=torch.float32) + grid_size = (point_cloud_range[3:] - + point_cloud_range[:3]) / voxel_size + grid_size = torch.round(grid_size).long() + input_feat_shape = grid_size[:2] + self.grid_size = grid_size + # the origin shape is as [x-len, y-len, z-len] + # [w, h, d] -> [d, h, w] + self.pcd_shape = [*input_feat_shape, 1][::-1] + + def forward(self, input): + """ + Args: + input: NC points + """ + if self.training: + max_voxels = self.max_voxels[0] + else: + max_voxels = self.max_voxels[1] + + return voxelization(input, self.voxel_size, self.point_cloud_range, + self.max_num_points, max_voxels) + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'voxel_size=' + str(self.voxel_size) + tmpstr += ', point_cloud_range=' + str(self.point_cloud_range) + tmpstr += ', max_num_points=' + str(self.max_num_points) + tmpstr += ', max_voxels=' + str(self.max_voxels) + tmpstr += ')' + return tmpstr diff --git a/tests/test_ops/test_scatter_points.py b/tests/test_ops/test_scatter_points.py new file mode 100644 index 0000000000..8610124e30 --- /dev/null +++ b/tests/test_ops/test_scatter_points.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +from torch.autograd import gradcheck + +from mmcv.ops import DynamicScatter + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_dynamic_scatter(): + feats = torch.rand( + size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50 + coors = torch.randint( + low=-1, high=20, size=(200000, 3), dtype=torch.int32, device='cuda') + + dsmean = DynamicScatter([0.32, 0.32, 6], + [-74.88, -74.88, -2, 74.88, 74.88, 4], True) + dsmax = DynamicScatter([0.32, 0.32, 6], + [-74.88, -74.88, -2, 74.88, 74.88, 4], False) + + # test empty input + empty_feats = torch.empty(size=(0, 3), dtype=torch.float32, device='cuda') + empty_coors = torch.empty(size=(0, 3), dtype=torch.int32, device='cuda') + + empty_feats.requires_grad_() + empty_feats_out_mean, empty_coors_out_mean = dsmean( + empty_feats, empty_coors) + empty_feats_out_mean.sum().backward() + empty_feats_out_max, empty_coors_out_max = dsmax(empty_feats, empty_coors) + empty_feats_out_max.sum().backward() + + assert empty_feats_out_mean.shape == empty_feats.shape + assert empty_feats_out_max.shape == empty_feats.shape + assert empty_coors_out_mean.shape == empty_coors.shape + assert empty_coors_out_max.shape == empty_coors.shape + + # test empty reduced output + empty_o_feats = torch.rand( + size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50 + empty_o_coors = torch.randint( + low=-1, high=0, size=(200000, 3), dtype=torch.int32, device='cuda') + + empty_o_feats.requires_grad_() + empty_o_feats_out_mean, empty_o_coors_out_mean = dsmean( + empty_o_feats, empty_o_coors) + empty_o_feats_out_mean.sum().backward() + assert (empty_o_feats.grad == 0).all() + + empty_o_feats_out_max, empty_o_coors_out_max = dsmax( + empty_o_feats, empty_o_coors) + empty_o_feats_out_max.sum().backward() + assert (empty_o_feats.grad == 0).all() + + # test non-empty input + ref_voxel_coors = coors.unique(dim=0, sorted=True) + ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0] + ref_voxel_feats_mean = [] + ref_voxel_feats_max = [] + for ref_voxel_coor in ref_voxel_coors: + voxel_mask = (coors == ref_voxel_coor).all(dim=-1) + ref_voxel_feats_mean.append(feats[voxel_mask].mean(dim=0)) + ref_voxel_feats_max.append(feats[voxel_mask].max(dim=0).values) + ref_voxel_feats_mean = torch.stack(ref_voxel_feats_mean) + ref_voxel_feats_max = torch.stack(ref_voxel_feats_max) + + feats_out_mean, coors_out_mean = dsmean(feats, coors) + seq_mean = (coors_out_mean[:, 0] * 400 + coors_out_mean[:, 1] * 20 + + coors_out_mean[:, 2]).argsort() + feats_out_mean = feats_out_mean[seq_mean] + coors_out_mean = coors_out_mean[seq_mean] + + feats_out_max, coors_out_max = dsmax(feats, coors) + seq_max = (coors_out_max[:, 0] * 400 + coors_out_max[:, 1] * 20 + + coors_out_max[:, 2]).argsort() + feats_out_max = feats_out_max[seq_max] + coors_cout_max = coors_out_max[seq_max] + + assert (coors_out_mean == ref_voxel_coors).all() + assert torch.allclose( + feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5) + assert (coors_cout_max == ref_voxel_coors).all() + assert torch.allclose( + feats_out_max, ref_voxel_feats_max, atol=1e-2, rtol=1e-5) + + # test grad # + feats = torch.rand( + size=(100, 4), dtype=torch.float32, device='cuda') * 100 - 50 + coors = torch.randint( + low=-1, high=3, size=(100, 3), dtype=torch.int32, device='cuda') + feats.requires_grad_() + gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5) + gradcheck(dsmax, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5) From b1e706f43a2e712f08aea5e2969c845def685e43 Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Mon, 27 Sep 2021 22:20:57 +0800 Subject: [PATCH 02/12] add ops (voxel) in mmdet3d --- .../cuda/scatter_points_cuda_kernel.cuh | 1 + .../common/cuda/voxelization_cuda_kernel.cuh | 4 --- .../csrc/pytorch/cuda/scatter_points_cuda.cu | 32 +++++++------------ .../csrc/pytorch/cuda/voxelization_cuda.cu | 6 ++-- mmcv/ops/csrc/pytorch/scatter_points.cpp | 16 ++++++++-- mmcv/ops/csrc/pytorch/voxelization.cpp | 8 +++-- mmcv/ops/{voxelization.py => voxelize.py} | 0 7 files changed, 33 insertions(+), 34 deletions(-) rename mmcv/ops/{voxelization.py => voxelize.py} (100%) diff --git a/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh index 7deade0547..df5b34c86b 100644 --- a/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh @@ -9,6 +9,7 @@ #endif typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; +int const maxGridDim = 50000; __device__ __forceinline__ static void reduceMax(float *address, float val) { int *address_as_i = reinterpret_cast(address); diff --git a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh index 638f327057..0669a8f436 100644 --- a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh @@ -10,10 +10,6 @@ typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - template __global__ void dynamic_voxelize_kernel( const T* points, T_int* coors, const float voxel_x, const float voxel_y, diff --git a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu index 3710bb7a6e..3692f8785f 100644 --- a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu @@ -1,5 +1,6 @@ #include #include +#include #include "pytorch_cuda_helper.hpp" #include "scatter_points_cuda_kernel.cuh" @@ -7,9 +8,6 @@ std::vector DynamicPointToVoxelForwardCUDAKernelLauncher( const at::Tensor &feats, const at::Tensor &coors, const reduce_t reduce_type) { - CHECK_CUDA_INPUT(feats); - CHECK_CUDA_INPUT(coors); - const int num_input = feats.size(0); const int num_feats = feats.size(1); @@ -42,9 +40,9 @@ std::vector DynamicPointToVoxelForwardCUDAKernelLauncher( else reduced_feats.fill_(static_cast(0)); - dim3 blocks(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock), - maxGridDim)); - dim3 threads(threadsPerBlock); + dim3 blocks(std::min( + at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); + dim3 threads(THREADS_PER_BLOCK); feats_reduce_kernel<<>>( feats.data_ptr(), coors_map.data_ptr(), reduced_feats.data_ptr(), num_input, num_feats, @@ -63,13 +61,6 @@ void DynamicPointToVoxelBackwardCUDAKernelLauncher( const at::Tensor &feats, const at::Tensor &reduced_feats, const at::Tensor &coors_map, const at::Tensor &reduce_count, const reduce_t reduce_type) { - CHECK_CUDA_INPUT(grad_feats); - CHECK_CUDA_INPUT(grad_reduced_feats); - CHECK_CUDA_INPUT(feats); - CHECK_CUDA_INPUT(reduced_feats); - CHECK_CUDA_INPUT(coors_map); - CHECK_CUDA_INPUT(reduce_count); - const int num_input = feats.size(0); const int num_reduced = reduced_feats.size(0); const int num_feats = feats.size(1); @@ -84,8 +75,8 @@ void DynamicPointToVoxelBackwardCUDAKernelLauncher( grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel", ([&] { dim3 blocks(std::min( - at::cuda::ATenCeilDiv(num_input, threadsPerBlock), maxGridDim)); - dim3 threads(threadsPerBlock); + at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); + dim3 threads(THREADS_PER_BLOCK); add_reduce_traceback_grad_kernel<<>>( grad_feats.data_ptr(), grad_reduced_feats.data_ptr(), @@ -101,8 +92,8 @@ void DynamicPointToVoxelBackwardCUDAKernelLauncher( grad_reduced_feats.scalar_type(), "max_reduce_traceback_scatter_idx_kernel", ([&] { dim3 blocks(std::min( - at::cuda::ATenCeilDiv(num_input, threadsPerBlock), maxGridDim)); - dim3 threads(threadsPerBlock); + at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); + dim3 threads(THREADS_PER_BLOCK); max_reduce_traceback_scatter_idx_kernel<<>>( feats.data_ptr(), reduced_feats.data_ptr(), reduce_from.data_ptr(), coors_map.data_ptr(), @@ -114,9 +105,10 @@ void DynamicPointToVoxelBackwardCUDAKernelLauncher( AT_DISPATCH_FLOATING_TYPES( grad_reduced_feats.scalar_type(), "max_reduce_traceback_scatter_idx_kernel", ([&] { - dim3 blocks(std::min( - at::cuda::ATenCeilDiv(num_reduced, threadsPerBlock), maxGridDim)); - dim3 threads(threadsPerBlock); + dim3 blocks( + std::min(at::cuda::ATenCeilDiv(num_reduced, THREADS_PER_BLOCK), + maxGridDim)); + dim3 threads(THREADS_PER_BLOCK); max_reduce_scatter_grad_kernel<<>>( grad_feats.data_ptr(), grad_reduced_feats.data_ptr(), diff --git a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu index a3edb8f71e..67852594aa 100644 --- a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu @@ -11,7 +11,6 @@ int HardVoxelizeForwardCUDAKernelLauncher( const int max_voxels, const int NDim = 3) { // current version tooks about 0.04s for one frame on cpu // check device - CHECK_CUDA_INPUT(points); at::cuda::CUDAGuard device_guard(points.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -155,7 +154,6 @@ void DynamicVoxelizeForwardCUDAKernelLauncher( const int NDim = 3) { // current version tooks about 0.04s for one frame on cpu // check device - CHECK_CUDA_INPUT(points); at::cuda::CUDAGuard device_guard(points.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -177,9 +175,9 @@ void DynamicVoxelizeForwardCUDAKernelLauncher( const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); - const int col_blocks = at::cuda::ATenCeilDiv(num_points, threadsPerBlock); + const int col_blocks = at::cuda::ATenCeilDiv(num_points, THREADS_PER_BLOCK); dim3 blocks(col_blocks); - dim3 threads(threadsPerBlock); + dim3 threads(THREADS_PER_BLOCK); AT_DISPATCH_ALL_TYPES(points.scalar_type(), "dynamic_voxelize_kernel", [&] { dynamic_voxelize_kernel<<>>( diff --git a/mmcv/ops/csrc/pytorch/scatter_points.cpp b/mmcv/ops/csrc/pytorch/scatter_points.cpp index 81aceb7ad4..2e407a13bd 100644 --- a/mmcv/ops/csrc/pytorch/scatter_points.cpp +++ b/mmcv/ops/csrc/pytorch/scatter_points.cpp @@ -24,9 +24,11 @@ void dynamic_point_to_voxel_backward_cuda( torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, const torch::Tensor &feats, const torch::Tensor &reduced_feats, const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, - const reduce_t reduce_type){DynamicPointToVoxelBackwardCUDAKernelLauncher( - grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, - reduce_count, reduce_type)}; + const reduce_t reduce_type) { + DynamicPointToVoxelBackwardCUDAKernelLauncher(grad_feats, grad_reduced_feats, + feats, reduced_feats, coors_idx, + reduce_count, reduce_type); +}; #endif std::vector dynamic_point_to_voxel_forward_cpu( @@ -50,6 +52,8 @@ inline std::vector dynamic_point_to_voxel_forward( const std::string &reduce_type) { if (feats.device().is_cuda()) { #ifdef WITH_CUDA + CHECK_CUDA_INPUT(feats); + CHECK_CUDA_INPUT(coors); return dynamic_point_to_voxel_forward_cuda( feats, coors, convert_reduce_type(reduce_type)); #else @@ -68,6 +72,12 @@ inline void dynamic_point_to_voxel_backward( const std::string &reduce_type) { if (grad_feats.device().is_cuda()) { #ifdef WITH_CUDA + CHECK_CUDA_INPUT(grad_feats); + CHECK_CUDA_INPUT(grad_reduced_feats); + CHECK_CUDA_INPUT(feats); + CHECK_CUDA_INPUT(reduced_feats); + CHECK_CUDA_INPUT(coors_idx); + CHECK_CUDA_INPUT(reduce_count); dynamic_point_to_voxel_backward_cuda(grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, reduce_count, convert_reduce_type(reduce_type)); diff --git a/mmcv/ops/csrc/pytorch/voxelization.cpp b/mmcv/ops/csrc/pytorch/voxelization.cpp index 66d71d6de8..8799a0234b 100644 --- a/mmcv/ops/csrc/pytorch/voxelization.cpp +++ b/mmcv/ops/csrc/pytorch/voxelization.cpp @@ -44,9 +44,7 @@ int hard_voxelize_forward_cpu(const at::Tensor &points, at::Tensor &voxels, void dynamic_voxelize_forward_cpu(const at::Tensor &points, at::Tensor &coors, const std::vector voxel_size, const std::vector coors_range, - const int NDim = 3){ - -}; + const int NDim = 3); inline int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, @@ -57,6 +55,8 @@ inline int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels, const int NDim = 3) { if (points.device().is_cuda()) { #ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(points); + return hard_voxelize_forward_cuda( points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, max_points, max_voxels, NDim); @@ -76,6 +76,8 @@ inline void dynamic_voxelize_forward(const at::Tensor &points, const int NDim = 3) { if (points.device().is_cuda()) { #ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(points); + dynamic_voxelize_forward_cuda(points, coors, voxel_size, coors_range, NDim); #else AT_ERROR("dynamic_voxelize is not compiled with GPU support"); diff --git a/mmcv/ops/voxelization.py b/mmcv/ops/voxelize.py similarity index 100% rename from mmcv/ops/voxelization.py rename to mmcv/ops/voxelize.py From 29598d8b07d617a3307949bdb605af0d73aaddb2 Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Tue, 28 Sep 2021 03:01:50 +0800 Subject: [PATCH 03/12] add ops (voxel) in mmdet3d --- mmcv/ops/csrc/pytorch/scatter_points.cpp | 19 ++++++++++--------- mmcv/ops/csrc/pytorch/voxelization.cpp | 22 ++++++++++------------ 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/scatter_points.cpp b/mmcv/ops/csrc/pytorch/scatter_points.cpp index 2e407a13bd..c825d708e4 100644 --- a/mmcv/ops/csrc/pytorch/scatter_points.cpp +++ b/mmcv/ops/csrc/pytorch/scatter_points.cpp @@ -47,11 +47,11 @@ inline reduce_t convert_reduce_type(const std::string &reduce_type) { return reduce_t::SUM; } -inline std::vector dynamic_point_to_voxel_forward( +std::vector dynamic_point_to_voxel_forward( const torch::Tensor &feats, const torch::Tensor &coors, const std::string &reduce_type) { if (feats.device().is_cuda()) { -#ifdef WITH_CUDA +#ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(feats); CHECK_CUDA_INPUT(coors); return dynamic_point_to_voxel_forward_cuda( @@ -65,13 +65,15 @@ inline std::vector dynamic_point_to_voxel_forward( } } -inline void dynamic_point_to_voxel_backward( - torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, - const torch::Tensor &feats, const torch::Tensor &reduced_feats, - const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, - const std::string &reduce_type) { +void dynamic_point_to_voxel_backward(torch::Tensor &grad_feats, + const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, + const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, + const torch::Tensor &reduce_count, + const std::string &reduce_type) { if (grad_feats.device().is_cuda()) { -#ifdef WITH_CUDA +#ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(grad_feats); CHECK_CUDA_INPUT(grad_reduced_feats); CHECK_CUDA_INPUT(feats); @@ -81,7 +83,6 @@ inline void dynamic_point_to_voxel_backward( dynamic_point_to_voxel_backward_cuda(grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, reduce_count, convert_reduce_type(reduce_type)); - return; #else AT_ERROR("dynamic_point_to_voxel is not compiled with GPU support"); #endif diff --git a/mmcv/ops/csrc/pytorch/voxelization.cpp b/mmcv/ops/csrc/pytorch/voxelization.cpp index 8799a0234b..de0c451913 100644 --- a/mmcv/ops/csrc/pytorch/voxelization.cpp +++ b/mmcv/ops/csrc/pytorch/voxelization.cpp @@ -46,13 +46,12 @@ void dynamic_voxelize_forward_cpu(const at::Tensor &points, at::Tensor &coors, const std::vector coors_range, const int NDim = 3); -inline int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels, - at::Tensor &coors, - at::Tensor &num_points_per_voxel, - const std::vector voxel_size, - const std::vector coors_range, - const int max_points, const int max_voxels, - const int NDim = 3) { +int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim = 3) { if (points.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(points); @@ -69,11 +68,10 @@ inline int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels, max_voxels, NDim); } -inline void dynamic_voxelize_forward(const at::Tensor &points, - at::Tensor &coors, - const std::vector voxel_size, - const std::vector coors_range, - const int NDim = 3) { +void dynamic_voxelize_forward(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3) { if (points.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(points); From 809693d23ed776defea64ff31d3acc543095a42c Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Sun, 3 Oct 2021 17:56:08 +0800 Subject: [PATCH 04/12] refactor code --- .../csrc/common/cuda/scatter_points_cuda_kernel.cuh | 12 ++++-------- mmcv/ops/voxelize.py | 6 +++--- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh index df5b34c86b..ffac86244a 100644 --- a/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh @@ -74,8 +74,7 @@ __global__ void feats_reduce_kernel( const T *feats, const int32_t *coors_map, T *reduced_feats, // shall be 0 at initialization const int num_input, const int num_feats, const reduce_t reduce_type) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input; - x += gridDim.x * blockDim.x) { + CUDA_1D_KERNEL_LOOP(x, num_input) { int32_t reduce_to = coors_map[x]; if (reduce_to == -1) continue; @@ -98,8 +97,7 @@ __global__ void add_reduce_traceback_grad_kernel( T *grad_feats, const T *grad_reduced_feats, const int32_t *coors_map, const int32_t *reduce_count, const int num_input, const int num_feats, const reduce_t reduce_type) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input; - x += gridDim.x * blockDim.x) { + CUDA_1D_KERNEL_LOOP(x, num_input) { int32_t reduce_to = coors_map[x]; if (reduce_to == -1) { continue; @@ -127,8 +125,7 @@ template __global__ void max_reduce_traceback_scatter_idx_kernel( const T *feats, const T *reduced_feats, int32_t *reduce_from, const int32_t *coors_map, const int num_input, const int num_feats) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input; - x += gridDim.x * blockDim.x) { + CUDA_1D_KERNEL_LOOP(x, num_input) { int32_t reduce_to = coors_map[x]; const int input_offset = x * num_feats; @@ -156,8 +153,7 @@ __global__ void max_reduce_scatter_grad_kernel(T *grad_feats, const int32_t *reduce_from, const int num_reduced, const int num_feats) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_reduced; - x += gridDim.x * blockDim.x) { + CUDA_1D_KERNEL_LOOP(x, num_reduced) { const int reduced_offset = x * num_feats; const int32_t *scatter_to_offset = reduce_from + reduced_offset; const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset; diff --git a/mmcv/ops/voxelize.py b/mmcv/ops/voxelize.py index f143d94142..500c861fc0 100644 --- a/mmcv/ops/voxelize.py +++ b/mmcv/ops/voxelize.py @@ -35,10 +35,10 @@ def forward(ctx, before call this function because max_voxels may drop points. Returns: - voxels: [M, max_points, ndim] float tensor. only contain points + voxels_out: [M, max_points, ndim] float tensor. only contain points and returned when max_points != -1. - coordinates: [M, 3] int32 tensor, always returned. - num_points_per_voxel: [M] int32 tensor. Only returned when + coors_out: [M, 3] int32 tensor, always returned. + num_points_per_voxel_out: [M] int32 tensor. Only returned when max_points != -1. """ if max_points == -1 or max_voxels == -1: From e19d87c082687eff226c5eb6982b3845f157ceec Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Sun, 3 Oct 2021 20:08:45 +0800 Subject: [PATCH 05/12] update test --- tests/data/for_3d_ops/test_voxel.npy | Bin 0 -> 1663049 bytes tests/test_ops/test_voxelization.py | 54 +++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 tests/data/for_3d_ops/test_voxel.npy create mode 100644 tests/test_ops/test_voxelization.py diff --git a/tests/data/for_3d_ops/test_voxel.npy b/tests/data/for_3d_ops/test_voxel.npy new file mode 100644 index 0000000000000000000000000000000000000000..0ca96590dae57544c5b07de1189a0c837b5b5ab9 GIT binary patch literal 1663049 zcmeF)d6-qzy(fHvNJxc?NGd80pdwHNkr0B2P-hyyr8z12L1*RJLgu8m_9G)}ko-dMaO&->iu-k4m`56^Qv z>$mqhYpwk`CzZdx&u`_B?+w1>`(LRyqu$)MW3QTe<&?G;p4oQZ#cj_$vu)|5DbuE0 zG4Y3!ri{HRz1-)D2~)31u1_6z#pJ7!x~Sc`XU;$KFZ)-p;J7O%PQP~Y?DMXiH07%E zuAM$%+SOM~nR3PKf*($~>dHwIr%szP{mN+tW%UlvDwsI-@ZGYn46IjPR{x@}j43;A zaan_5Wyco{{7Sn|Cykvpd-7EUWerExuUD_$#TS%)b;(Z_mo*wz_O+sc$Bq0-a&2B& z!H^+C2K-C%FZ<8Hcv{)l$1M3lS>yR-e^)f}_+-kJlO|1>TGnJ-y>Z!2CyYzBX*#Uz z#AKU+KN|ec#|><-xUAW@!&{Vox9l6~7KOvgzFD_LkFt~gaEp4=%1$1$r00Ux&Fd|@ zZAIv|v0w4XZo}hbt9py_E5hQFb|0y`E?E&?7(6DbyT9+b6`|Rp1ySvs;qhxjtBS85 zzal(%$79{^?|pvu`KjYpgh|a771z~EC#(p&pLy>{J!H&^@X^C%QS%1p#}(TS_88lJ zMOgm&M_HZj?;imI1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBly zK!5-N0t5&UAV7cs0RjXF5FkK+009C72>izg%q^T8+AqE}x@1|`c-PKBQO7+~!sVNv zj255NEz6UaA1Di>#{NSzb8E}2{*O7OT?r5%K!5-N0t5&UAV7cs0RjXF5FkK+009C7 z2oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ;IA!k>qV2o+``JJ zxS(rXJEt&iS28IK>bfKvR@EXIpAd~`GAUH{y)SxlMvJ&#?FZ3c`#jbsK!5-N0t5&U zAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C7 z2oNAZfB=F27=aJ2TaolT|5`k<+wiQvd1>j2@Y>K*;%DExG^-z&ydtE%&yTK`-FDuQ z`CF!~2&Z){K2lE|w<2_EH0Vg3uDA7`QAg?>16L&Z_N$Mq_rl;6VNBCmN9uIFnOm-dP^%)yv2C_dY)^s96)u+`2pz7rd9% zm3_~WS^wOA(+);}009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C7 z2oNAZfB*pk1PBlyK!5-N0t5&UAV7csfqz}#rCC1>_n!J}JY>uzS$4eX$e)JWCvJ-$ zdHXwAz5l~yVNlm6 zWyLo|>rS5%o}95YUJ}lXE4EFKj@9RKZUO`d5FkK+009C72oNAZfB*pk1PBlyK!5-N z0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&U__qnPdbuo&-||?z z_@r)Wzw@}^!m`l1VRbS-GcK&$5k31}S-51`jx3M<+RzK+(f@6{W)mPlfB*pk1PBly zK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF z5FkL{m=P$fTp5zX=I-x3BCefN7@xd*d02OP{oKfI!_z!^z3#94EG&P$L9TP>OSAfc z$ty#*jrDT7pXr*_-8TL#3}{?G*KlE%tRB$#XQ8Ua@wswlW`KmTg3 z&!)~KNTk@K!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ;0rGB^nHs%ucd{#%QyFrhgIDj z^;-Jlu*Sk##3^g8l9WfTkp9s4C>lEck4y% zv;OK|a9)o`fB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBly zK!5-N0t5&UAV7cs0RjXF5FkK+0D)shVCjTKq37r`ay>`){e{mlo-m;WK3iU0uu1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ zfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009DD7=eYiTpw!ZoRjOiB*e9IZjLtG zH9xdp+$Ps-Q7&#*QW1Uh@Z7L|N>T2Od2yCMU;Fs=VPv;+a!V()jmr-_AI;o4KRkHH zIl0o(Hd(*)P4CPNZ#~o|_w;?OIMfbN_R5&!62ru4-|7{MJL$!jLiF z&K)Z15f{|7iW;q(5}GagcCKITg>lRBYod~hDPiF)owNG2p;ghOW>Z7`_FZ$M#&(Yn z75ye!l%Em?G`=ufFYUMf}0<+(F2obhF31`kW1#qNsspEgX^Y+Cuf|O^lMM=w~88!m=Zd7?wqTB>a65E2clj} zr-bFNfA`b+n&|e4Q$oFZ7i4uw#kQzs&6M!|rQLG_8lRoy=j+AOS58URi?TY+hUd&~A^M17-L{%-Oh8G6+%<9#vRwa4wsiDJ$ z-dR4qUcCiT(~;@?ep&t6&$nxn` zEy|tPV?#m8cz!w?)qVtpO;^+`;~boK!5-N0t5&U zAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C7 z2oNAZ;8+m2cK_5cb=(!X<*)Zh`kR+WuMM4+^dlde<-OaLR3v%sX-PhNT&}~0;;g@U zqjl3lLCrN;z5U?YXyX&p!r1PUa+}^M&ibdP^(QY*&g!&3dB56eVPv;yS-pBy^JvSo z=}A6(Rxar=p5)CN$DKP*50jeB&FZutdef29L;J^Y}2%Xg=B+OPcXH-D7n!^?^b;`Dr{b-XsW@|1J3^M3N;v@m_;B?(^s#Wp5*0k z%=H}II_XD#K{Rvg^klvAtWNr;N4uYy9xD4T$vu5v>*PF@(fz$=gwCCp=7#TWo#fwl zCjH81gz;PcKC9FHH}{_rN-BPu+jCBvq(Ay)(WtRglRW%SKh2x}IilteAV7cs0RjXF z5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk z1PBlyK;YOE7*;hSG+OtwEbm=VvnJ}eXGXYv;)>jzkL8kn<|jn08_rDnm*1Rg-7sYJ z%fFuyhVQ*4SK0S)ed|Rt!`jDh%XM86vV8Zwr_KyLN3Y82G_SsN!pvm7)w$}YLR{JR ziX_iHGyLkzJF`0NU*2ca%rKzw-MJSA_ssgQ@BR7A&}Y*(8y+-S1BzgB0xwjrVT#w%}BXrw%b5^H$^Rw2^2wg6}H8=a!HgVd|{E@e3BbaDB$8S)s#*hjR;WxhU%o{@T!4Vd;eboYiUn@F8Png+X2aORk{i@cl*k zSz&DV|C(EQO0T$f&dt%7rn5rJ@`rNkryS0s?>T2ySU+WLuG_|@>yZ!DeH3E z4;CeP_nM^N`K+YB`hUp{>e?>ewR2~2x?cBJ{%h`~S?%JqfBAoekv$0zAV7cs0RjXF z5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk z1PBly@MRDfzhzd^-+X=Uj8VOk{Q1G+`t4^YdGrmr2k+>W=7#T0@2`pqYG#M~dvDF{Ij2vy|JHkEg?mqZ^vL{SRkOnC zRsSu^pBEQg7IoY+D{Ptei`?a#(|r1%C_V4=l^b*WKWvxf*-z^@J9N2xQ&xBG{EMjP z=-FX&|EgSB@wbwG>{mpi#?DUqwQtUSaNW1!?FZLJO-IfSyPx@otWNv2r}Lk_@0Ynn z`EOD=c2=Vtq-=bInz%Jo`$Zg&3ke9xc#Y;M=i zbCZ01Rq>beSKz4$5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+ z009C72oNAZfB*pk1PBlyK!5-N0t5&UAW*N~*zR*fw~f!`x^3(i7Z=n-?|wWd=^y^< ztWNXc#RYT2lJH_~`@w!$|MO8}=O+2^e@yz3_fPtd?~Dp7=Y|b;)g=AM`)B>dw@jNG zD*Nuv>a^eav+t$%Urzdu56JqdmmioL9+n|R6nU~yuCFwUlAnQk-=EJui zd@jp_C;iup?>%*HnAGfr+=wO}vOM|xcy4Ij@TJ@<7pL{~=(LV=!^ykIytt&I zXVmKDx#8OVujH29c3zgp_ZuNVfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C7 z2oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+z?Vg!Q=@sI>B!oo|M}S4CUznHS!A=(kC~@qt-hdr^K~l3)Mlq~G|U zEDyf#xq0EN_5Ua7H$Eu4|MKtWg*zX6E$KHtC|j@No_S%}ZEqy`@^t>e;<<%s|MLHp z<-=3+yWgA}W^Ub^J9+ncS^x9eIrBo34!_N{E9sck_xGL`X21IX)o(pCFAVDXYOeb$ zhwrE78QcB9k-E{kd13KMujhVs=C|XtKl=DB^O8LKo4FxlzMZZ2W%-^xCjkNk2oNAZ zfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&U zAV7cs0Rja6lt9nX^TUWHZzui42WS1sADBEp%q{#~(m#A~oc2H8aM%1WvfH~!K74SN zFW)k4erVp{P|`nqNLF{)kluek=^s8M%bQm}H9x#C_=BW>_>iJ6 zAzAdGd?nVO4iWof=&iT9*GyR#*0|Ek1eoyinER?Oe^8!};*#ug?qh z+rN_=z3DDsaZAv0>EM{keMe za!G!CTeSbfE5j!*{&Q9r7t}9ln3e19<=f1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C7 z2oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72z*Hd=Epw@mn?fHx9_>0 z@u8yML#kt$0wQeA0_$o_j3zxxhSsKHa*(-#F%jJsqf_~`-UW+ zUJ;!!sx--}f0W%XspuIEXgntD+WAqg>ynW5ORs+FiX@-@Nv`^-5VtHp5KSF-Rgz^t zl)L-QwsCR6Wzp_ut_+i!y_Z`$p>14#;Q6Suv^31z`eCly#-g~ea!0iFp0S~@^21!) zPDOE9@l8>e%g2WK@yA(Rv*xX+xZuk0(Ze6-3M<>jt5-G8&iBILPqO+;@)dea0t5&U zAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C7 z2oNAZfB*pk1PFX40%gVHLjCq1=FS+^E6Imn6OHUPKFN3gKDYnFc5&^T!uT`U#3BR; z5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk z1PBlyK!5-N0t5&UAn>^tsA{nwRBZdZ@a%h+CjHD$NcxvA2yvGtS)KGhk3RPg;D7`O z5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk z1PBlyK!5-N0t5&UAn=(8l$I_CYac%`>3ROW__d)`(e>jOgnqTnLc@jM%kt6xOw$Xp{qB5h zL3sbtlfx%3UJzIItu6jc9?~KN2oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+ z009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5Fqd;1s=R(Y?##S`0(Hzz2fZ$ z*GAJZguNGUaXK@&nK7hyEu;O&~yk009C72oNAZfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7e? zu_>_cxdkEaa!S&l{IaA!`Jia`GYgX5<;|1+HWs>$-5VX zW{XY8A=Ar4xZt>j0^P;(h3&RV8TZFCmT$uGUKQ>jIng9U;1PBly zK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF z5FkK+009DDK7oFh58abe{S{nEdD zKPgX4fB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N z0t5&UAV7cs0RjXF5FkK+z%eSY_@srQTb3^j?H8XN`qhrg^5{?B zw=n62-YSe5`=8SLQJ+lAENahg{jR<$r}xT|eAtz&VzUhy%? z=-31Z5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ zfB*pk1PBlyK!5-N0t5&UIOYWwpL9dGePU6P|NcRoKFpY=H-t-;wF?hS{y|)^?O^ft zgA2p5+rAY#Z1_Q3*|)ZM!(9u*j)Cp7I-Ng1z9GDFafdLh>hdga{>Ho;!qN#H!_2Ke zj1Lw4CVJ%Ug`sv%Q8K?c$)B%D^5r*#KAYNwP4678t6JO;iVNC@xJwkzEj%x3+v$ez z!r*h0^`d0IE26l|4WZ$}4&mkBAFh{9xFH#zmt@-Gr2l$P{ngVi!9M~72oNAZfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs z0RjXFd@M6;^KnK zqUxt^2yZ>~-OzQ(;kx6V8^W{ibqNC+kIC}jkGy?Dl27j%N-D<04@~}LwEXoOl3wYD z>(#58M-NQCA++uE?IfQb$J-CCjV@VsL(&U9-9Hzn_uqQxhH!uHbpKqof4bfagU=7u zPldR&^tai14HtGz_UkEs{uzA#mLNcY009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs z0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7dXoxq^3i;^sPxA6X@S7v$d z`SGH#@3{*@K}}k(DIPLrQCJdsgfm8s&HA4&ovPo#P85K2o+lYZo@qP34N3hfv7OnRP|X8G_{ojx>-Uap=iD6i zt6dbDbm$d^@9mxC*+1K-u^Isa1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ zfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72>g2mIyJg6wC&U@ES>P9EFXUF zsW*mxwSB_+DdXbeg3F>w&29`cxAso*=zX$$dh-T1hV2LYBzg2cSwHiC@4h}qfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs z0RjXF5FkK+009D@FM<6ZE(-PP^$XoLj!XKP?=1d&oz7_p5FkK+009C72oNAZfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAVA=6 zCD5|`#;|MWfUy6=tK-$Hnn&Z7-x%U9gTe>bjgJdzT18!#+!)@NH#jt|e@)i^{H=#> zOvV?7mgU#PuMMq=7N2xu=yG|#P``cOtRMPoLvIY78Vv}Am51vmXWSTOzd9i4vEDDc zpRRxT=D|r1^?uoUTkp9syxQjC@Y1aQal4WeqQBL1+lK%F0t5&UAV7cs0RjXF5FkK+ z009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1pZK9 z*=;w5CE=1JGd>~9n_s*C#<22~VM!)^LXruH<;cD z>|0yhbYyv0`}k#HbN@lvdH?WfR0IeRAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBly zK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBly@b48kV^n$4&;0vg!(Ef& zT{{OwA3a>2WXVgy=zWvp($e2X1vTa2{Y!rkrj9Gi?%(}pd2;Xa(4@nZte<)M`O`Xn zKlEBUDD8({eEY=m&~4+WFrvxeBp<$}xMg{Hl5Z~wX&A)2oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+ z009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5Fqe*6&Sa?dua9Y30eNUwDh-8 z`^7!N-k+aPU9qi8TvmKjbXv#mq0gqjtFE2%-MFNpe!Sp z4jaB+eW>X4?EIViUl?8(Tu?n~>}grOD1TwNeDl|`dTwE5)NAR5q0zcV)z6=OT70PJ zH}d$O*B`QT5ga*5=C+=7KL9`_F2!pzQwfd2_ zzZ0i%*#ggzBfhlI=IF>h37LU)lF?o#yBN;xE7g1PBlyK!5-N0t5&UAV7cs z0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oU(3 z2&DP%b|wFk@BT{1tpE3`ZMui%4L-^0v|stH7j+Nor+k#(a977HkDm4uU;Frn+5PeZ zC&rEI7l)_s`+YvmkEi+U_b=@pmfdzJfAa40;s+-GGAgOKFkHL;o&4Ow^Rj&UvfI*r zw>~}C)c3Y3^ zyy^V3|N7_TYjX?&1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBly zK!5-N0t5&UAV7cs0RjXF5FkK+009C72z=%OyLRS6+fMbe{^zxGZjQ!p>6!HZ{z~Q{3{oC zNcx?BHSSm2Bk526de)D8`@ywQ`^7!N-k%@H<|hvnfAV75-+X_**V1#7y!y>i{r2g4 zzs>5jzxlQMyNCTB?#u7md2T$nurg}Wp-1Sp@t?B$hl+j^^;+5^v@GA7)w_1?jD}V9 z2(w@P-&vjRm+s$iVQp3yR_;j7R~*h5^}n*ZwDh;pBX3_A+An@1zwwEa;^Kmu=-yMi zhY?L)&+7Dkdfq`@U(fFtcyhM?n5I3#%fJ6W+5L3Av>*FxL;pG7yunG?{mz}chb`0g z=kM=*Qq~Xtto0WrpWm0S*!ImhJsz zta&RcKTsStKJkxP{o2r~BwrpSdGVKy+@HBM4l7T2G2f}txoI!;xMExPZ2kH1H{+@n z<=# zeb42>fX4rC*3Uf6d%yKi&oFx5|IE)VJTJ+LPmk_=tY`S$uV2dUr}^|{xAhDi_q>=N zw><5Sz9Y$Z_Y7O_`E{00PxI%`z88}I-p?P&v)8QY6J~CGCBJ%A^CU06Ci>{%5SlIe zpZPS)UR-cl)No<%(7E%=`A=RvB~JT=@7mcjEW7O=vwr1iKlQnVxuk#k^I8A!gDEfzuFJ< z5C1EeTao|)0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs z0RjXF5FkK+009C72oNAZfB*pkpGSc^AIpU%9iGXreY|~Kv*xX6bN`;9(Yl@a&Hc}f zODeWS>!r|6&J)~yI&OEdg$@&e#`P}qQ><@=vTWf|G{Ez%JSf8-n>bND6C$! zD_=XOFy4M}ZM5vRDAdl`nbm2Yd+pU`(U--sWW{L3USUL01Q zlFvW#cC)Phn>>U~2oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ zfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkL{FBWKAKMt*4{$+mq!J@3+`LL=eeDdO! ze7B87amBXj(SvtHVQlw*$giJL6qlC%Hrn@G6gIu{Ke9UQhd!c79J+0+%Ib=32cvPz z<4{(-Dc`nJQP%%__})0YHgscFr}^~_cg11rJsa|Ozu7kHx88nn6q+sClI7LY^p>36<9`+QmPP0?QrXaNEQ2oNAZfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs z0RjXFd_e@3-ImMpq3cNAG@|3+qnbl;zRWJb1s_TzFydh9mj&&HZypo_u|l zA1|m`6TS6NE>yMnWxmg*Z^Wxtt%{mA$R+vn%~?P6mgU#TcYi^CCmw?U0RjXF5FkK+ z009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBly zK!5;&FTTL0chdgimHEQTwpqXMlXpj9TB^BGEN8XOYqWmNIw;pN}w<|dz zUX+hQ^9Ge!zwwsk2cpWpaTw9$zhrgi&cBG#{^BFMtxLn0Ke#r#->#%0+V@;8v@3ZiyI-;GV3IcvNq+so?0%Y8 zU;B6nuU!1|BlYmTAw2uunygOyqd$Lk2wU&DKg+idt7;M7dQm$6zN}8ypIeyrS6`Pe zKX77}f3H_B7dAfeNLC*z`c2fSQ7-fx{cx6dpIcZNH5ic#0~$Y=<=ab3e;bt*_Y9xB z_&~l%hh}j)zpyfd*{|N8%}?|1BfIqs11%L^;-$@-fY7t}-#-jNIUo_a@q_}oH;Z^yk?`xg)NAFjg3xm3@ z%I>H0M>GlH?l*7C>a^c^dcU;v)@**-AN_E?{FWp2{trW_opW-a5^TFP#vF zF-`Bx>NGFD`v@|5g=?K*-?{TQ;xwQB$%_NBe&!>)6~^hKoV;PwgL`Kf5B!n-|oyieGIrFx)=z zr}_OKwu@VqABdK}J|KMZ;@@Za^NMZLqh+@Z2%S4G&A&9OUEI0zFQSuo4+!N4mSlDL zffM6aFAq%e>PxcEr~TBYuN)MXgrDTgioccQ*PF+ceFukIFIt@C-P8W)?_WANyxOKb zyPxbIjo&gTv@2PZeLn4{{%V`SA?|WRR;T-2KYma+W7L9dze7c*$HS@ygkDSkKHq24 zH{!y|9nr3x1H!fYmu7Xll8PixKOmG9-<01n?Hh4faY6k2*#ndQ>`SuywR3Kc);>Nk zZ0`S){IlrjA>ZZ@BQAS$_V{WHgNc0RjXF5FkK+009C72oNAZfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0>`|-OS1-t z`t7gJ^4)2F^AD~Y6lT9VBR{BXyCkoDQ}o8Xf#H$2r{vd9DT)sj{U(|^ZeSSGbaK9J zr=mE`kJqa=Ang75HTk6z+Q$8AKZu$w8W3KZb#=b5vTfEcy`W}bl0P4p%}?{>wQ~j~ zdGgY1y-1m0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF z5FkK+009C72oNAZfB*pk1PBlyK!5-N0>_NNtrrapo8I|Be(8j^al4We;&(qD5ZZPc znXlizZT#BMs_6dSgTlrqM&$2)vu)h6{6O@D|NPVU zwa)5e=D{75009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ zfB*pk1PBlyK!5-N0t5&UAV7cs0Rja6K%jp6ABDKf#Qer5&X*tm0~rMY0t5&UAV7cs z0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ zfB*pk1ioAXO-HT_72B5OSFiehTwHKjG-~Y1(EXLA`KBYkAD0yu#J5je8D_tFQ&tz$ zw2Ho5I-ZsQ0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+ z009C72oNAZfB*pk1PBlyK!5-N0$(11l8OOg{FWa1J?FHEYvv;RYwb96K1Hu`jzMI{zS@TwO{rCZ){6MFy-nDaQR9M+R zEW7Qzd{v9qaruGgqX#DU4m9UmXf+y8l7_$%_{psprR6hu{6WOTKx7!+G_oSh9I{)^G&9eIy+YT0|_0kC^ zWOYGJt7zBG@u7KxCRshV@Vw~c-Qz>!`d`oL^z*k*93RGSX^^dVsOUFQy8h}_U&%LI z*d;D5sENLa=AMZF0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF z5FkK+009C72oNAZfB*pk1PBlyK!5-N0-r~LOO{<7dM*9v(RV-oZoF$}L$7ncF2d$f6j^RxW)Ozy!wpC3FEzf0CqRGz0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs z0RjXF5FkK+009C72oNAZfB*pk1PBlyK!CvKNMK1A5MFJw@X_}#Z53B+J6QZVI+~*p zAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C7 z2oNAZfB*pk1PBlyK!CvCP+<2ngTwXX=RI20qIK5){BO9mtq2exK!5-N0t5&UAV7cs z0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ z;BzFf{osT!bL$TuZQkJgc-PLIQO%n1;iHF#K01EOcjI;?C&XK(T^rhV>i6jVz0Xhj zqfd{1b>_ry*7|0TwqJaHmRCPiG%?J6wc(NabM$DALVy4P0t5&UAV7cs0RjXF5FkK+ z009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk{|14= z$`#>_c|Ux#cFyo5e?BOBa>j~Kv*waVM|K+?uU^$W-v8msP*8KxqZQk}r$73?!PQg( z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs z0RjXF5FkK+009C72pnqya| zydrFw_KhvizIUnq=AR9zRR|CuK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNA}ObW~`To(=%-5HN)GFblnnEVDE zng9U;1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&U zAV7cs0RjXF5FkK+009D@jlk+v4~EVC*Tro+^^-sUY)Gv_fB*pk1PBlyK!5-N0t5&U zAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+0D)st zV8_7g!wZ8Oh4EXy8^1R6g6Noh2!|#>fB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+ z009C72oNAZfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+0D-?)p!+L73{%H_HLRaf z6j!w~?FWnEb|oi7Tkk0eg_U0q zZ95gkB^CAK<_#_nU6=e_7}T|0T(Rw7@v-@oPECLS0RjXF5FkK+009C72oNAZfB*pk z1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N0)I+iQnM>V zr$(oTWw)J|<H5*W9 z<&Nm`%>%=hY3-`_f7mXYzy08#FlubO>ayZ*#r5hfh{}qGg!_9JRex~Zx8k{l=gG5w z;eIh5hyVcs1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZfB*pk1PBlyK!5-N z0t5&UAV7cs0RjXF5FkK+0D&*6z@ef(VcqFxR`(p;I?I=TQJ>Ru5g%huAyOm+_w7Gi`vI6%dd&XEx$C>tZ7rd_VM;vfAc0CMugR?+E$Nf(ji{G zs(I9H+uY1UN}rk0&LuHLw>46IjPcG|e(7MGnqtn7@UfnOQ;qk+f$xU9vvh9m1IGq3(` zS<59qSzLDJu(Gp?1|B!=q#;9w4EUGiU-qAY^`@1bJ!VPI7aP>8_tj+fCCTfY@KE0lPwr>6Odfj+yJpT)r~JZYCQ6&@%N7!kNTr=-J1U?xn4IeIcofa zqsEsXHU8mIkckenm1qGkK-+#~d{tos8?|UwPE^bbj6S zu}59M>Zoz)QR5#SH6Hgz*J3aUvt#BdopgGywcD1NM5IWIxbG$ zPyTcq9W{=R8t0B0r`PK~8;-i(^N+?y< zTPM9<_l$jy8uvSDoL;Z{F4BGLe%t+znv>42o0HD3n=|03IqCem-$Q!6Zk%4Pn?LZV z>w}IOr`PM|q}N*~uk?EvoV@DRNyl~1GUSiOb#v0|b?aPw)MwLi-SzZ;*X@&6x_&xt z_vtwOO}6}WJUV$l?bC63y>6WThU*+UCgZwuUYLyQzKe8Rch2-TR(CxepPIbVebYJV zxUAK*vUA3aEKG(Ir(gTS$&;?0IBn_=Cto$?huL4yvetEfK-;8$K-&&0E6V}h zx_jw+dadr9>ArPmNY_pE(mm48r*(RcbnWyU>3h0nI!=G_X`P-qJ!g8Bbf5H z_ekg0?VrA<^Xjgp_0i)xzx4Y5-z#13=-2q_{`%5Hu!Pbx5s_+Cq3(Zr(V5fx2*_^@+(4z?+otQb;*jb@3|FWZsFu` z>qV17$30WR2iL6#rKKyv1CtNGwoF?Qrj9#&ZO6bB;f28~LV7LkvLZZx_S!Ig<(hD) z=>G8T$IHXLr~WLwH0!6y=cj~LFPA0zEDwd1E5pgVmxu1J{4A`W^3!ly$7NyDJ4-{K zO*e(N9$FH{E&oY)`o6`X*U}#+`;>>0iW|ey35$~JH-wrs3&YlX7KDYjTpu>vH9vgx z@Z4mdX`xAnY2lTNrzW4D78Rgg8K#D&Bc~?o z9)2I&eQId9aB6afWk)w*Rx$ec`Bh&2_i#@kdJCTzG2WT8VAI%K)6>u0->N3*~!+m61Av*D4A4cYFnnrv{)cJFPF;~xJu8(ee1B?p^w z+%XwB*q!5^dMyW=a@^IQ=eS>enB%VPmkVaO?%iW@;gjnQnUISuxo+l;TsIZN1ofeVICOgx!VWkVMiV|=D8_#dFU?O^4z4z ze00di{(Nl9cMo37cgsG>hqeIP0(W#=0X7uCvjF)8&=+7s0rnKQ$>ZjO=X~pj9dqZq z-M7ql!+YnuBTmi79}A#gfc_conafXtH$xK^(1tfz?q5orqE4o zq%i>==ok6LU;YD4B{#S`GxMW1BK8RVn+skd=hNNx#ca2-1cQf#7z-+ z7NJ8CF;>KSu{)tjF*p^WXA$y?-Omb&v8kAM#Zxizim|8IedEhwx8hu}n-*Kbxf1x6 z*tq*{L5cgq^b&L}aW}kMf?cJ=YY9Fq#kLaUmJ*w#ZtS8`;;+ZSDh1mz>?(6B zTa^);W!PDUkIL{-8GbIq=T+$8kIl2IkWqzStFXNa9sPO9@+xq7V`DNlB)ji)NG6^iLO++hk)96nLc~l6{vqsg;A6nY zA?6LXI{4l}9|JFgFCAhfg#SZsU7jHZ4Y}%C-G(l4@wp3)i@q*+xOPwUqFcYpV4tzP zMmBcstTSgWyTpe}3?(C@uN${DmAD;%k5kwi$lgHWdx(3s%OGMd75!4tAr+quBF@wB z@nEnXOwOcnW-w>dh>gME@)Q_Xk{4;xqd9uzePpOW2=HU7khk&cbgM z;8{r={1&@@OU^#))*rb9JfGoxG59RP)?X8I<=Fo#`c4@&w-g&ncrPXImF5wYQY?0 zXD+?&bjc#Fv+#E|JhF+q zY~pGTI%H!*Hu0N{>}+x?+xEUg9dqE3gMa+7a)saTjvdIMPUnJguG?p9F7$rC8&;D` z&ig&DP0xAAod^GUVC+8!JuheHJp7vnk34Ep9=#*Kkt-b`%i51>|Z0KJj~==1#}QGv;G=2J$8lQw#9h z0x(}dy;?y3%mnj=dcCe<$F_c<@ie zuZiFmPb~Spc0v=a#goSg#JR7FbolDS1ac{n8j|EL-IYM^PjttfPJ~|~F{*V3{{3xF zzp~M#PV_@|GX6E3cWfVAvOGjS81#41SAEFA#=i8vzQj@z>q$Hte3A@K>YE0?_2u1A zmmKVH&?$r-A=?jA8-=jl0hH#-69}S)mYF1JhiY|XXl;6BMG}>BWS>D5G(=*BG z-7wAB`SKViEq1a)-C%E_Gw5aW=Kx>%4N;LFf3lJDtoO zZ#(my^ZNb7Y-$@}D$dn6?{sKrzFQD!-rwKIymHs=CU#L%6M6Iw(|2)mvvp4k^Va!R zW=uvKlRWNjgM3r=$vx)a#ZIO!ud|6R?_xI3?gm{~Gj@BF**@@>X2^tIrddfZQ&bgW zBBEnWYNL4b&G(7s>S1F(8k}spZ|TQ#im7asW-<~UW^I^Ry6aJMw##GY%;m?-d)Z^n zoJr%M$uLLPO+uF`oSOm7Ofx!euEAC_qgg((7nq(W3eB*ZVl#9}2{dKq__x0@FTVY2 z^TPDspvN=ly2PZ^J!=|n`YmhApnuMcI8|jfyt^EGR_wOTCVlAv6SwsP(`W35W|4c?Oq}>9v*o3u zrh33J=HuqonJ1V}nkJ*`kbMf8znU*z{}`ITnR_<>H#UB1mMs6wjBNa+$#4D@c70=7 z*T2Ag2|T~U_A8vbYTB3m9p9Y+yT6&5b~h4_^@;PZO;+VM$i0MrE~D#3Gw$>id~h3e zvN7>_1--6d;8C7jgkxfE4yVQ54Bbu8+{n3zaPNi@;e!`%VE=}2 z-^KO9QR#p8TGzkoZJvF_+wksXuWRpm;eBhq^_DEZ;7vR8wfE8Bue`C_zx2NO{tNG( z&7XPANk zd0yi+uYTDpZ}AdqtGrk4T7@kuytLRA-s;bndlM%<&zUNZ{~j;T%fO(*8&gLqYi;)}Cu$eD^}Gw+uGG>)YHQlfC&uO? z>-V+z2AU1E=}XtvHr(`L?b$9xwF@%~YD>J|)PD1QX>H@~#kDy*pQydE@sZlDy;EvG zD|n*z=(?e`kw+h@y=QZ?+IeRi)Gpn1erM#-k7{?!y<^AvRc|NEXjYPR?7+pO)t{dU zUD-G*loq=vR6Sr#XiUZ*Llx()gibzun=|c98>htkg##`Qeb z>RjmdqSNx&YG?bvP4qU;d13l4`r>=ed)YgkdmDtEp-XlNVfrLH%p9=D*@wxXW5riQkmHr{QTSNpvWjLrI0ovFuN=sjKN ziCqmfjlR{x#KgqVQ}3s5JV-xCU@c_u9lfV@^tGYL8A1IWWlFp;hB%?# zKVf#hJi(mkJkhNF{Atss=QQf{OzQ9)dQg_Rx4}I6YXR~W(o2faw*2YQBjYY^_ zY%+H|gFVmE8=s}0R?rWZVaIdS^XKUeFCb?H^sA8nB4<{!_6m5cHPgEb`S-Il(X2;w=pvzYHY(s|{XtvYS{s_&R=G1Go^d}F@c5!EToBsE% z>Dqe_XZF(n_L=$4`{q!`gWMxNpl^SK9Y;*dV}C-|KO^H9wjRe1C(z*}bbmp=Q{eQm znRoVM@IQ;4pP2F%=b-%@zkLa2U(-h~K!1t*$#>XU$JwjMI72V|n7$UlH`0yV+3NFc zeG~WFZ|R8*=^+>3aS{HP@#SUo`rcgG_&s}9>8DNj?rF;1_zv#Q&FL+7($j9FPqgG4 zpcUUpt@(br86FL}Z?~h5M!NIPHlp9$!*}GZ+!Jr5XWd4x>B3#TE8j2O_@;{DUKq{! zUvjtY&bMC=`c^N#pL+3absv4TEqf2J7fVli(Dppt=hK%a@|~MRuXE^~2DvV>lIgR3 z>AC&*M(B@iDbS~~uJ4Ay+;t!3PCJzQ>@fP=aK29-#hypGUykBi?=kL|W9W(L^ybI; zZg_$_X(9bCihdFm&e<6ij$ITL-Z3{SJSHP5{OZiE;kEs`g}ZO*6yDUf18eQ+ZC%6q z@0=f9*Co7iMb~h}xh~;xr@PRvZl^ysrVlozx7`*dkD$9X{L$b@_}>z4xTzui^5*d9 zxCZpZo5G8y-b8=AF+8K$jjTo3=cLFRsfRbvCvSk?-@Wy#u6RXNm%Ze1-+M`sH-v}P zd`CS$PapWo+dl9c+xMc&KchCDqvoIUj(_{#^pL+%1JCe&+Uwo$uhh9y^tL*B=U?bs zb+*^N`1XHNKaW#)kI_Gl+FAWDXXi!tJkny^g+GssGt)tGrN^e+8FYx@#*5HUAm(T8w_bfyW~7 z{55rb5wb@(8dncrvU6SD2x?|5H8a*JZxLtL4qhDYXar}2`N!*e-tg`ezC$KD$>W}~ z=Noq)<_wuI)M+w0?i#PP{YF|oF)_pWRvBvRd4BUD)bljx2T|)&ow3^oI-6&wIE$wa za3)Ud55IoS-DCPXMO6ti(8<6*%;V1o}XNbEsniy&%!y_gg#0WcX`Z z*S?0%YUo@r7G@^$-Q>UP+p6MRqW#7q4;_tdUtb^8boxrjpE{M(R&7^{bL;dfg* z7P-dJkd7MaTl}W#R1a{i?$V#$U zatB(xh54+cREOV6tq&SS408Al)oEWg$ofU!ocylp^j(~0eMK*GCRGe}zWQ*mBmdPO zIfU=iA@t%Q)~D*zd}`VuKaV>-)Y>LnWt03VJoFsIQGm;Q=TW|2A9Y&SAL%GwWOL6G zeTmCa&anfdoZW*)SzBehaM{#$3_c&@jLG;H{QED?!pwAhnvNeIcV3 z=LvKghhN6wlkxatyu;tS5Lf=VH2g{MeiHpMkeOj)dEc6+h{K7_7q3qQ%csHeX>gnb zhLfB#mnVZ+rqd{53Rq3y+kC3Mo1A=j8rVz&m+5?i&v3T9Gy@D~I&&t?avFD^ji1Wk zUFOiokyFb5L!y+tCX^ryW~f z@@LU#z2&hU___ynb;qtAcIcTq(6wj%|1dP(-|x@CdoXLA-wGV=x0cV*b-lnX`q~ix z-VJ|g_a;?zx93E&;oWF({)MF#z0PQeRt3U5* z_cb@K?n> zES;sV-jC$^L+`VaeqqmR*1LxGHLETwwq=j(lP%IsHt0E63)W?eY?N(V_dI}Y7p>ka zS3E{d_Q#3NU4CMV^si~>59z+P-vd7xiZ!ii4n|VsC?~7(LHi!;DHenMVAh)ARq+|D z1#>{B+*f{UJ*#pLHtqR`v)a2fypz>a&w-A5&fIaoeODh-pW^o=C%XKBYXj`Kn-0H5 zBifeyuqWDKHSMVH?Whs$|EWx!`S}>xc(ARLs;%Os+)}dF^!-&-)rr2*ky_P>+SQR>a?iDq z+NdKnt)t!784aB|{eeGbH0wl<>0~h!j^Y)}TCZ%?nfiBq@XbRF?o6HXhvGx=tr5h# z;$N|(GtyBprxD~$kXJ#xDrRMqVl{}}AV(D6K~Bj2AP<7vP(JCKuu(*`9l}a=Ks7=` z_^Cz&bx8F?xT;R6rf3Lj)f&|?{jREa)wihho^}K^P<2l=Qhri>RNa){^gSw!8h5|X z)&XH6->cqg9K0A~^;SFvzLR}{y|Q0JdC6~IHn$Zc6Pi3|&-1Ns?@YQoDX)de&X?ou zEW35TQ9Thq?fv-PqC1Pz=lUHbxSOaS>pmiTb^l0h6OR=5Q5 zCEv=A@~eCp__0mTyQzh3>GN%`%@?}01(!B<9O~Exyzb&|;Ez6I@1o|lwxhg7YrFU3 zeeGpcwz8ws@z&fA+d3;(w8kH;xC^(mzrpL7-^K0NJ*cIfTi0(v-nXF6-1!ef>+_vE z$$Ni%^L=ybQ*(P>GQ?lJuO4ni479ZLI$xL9+zy==U(stVy1a!QF)?=#Yt5){K0iJ8 z-*$&Rr#*gev7>SKX2fb!Vznu8+mu-K--9)+>sfoj78HVOKXyKN(bt#1kNC--PammC;#(jjdw;U=|NnB0o^Kxg z+fv>Of5MmFqmqypTDmKX{X#qQZ&pIXYF3B%4KnoUBM*iybo z8j_xxCr&>Nv?3wzMeI7KA_bZ?MV;W1(P3e5*GWdE}{`{oei)T58)$y zrH|$x>mgo&uG)iFlIW#d^?W zUh=7r(F%*Xbk%vu)GT`Gr?c{j;`|4!f8f2_+OHTBAJtu*kv!=s>@0M_eNBfeYSv>JA-^1grZ23cUfjnWM zyQ$U{&%#1<(p|IGWRoz~5Z?fM{BQm7W4!SjeP@(CD|^L9`irm5i)jOH_Uf&;RMn z|F$l_2WR~JQ2ahfuHN$xdI%5sAs9L*{jb0Kbyq$NdWGUwF&~_fJk;0uwS$G z1N!UFsyE4&xA+HrBIpr*eE5At{G^l4Ngv^$IMrF{5yY2d1@Ry`dJc3~tYEK=7ug)d zoan`0x+p%yuY=^ z!_NcD$DS2m>FVcB8G1%ge{Q$6M|`xWcdg4_)sDbc*(*BXr6C)In`{o~l=E5-Yy)SD zne+(o6Mw~qbont30gcvlMmGL{fweEtJ&iRzM?X*O{l)TCoss>D4V{-A;A;D?c%IwD 0 + assert np.all( + points[indices] == expected_coors[i][:num_points_current_voxel]) + assert num_points_current_voxel == expected_num_points_per_voxel[i] From 4c80e8aa6024aa1e824879c57d010911e57ffc1f Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Sun, 3 Oct 2021 20:12:51 +0800 Subject: [PATCH 06/12] update test --- tests/test_ops/test_voxelization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index 4e896dcf33..1309ce9bea 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -13,7 +13,7 @@ def _get_voxel_points_indices(points, coors, voxel): @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_voxelization_cpu(): +def test_voxelization(): voxel_size = [0.5, 0.5, 0.5] point_cloud_range = [0, -40, -3, 70.4, 40, 1] @@ -52,3 +52,6 @@ def test_voxelization_cpu(): assert np.all( points[indices] == expected_coors[i][:num_points_current_voxel]) assert num_points_current_voxel == expected_num_points_per_voxel[i] + + +test_voxelization() From c8baff0056d636cb6b6a69d2821a8da4e2de2c67 Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Sun, 3 Oct 2021 20:13:35 +0800 Subject: [PATCH 07/12] update test --- tests/test_ops/test_voxelization.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index 1309ce9bea..e5b7eb51c5 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -52,6 +52,3 @@ def test_voxelization(): assert np.all( points[indices] == expected_coors[i][:num_points_current_voxel]) assert num_points_current_voxel == expected_num_points_per_voxel[i] - - -test_voxelization() From 14d9a812ddf8a6c0069cf3f723904d46a8cb92c7 Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Sun, 10 Oct 2021 04:05:11 +0800 Subject: [PATCH 08/12] refactor code --- .../common/cuda/voxelization_cuda_kernel.cuh | 4 +- .../csrc/pytorch/cuda/scatter_points_cuda.cu | 14 +++-- .../csrc/pytorch/cuda/voxelization_cuda.cu | 5 -- mmcv/ops/scatter_points.py | 45 +++++++------- mmcv/ops/voxelize.py | 58 ++++++++++--------- 5 files changed, 65 insertions(+), 61 deletions(-) diff --git a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh index 0669a8f436..3407d3aed7 100644 --- a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh @@ -26,14 +26,14 @@ __global__ void dynamic_voxelize_kernel( int c_x = floor((points_offset[0] - coors_x_min) / voxel_x); if (c_x < 0 || c_x >= grid_x) { coors_offset[0] = -1; - return; + continue; } int c_y = floor((points_offset[1] - coors_y_min) / voxel_y); if (c_y < 0 || c_y >= grid_y) { coors_offset[0] = -1; coors_offset[1] = -1; - return; + continue; } int c_z = floor((points_offset[2] - coors_z_min) / voxel_z); diff --git a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu index 3692f8785f..884d38db67 100644 --- a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu @@ -33,6 +33,9 @@ std::vector DynamicPointToVoxelForwardCUDAKernelLauncher( auto reduced_feats = at::empty({out_coors.size(0), num_feats}, feats.options()); + at::cuda::CUDAGuard device_guard(feats.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES( feats.scalar_type(), "feats_reduce_kernel", ([&] { if (reduce_type == reduce_t::MAX) @@ -43,7 +46,7 @@ std::vector DynamicPointToVoxelForwardCUDAKernelLauncher( dim3 blocks(std::min( at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); dim3 threads(THREADS_PER_BLOCK); - feats_reduce_kernel<<>>( + feats_reduce_kernel<<>>( feats.data_ptr(), coors_map.data_ptr(), reduced_feats.data_ptr(), num_input, num_feats, reduce_type); @@ -69,6 +72,8 @@ void DynamicPointToVoxelBackwardCUDAKernelLauncher( // copy voxel grad to points if (num_input == 0 || num_reduced == 0) return; + at::cuda::CUDAGuard device_guard(feats.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) { AT_DISPATCH_FLOATING_TYPES( @@ -77,7 +82,7 @@ void DynamicPointToVoxelBackwardCUDAKernelLauncher( dim3 blocks(std::min( at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); dim3 threads(THREADS_PER_BLOCK); - add_reduce_traceback_grad_kernel<<>>( + add_reduce_traceback_grad_kernel<<>>( grad_feats.data_ptr(), grad_reduced_feats.data_ptr(), coors_map.data_ptr(), reduce_count.data_ptr(), @@ -94,7 +99,8 @@ void DynamicPointToVoxelBackwardCUDAKernelLauncher( dim3 blocks(std::min( at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); dim3 threads(THREADS_PER_BLOCK); - max_reduce_traceback_scatter_idx_kernel<<>>( + max_reduce_traceback_scatter_idx_kernel<<>>( feats.data_ptr(), reduced_feats.data_ptr(), reduce_from.data_ptr(), coors_map.data_ptr(), num_input, num_feats); @@ -109,7 +115,7 @@ void DynamicPointToVoxelBackwardCUDAKernelLauncher( std::min(at::cuda::ATenCeilDiv(num_reduced, THREADS_PER_BLOCK), maxGridDim)); dim3 threads(THREADS_PER_BLOCK); - max_reduce_scatter_grad_kernel<<>>( + max_reduce_scatter_grad_kernel<<>>( grad_feats.data_ptr(), grad_reduced_feats.data_ptr(), reduce_from.data_ptr(), num_reduced, num_feats); diff --git a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu index 67852594aa..7275ca5f98 100644 --- a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu @@ -50,7 +50,6 @@ int HardVoxelizeForwardCUDAKernelLauncher( NDim); })); - cudaDeviceSynchronize(); AT_CUDA_CHECK(cudaGetLastError()); // 2. map point to the idx of the corresponding voxel, find duplicate coor @@ -78,7 +77,6 @@ int HardVoxelizeForwardCUDAKernelLauncher( max_voxels, num_points, NDim); })); - cudaDeviceSynchronize(); AT_CUDA_CHECK(cudaGetLastError()); // 3. determin voxel num and voxel's coor index @@ -104,7 +102,6 @@ int HardVoxelizeForwardCUDAKernelLauncher( max_points, max_voxels, num_points); })); - cudaDeviceSynchronize(); AT_CUDA_CHECK(cudaGetLastError()); // 4. copy point features to voxels @@ -139,7 +136,6 @@ int HardVoxelizeForwardCUDAKernelLauncher( coors.contiguous().data_ptr(), num_points, NDim); })); - cudaDeviceSynchronize(); AT_CUDA_CHECK(cudaGetLastError()); auto voxel_num_cpu = voxel_num.to(at::kCPU); @@ -187,6 +183,5 @@ void DynamicVoxelizeForwardCUDAKernelLauncher( coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, NDim); }); - cudaDeviceSynchronize(); AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/mmcv/ops/scatter_points.py b/mmcv/ops/scatter_points.py index a37079e0c5..c1ca89f222 100644 --- a/mmcv/ops/scatter_points.py +++ b/mmcv/ops/scatter_points.py @@ -16,16 +16,18 @@ def forward(ctx, feats, coors, reduce_type='max'): """convert kitti points(N, >=3) to voxels. Args: - feats: [N, C] float tensor. points features to be reduced + feats (torch.Tensor): [N, C]. Points features to be reduced into voxels. - coors: [N, ndim] int tensor. corresponding voxel coordinates + coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates (specifically multi-dim voxel index) of each points. - reduce_type: str. reduce op. support 'max', 'sum' and 'mean' + reduce_type (str, optional): Reduce op. support 'max', 'sum' and + 'mean'. Default: 'max'. + Returns: - tuple - voxel_feats: [M, C] float tensor. reduced features. input features - that shares the same voxel coordinates are reduced to one row - coordinates: [M, ndim] int tensor, voxel coordinates. + voxel_feats (torch.Tensor): [M, C]. Reduced features, input + features that shares the same voxel coordinates are reduced to + one row. + voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates. """ results = ext_module.dynamic_point_to_voxel_forward( feats, coors, reduce_type) @@ -54,22 +56,23 @@ def backward(ctx, grad_voxel_feats, grad_voxel_coors=None): class DynamicScatter(nn.Module): + """Scatters points into voxels, used in the voxel encoder with dynamic + voxelization. + + **Note**: The CPU and GPU implementation get the same output, but have + numerical difference after summation and division (e.g., 5e-7). + + Args: + voxel_size (list): list [x, y, z] size of three dimension. + point_cloud_range (list): The coordinate range of points, [x_min, + y_min, z_min, x_max, y_max, z_max]. + average_points (bool): whether to use avg pooling to scatter points + into voxel. + """ def __init__(self, voxel_size, point_cloud_range, average_points: bool): super(DynamicScatter, self).__init__() - """Scatters points into voxels, used in the voxel encoder with - dynamic voxelization - **Note**: The CPU and GPU implementation get the same output, but - have numerical difference after summation and division (e.g., 5e-7). - - Args: - average_points (bool): whether to use avg pooling to scatter - points into voxel voxel_size (list): list [x, y, z] size - of three dimension - point_cloud_range (list): - [x_min, y_min, z_min, x_max, y_max, z_max] - """ self.voxel_size = voxel_size self.point_cloud_range = point_cloud_range self.average_points = average_points @@ -79,10 +82,6 @@ def forward_single(self, points, coors): return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce) def forward(self, points, coors): - """ - Args: - input: NC points - """ if coors.size(-1) == 3: return self.forward_single(points, coors) else: diff --git a/mmcv/ops/voxelize.py b/mmcv/ops/voxelize.py index 500c861fc0..c020247ecc 100644 --- a/mmcv/ops/voxelize.py +++ b/mmcv/ops/voxelize.py @@ -22,24 +22,27 @@ def forward(ctx, """convert kitti points(N, >=3) to voxels. Args: - points: [N, ndim] float tensor. points[:, :3] contain xyz points - and points[:, 3:] contain other information like reflectivity - voxel_size: [3] list/tuple or array, float. xyz, indicate voxel - size - coors_range: [6] list/tuple or array, float. indicate voxel - range. format: xyzxyz, minmax - max_points: int. indicate maximum points contained in a voxel. if - max_points=-1, it means using dynamic_voxelize - max_voxels: int. indicate maximum voxels this function create. + points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points + and points[:, 3:] contain other information like reflectivity. + voxel_size (tuple or float): The size of voxel with the shape of + [3]. + coors_range (tuple or float): The coordinate range of voxel with + the shape of [6]. + max_points (int, optional): maximum points contained in a voxel. if + max_points=-1, it means using dynamic_voxelize. Default: 35. + max_voxels (int, optional): maximum voxels this function create. for second, 20000 is a good choice. Users should shuffle points before call this function because max_voxels may drop points. + Default: 20000. Returns: - voxels_out: [M, max_points, ndim] float tensor. only contain points - and returned when max_points != -1. - coors_out: [M, 3] int32 tensor, always returned. - num_points_per_voxel_out: [M] int32 tensor. Only returned when + voxels_out (torch.Tensor): Output voxels with the shape of [M, + max_points, ndim]. Only contain points and returned when max_points != -1. + coors_out (torch.Tensor): Output coordinates with the shape of + [M, 3]. + num_points_per_voxel_out (torch.Tensor): Num points per voxel with + the shape of [M]. Only returned when max_points != -1. """ if max_points == -1 or max_voxels == -1: coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int) @@ -66,22 +69,27 @@ def forward(ctx, class Voxelization(nn.Module): + """Paper reference: https://arxiv.org/abs/1907.03739. + + Args: + voxel_size (tuple or float): The size of voxel with the shape of [3]. + point_cloud_range (tuple or float): The coordinate range of voxel with + the shape of [6]. + max_num_points (int): maximum points contained in a voxel. if + max_points=-1, it means using dynamic_voxelize. + max_voxels (int, optional): maximum voxels this function create. + for second, 20000 is a good choice. Users should shuffle points + before call this function because max_voxels may drop points. + Default: 20000. + """ def __init__(self, voxel_size, point_cloud_range, max_num_points, max_voxels=20000): - super(Voxelization, self).__init__() - """ - Args: - voxel_size (list): list [x, y, z] size of three dimension - point_cloud_range (list): - [x_min, y_min, z_min, x_max, y_max, z_max] - max_num_points (int): max number of points per voxel - max_voxels (tuple or int): max number of voxels in - (training, testing) time - """ + super().__init__() + self.voxel_size = voxel_size self.point_cloud_range = point_cloud_range self.max_num_points = max_num_points @@ -104,10 +112,6 @@ def __init__(self, self.pcd_shape = [*input_feat_shape, 1][::-1] def forward(self, input): - """ - Args: - input: NC points - """ if self.training: max_voxels = self.max_voxels[0] else: From cae9a0e805d077d7f430fda526920bc7e012f57d Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Sun, 10 Oct 2021 04:05:33 +0800 Subject: [PATCH 09/12] refactor code --- mmcv/ops/voxelize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmcv/ops/voxelize.py b/mmcv/ops/voxelize.py index c020247ecc..066517a493 100644 --- a/mmcv/ops/voxelize.py +++ b/mmcv/ops/voxelize.py @@ -1,4 +1,3 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import torch from torch import nn from torch.autograd import Function From 29cb57114865f98040eb4fa483b6433f134dd54e Mon Sep 17 00:00:00 2001 From: hdc Date: Mon, 11 Oct 2021 03:59:40 +0800 Subject: [PATCH 10/12] refactor code --- .../csrc/pytorch/cuda/voxelization_cuda.cu | 4 +- mmcv/ops/csrc/pytorch/voxelization_cpu.cpp | 4 +- mmcv/ops/scatter_points.py | 47 ++++++++++++++----- mmcv/ops/voxelize.py | 18 +++---- tests/test_ops/test_voxelization.py | 12 ++--- 5 files changed, 56 insertions(+), 29 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu index 7275ca5f98..4303063639 100644 --- a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu @@ -79,7 +79,7 @@ int HardVoxelizeForwardCUDAKernelLauncher( AT_CUDA_CHECK(cudaGetLastError()); - // 3. determin voxel num and voxel's coor index + // 3. determine voxel num and voxel's coor index // make the logic in the CUDA device could accelerate about 10 times auto coor_to_voxelidx = -at::ones( { @@ -90,7 +90,7 @@ int HardVoxelizeForwardCUDAKernelLauncher( { 1, }, - points.options().dtype(at::kInt)); // must be zero from the begining + points.options().dtype(at::kInt)); // must be zero from the beginning AT_DISPATCH_ALL_TYPES(temp_coors.scalar_type(), "determin_duplicate", ([&] { determin_voxel_num<<<1, 1, 0, stream>>>( diff --git a/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp b/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp index 63d83318c9..15dd625985 100644 --- a/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp +++ b/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp @@ -109,7 +109,7 @@ void dynamic_voxelize_forward_cpu(const at::Tensor& points, at::Tensor& coors, // coors, num_points_per_voxel, coor_to_voxelidx are int Tensor AT_DISPATCH_FLOATING_TYPES_AND_HALF( - points.scalar_type(), "dynamic_voxelize_forward_cuda_kernel", [&] { + points.scalar_type(), "dynamic_voxelize_forward_cpu_kernel", [&] { dynamic_voxelize_forward_cpu_kernel( points.accessor(), coors.accessor(), voxel_size, coors_range, grid_size, num_points, num_features, NDim); @@ -144,7 +144,7 @@ int hard_voxelize_forward_cpu(const at::Tensor& points, at::Tensor& voxels, int voxel_num = 0; AT_DISPATCH_FLOATING_TYPES_AND_HALF( - points.scalar_type(), "hard_voxelize_forward_cuda_kernel", [&] { + points.scalar_type(), "hard_voxelize_forward_cpu_kernel", [&] { hard_voxelize_forward_cpu_kernel( points.accessor(), voxels.accessor(), coors.accessor(), num_points_per_voxel.accessor(), diff --git a/mmcv/ops/scatter_points.py b/mmcv/ops/scatter_points.py index c1ca89f222..70449d9ab5 100644 --- a/mmcv/ops/scatter_points.py +++ b/mmcv/ops/scatter_points.py @@ -9,7 +9,7 @@ ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward']) -class _dynamic_scatter(Function): +class _DynamicScatter(Function): @staticmethod def forward(ctx, feats, coors, reduce_type='max'): @@ -52,15 +52,16 @@ def backward(ctx, grad_voxel_feats, grad_voxel_coors=None): return grad_feats, None, None -dynamic_scatter = _dynamic_scatter.apply +dynamic_scatter = _DynamicScatter.apply class DynamicScatter(nn.Module): """Scatters points into voxels, used in the voxel encoder with dynamic voxelization. - **Note**: The CPU and GPU implementation get the same output, but have - numerical difference after summation and division (e.g., 5e-7). + Note: + The CPU and GPU implementation get the same output, but have numerical + difference after summation and division (e.g., 5e-7). Args: voxel_size (list): list [x, y, z] size of three dimension. @@ -71,17 +72,41 @@ class DynamicScatter(nn.Module): """ def __init__(self, voxel_size, point_cloud_range, average_points: bool): - super(DynamicScatter, self).__init__() + super().__init__() self.voxel_size = voxel_size self.point_cloud_range = point_cloud_range self.average_points = average_points def forward_single(self, points, coors): + """Scatters points into voxels. + + Args: + points (torch.Tensor): Points to be reduced into voxels. + coors (torch.Tensor): Corresponding voxel coordinates (specifically + multi-dim voxel index) of each points. + + Returns: + voxel_feats (torch.Tensor): Reduced features, input features that + shares the same voxel coordinates are reduced to one row. + voxel_coors (torch.Tensor): Voxel coordinates. + """ reduce = 'mean' if self.average_points else 'max' return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce) def forward(self, points, coors): + """Scatters points/features into voxels. + + Args: + points (torch.Tensor): Points to be reduced into voxels. + coors (torch.Tensor): Corresponding voxel coordinates (specifically + multi-dim voxel index) of each points. + + Returns: + voxel_feats (torch.Tensor): Reduced features, input features that + shares the same voxel coordinates are reduced to one row. + voxel_coors (torch.Tensor): Voxel coordinates. + """ if coors.size(-1) == 3: return self.forward_single(points, coors) else: @@ -101,9 +126,9 @@ def forward(self, points, coors): return features, feature_coors def __repr__(self): - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'voxel_size=' + str(self.voxel_size) - tmpstr += ', point_cloud_range=' + str(self.point_cloud_range) - tmpstr += ', average_points=' + str(self.average_points) - tmpstr += ')' - return tmpstr + s = self.__class__.__name__ + '(' + s += 'voxel_size=' + str(self.voxel_size) + s += ', point_cloud_range=' + str(self.point_cloud_range) + s += ', average_points=' + str(self.average_points) + s += ')' + return s diff --git a/mmcv/ops/voxelize.py b/mmcv/ops/voxelize.py index 066517a493..2b7acfc82b 100644 --- a/mmcv/ops/voxelize.py +++ b/mmcv/ops/voxelize.py @@ -68,7 +68,9 @@ def forward(ctx, class Voxelization(nn.Module): - """Paper reference: https://arxiv.org/abs/1907.03739. + """convert kitti points(N, >=3) to voxels. + + Paper reference: https://arxiv.org/abs/1907.03739. Args: voxel_size (tuple or float): The size of voxel with the shape of [3]. @@ -120,10 +122,10 @@ def forward(self, input): self.max_num_points, max_voxels) def __repr__(self): - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'voxel_size=' + str(self.voxel_size) - tmpstr += ', point_cloud_range=' + str(self.point_cloud_range) - tmpstr += ', max_num_points=' + str(self.max_num_points) - tmpstr += ', max_voxels=' + str(self.max_voxels) - tmpstr += ')' - return tmpstr + s = self.__class__.__name__ + '(' + s += 'voxel_size=' + str(self.voxel_size) + s += ', point_cloud_range=' + str(self.point_cloud_range) + s += ', max_num_points=' + str(self.max_num_points) + s += ', max_voxels=' + str(self.max_voxels) + s += ')' + return s diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index e5b7eb51c5..29e871946a 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -import pytest import torch from mmcv.ops import Voxelization @@ -11,8 +10,6 @@ def _get_voxel_points_indices(points, coors, voxel): return result_form[:, 0] & result_form[:, 1] & result_form[:, 2] -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') def test_voxelization(): voxel_size = [0.5, 0.5, 0.5] point_cloud_range = [0, -40, -3, 70.4, 40, 1] @@ -31,8 +28,11 @@ def test_voxelization(): hard_voxelization = Voxelization(voxel_size, point_cloud_range, max_num_points) - # test hard_voxelization on gpu - points = torch.tensor(points).contiguous().to(device='cuda:0') + device = torch.device( + 'cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + # test hard_voxelization on cpu/gpu + points = torch.tensor(points).contiguous().to(device) coors, voxels, num_points_per_voxel = hard_voxelization.forward(points) coors = coors.cpu().detach().numpy() voxels = voxels.cpu().detach().numpy() @@ -41,7 +41,7 @@ def test_voxelization(): assert np.all(voxels == expected_voxels) assert np.all(num_points_per_voxel == expected_num_points_per_voxel) - # test dynamic_voxelization on gpu + # test dynamic_voxelization on cpu/gpu coors = dynamic_voxelization.forward(points) coors = coors.cpu().detach().numpy() points = points.cpu().detach().numpy() From 524721b91fa8adc84993d5094e30a13db8bfd066 Mon Sep 17 00:00:00 2001 From: hdc Date: Mon, 11 Oct 2021 05:03:00 +0800 Subject: [PATCH 11/12] fix typo --- tests/test_ops/test_voxelization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index 29e871946a..8abdb256ad 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -14,7 +14,8 @@ def test_voxelization(): voxel_size = [0.5, 0.5, 0.5] point_cloud_range = [0, -40, -3, 70.4, 40, 1] - voxel_dict = np.load('tests/data/for_3d_ops/test_voxel.npy').item() + voxel_dict = np.load( + 'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item() expected_coors = voxel_dict['coors'] expected_voxels = voxel_dict['voxels'] expected_num_points_per_voxel = voxel_dict['num_points_per_voxel'] From 17d89f305d162b921b5d39dd0b59d00c05cdb905 Mon Sep 17 00:00:00 2001 From: hdc Date: Wed, 20 Oct 2021 22:55:19 +0800 Subject: [PATCH 12/12] fix typo --- .../common/cuda/voxelization_cuda_kernel.cuh | 4 +--- .../csrc/pytorch/cuda/scatter_points_cuda.cu | 1 + .../csrc/pytorch/cuda/voxelization_cuda.cu | 1 + mmcv/ops/csrc/pytorch/scatter_points.cpp | 1 + mmcv/ops/csrc/pytorch/voxelization.cpp | 1 + mmcv/ops/csrc/pytorch/voxelization_cpu.cpp | 21 +++++++------------ mmcv/ops/scatter_points.py | 1 + mmcv/ops/voxelize.py | 10 ++++----- tests/test_ops/test_scatter_points.py | 1 - tests/test_ops/test_voxelization.py | 14 +++++++++---- 10 files changed, 29 insertions(+), 26 deletions(-) diff --git a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh index 3407d3aed7..62e118b352 100644 --- a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh @@ -1,4 +1,4 @@ -// Copyright (c) OpenMMLab. All rights reserved +// Copyright (c) OpenMMLab. All rights reserved. #ifndef VOXELIZATION_CUDA_KERNEL_CUH #define VOXELIZATION_CUDA_KERNEL_CUH @@ -143,8 +143,6 @@ __global__ void determin_voxel_num( const int max_points, const int max_voxels, const int num_points) { // only calculate the coors before this coor[index] for (int i = 0; i < num_points; ++i) { - // if (coor[i][0] == -1) - // continue; int point_pos_in_voxel = point_to_voxelidx[i]; // record voxel if (point_pos_in_voxel == -1) { diff --git a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu index 884d38db67..4939fe40a0 100644 --- a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu @@ -1,3 +1,4 @@ +// Copyright (c) OpenMMLab. All rights reserved. #include #include #include diff --git a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu index 4303063639..a8946d2b93 100644 --- a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu @@ -1,3 +1,4 @@ +// Copyright (c) OpenMMLab. All rights reserved. #include #include diff --git a/mmcv/ops/csrc/pytorch/scatter_points.cpp b/mmcv/ops/csrc/pytorch/scatter_points.cpp index c825d708e4..468503fe3f 100644 --- a/mmcv/ops/csrc/pytorch/scatter_points.cpp +++ b/mmcv/ops/csrc/pytorch/scatter_points.cpp @@ -1,3 +1,4 @@ +// Copyright (c) OpenMMLab. All rights reserved. #include "pytorch_cpp_helper.hpp" typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; diff --git a/mmcv/ops/csrc/pytorch/voxelization.cpp b/mmcv/ops/csrc/pytorch/voxelization.cpp index f6522000cd..8f389e1b6b 100644 --- a/mmcv/ops/csrc/pytorch/voxelization.cpp +++ b/mmcv/ops/csrc/pytorch/voxelization.cpp @@ -1,3 +1,4 @@ +// Copyright (c) OpenMMLab. All rights reserved. #include "pytorch_cpp_helper.hpp" #ifdef MMCV_WITH_CUDA diff --git a/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp b/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp index 15dd625985..59eb86f543 100644 --- a/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp +++ b/mmcv/ops/csrc/pytorch/voxelization_cpu.cpp @@ -1,3 +1,4 @@ +// Copyright (c) OpenMMLab. All rights reserved. #include "pytorch_cpp_helper.hpp" template @@ -24,12 +25,10 @@ void dynamic_voxelize_forward_cpu_kernel( coor[ndim_minus_1 - j] = c; } - for (int k = 0; k < NDim; ++k) { - if (failed) - coors[i][k] = -1; - else - coors[i][k] = coor[k]; - } + if (failed) + memset(&coors[i][0], -1, NDim * sizeof(T_int)); + else + memcpy(&coors[i][0], &coor[0], NDim * sizeof(T_int)); } delete[] coor; @@ -72,18 +71,14 @@ void hard_voxelize_forward_cpu_kernel( voxel_num += 1; coor_to_voxelidx[coor[i][0]][coor[i][1]][coor[i][2]] = voxelidx; - - for (int k = 0; k < NDim; ++k) { - coors[voxelidx][k] = coor[i][k]; - } + memcpy(&coors[voxelidx][0], &coor[i][0], NDim * sizeof(T_int)); } // put points into voxel num = num_points_per_voxel[voxelidx]; if (max_points == -1 || num < max_points) { - for (int k = 0; k < num_features; ++k) { - voxels[voxelidx][num][k] = points[i][k]; - } + memcpy(&voxels[voxelidx][num][0], &points[i][0], + num_features * sizeof(T)); num_points_per_voxel[voxelidx] += 1; } } diff --git a/mmcv/ops/scatter_points.py b/mmcv/ops/scatter_points.py index 70449d9ab5..2b8aa4169e 100644 --- a/mmcv/ops/scatter_points.py +++ b/mmcv/ops/scatter_points.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import torch from torch import nn from torch.autograd import Function diff --git a/mmcv/ops/voxelize.py b/mmcv/ops/voxelize.py index 1e7d4ea0e8..ca3226a4fb 100644 --- a/mmcv/ops/voxelize.py +++ b/mmcv/ops/voxelize.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import torch from torch import nn from torch.autograd import Function @@ -18,7 +19,7 @@ def forward(ctx, coors_range, max_points=35, max_voxels=20000): - """convert kitti points(N, >=3) to voxels. + """Convert kitti points(N, >=3) to voxels. Args: points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points @@ -68,10 +69,10 @@ def forward(ctx, class Voxelization(nn.Module): - """convert kitti points(N, >=3) to voxels. + """Convert kitti points(N, >=3) to voxels. - Please refer to `Paper of PVCNN `_ - for more details. + Please refer to `PVCNN `_ for more + details. Args: voxel_size (tuple or float): The size of voxel with the shape of [3]. @@ -102,7 +103,6 @@ def __init__(self, point_cloud_range = torch.tensor( point_cloud_range, dtype=torch.float32) - # [0, -40, -3, 70.4, 40, 1] voxel_size = torch.tensor(voxel_size, dtype=torch.float32) grid_size = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size diff --git a/tests/test_ops/test_scatter_points.py b/tests/test_ops/test_scatter_points.py index 8610124e30..8fe1fe8cd0 100644 --- a/tests/test_ops/test_scatter_points.py +++ b/tests/test_ops/test_scatter_points.py @@ -1,4 +1,3 @@ -# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch from torch.autograd import gradcheck diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index 8abdb256ad..ad3253f952 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -1,5 +1,5 @@ -# Copyright (c) OpenMMLab. All rights reserved. import numpy as np +import pytest import torch from mmcv.ops import Voxelization @@ -10,7 +10,14 @@ def _get_voxel_points_indices(points, coors, voxel): return result_form[:, 0] & result_form[:, 1] & result_form[:, 2] -def test_voxelization(): +@pytest.mark.parametrize('device_type', [ + 'cpu', + pytest.param( + 'cuda:0', + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support')) +]) +def test_voxelization(device_type): voxel_size = [0.5, 0.5, 0.5] point_cloud_range = [0, -40, -3, 70.4, 40, 1] @@ -29,8 +36,7 @@ def test_voxelization(): hard_voxelization = Voxelization(voxel_size, point_cloud_range, max_num_points) - device = torch.device( - 'cuda:0') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device(device_type) # test hard_voxelization on cpu/gpu points = torch.tensor(points).contiguous().to(device)