From b8cb09076d7a0e54f9074a1e0add8958a41b5f6f Mon Sep 17 00:00:00 2001 From: wangruohui <12756472+wangruohui@users.noreply.github.com> Date: Tue, 20 Jul 2021 00:43:10 +0800 Subject: [PATCH 1/2] use type long long and dynamic memory allocation --- mmdet3d/ops/iou3d/src/iou3d.cpp | 12 ++++++------ mmdet3d/ops/iou3d/src/iou3d_kernel.cu | 2 +- mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu | 10 +++++----- mmdet3d/ops/voxel/src/voxelization_cpu.cpp | 4 +++- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mmdet3d/ops/iou3d/src/iou3d.cpp b/mmdet3d/ops/iou3d/src/iou3d.cpp index 706cddbb80..f9f50ef076 100644 --- a/mmdet3d/ops/iou3d/src/iou3d.cpp +++ b/mmdet3d/ops/iou3d/src/iou3d.cpp @@ -103,7 +103,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, int boxes_num = boxes.size(0); const float *boxes_data = boxes.data_ptr(); - long *keep_data = keep.data_ptr(); + long long *keep_data = keep.data_ptr(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); @@ -124,8 +124,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, cudaFree(mask_data); - unsigned long long remv_cpu[col_blocks]; - memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); + unsigned long long* remv_cpu = new unsigned long long[col_blocks](); int num_to_keep = 0; @@ -141,6 +140,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, } } } + delete[] remv_cpu; if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); return num_to_keep; @@ -157,7 +157,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, int boxes_num = boxes.size(0); const float *boxes_data = boxes.data_ptr(); - long *keep_data = keep.data_ptr(); + long long *keep_data = keep.data_ptr(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); @@ -178,8 +178,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, cudaFree(mask_data); - unsigned long long remv_cpu[col_blocks]; - memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); + unsigned long long* remv_cpu = new unsigned long long[col_blocks](); int num_to_keep = 0; @@ -195,6 +194,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, } } } + delete[] remv_cpu; if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); return num_to_keep; diff --git a/mmdet3d/ops/iou3d/src/iou3d_kernel.cu b/mmdet3d/ops/iou3d/src/iou3d_kernel.cu index fce3f78825..861aea3c5a 100644 --- a/mmdet3d/ops/iou3d/src/iou3d_kernel.cu +++ b/mmdet3d/ops/iou3d/src/iou3d_kernel.cu @@ -13,7 +13,7 @@ All Rights Reserved 2019-2020. //#define DEBUG const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; -const float EPS = 1e-8; +__device__ const float EPS = 1e-8; struct Point { float x, y; __device__ Point() {} diff --git a/mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu b/mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu index 20f7d2cf99..af51421463 100644 --- a/mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu +++ b/mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu @@ -49,7 +49,7 @@ __global__ void assign_score_withk_forward_kernel(const int B, const int N0, con const float* points, const float* centers, const float* scores, - const long* knn_idx, + const long long* knn_idx, float* output) { // ----- parallel loop for B, N1, K and O --------- @@ -82,7 +82,7 @@ __global__ void assign_score_withk_backward_points_kernel(const int B, const int const int K, const int O, const int aggregate, const float* grad_out, const float* scores, - const long* knn_idx, + const long long* knn_idx, float* grad_points, float* grad_centers) { @@ -116,7 +116,7 @@ __global__ void assign_score_withk_backward_scores_kernel(const int B, const int const float* grad_out, const float* points, const float* centers, - const long* knn_idx, + const long long* knn_idx, float* grad_scores) { // ----- parallel loop for B, N, K, M --------- @@ -156,7 +156,7 @@ void assign_score_withk_forward_wrapper(int B, int N0, int N1, int M, int K, int const float* points_data = points.data_ptr(); const float* centers_data = centers.data_ptr(); const float* scores_data = scores.data_ptr(); - const long* knn_idx_data = knn_idx.data_ptr(); + const long long* knn_idx_data = knn_idx.data_ptr(); float* output_data = output.data_ptr(); dim3 blocks(DIVUP(B*O*N1*K, THREADS_PER_BLOCK)); @@ -191,7 +191,7 @@ void assign_score_withk_backward_wrapper(int B, int N0, int N1, int M, int K, in const float* points_data = points.data_ptr(); const float* centers_data = centers.data_ptr(); const float* scores_data = scores.data_ptr(); - const long* knn_idx_data = knn_idx.data_ptr(); + const long long* knn_idx_data = knn_idx.data_ptr(); float* grad_points_data = grad_points.data_ptr(); float* grad_centers_data = grad_centers.data_ptr(); float* grad_scores_data = grad_scores.data_ptr(); diff --git a/mmdet3d/ops/voxel/src/voxelization_cpu.cpp b/mmdet3d/ops/voxel/src/voxelization_cpu.cpp index c0dad46950..6bcec4019c 100644 --- a/mmdet3d/ops/voxel/src/voxelization_cpu.cpp +++ b/mmdet3d/ops/voxel/src/voxelization_cpu.cpp @@ -14,7 +14,8 @@ void dynamic_voxelize_kernel(const torch::TensorAccessor points, const int NDim) { const int ndim_minus_1 = NDim - 1; bool failed = false; - int coor[NDim]; + // int coor[NDim]; + int* coor = new int[NDim](); int c; for (int i = 0; i < num_points; ++i) { @@ -37,6 +38,7 @@ void dynamic_voxelize_kernel(const torch::TensorAccessor points, } } + delete[] coor; return; } From 6e5725f1700d98c91af427b67fd97d7bf088d796 Mon Sep 17 00:00:00 2001 From: wangruohui <12756472+wangruohui@users.noreply.github.com> Date: Tue, 20 Jul 2021 16:59:22 +0800 Subject: [PATCH 2/2] use int64_t instead of long long --- mmdet3d/ops/iou3d/src/iou3d.cpp | 9 +++++---- mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu | 11 ++++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mmdet3d/ops/iou3d/src/iou3d.cpp b/mmdet3d/ops/iou3d/src/iou3d.cpp index f9f50ef076..25a5cd98f1 100644 --- a/mmdet3d/ops/iou3d/src/iou3d.cpp +++ b/mmdet3d/ops/iou3d/src/iou3d.cpp @@ -12,6 +12,7 @@ All Rights Reserved 2019-2020. #include #include +#include #include #define CHECK_CUDA(x) \ @@ -103,7 +104,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, int boxes_num = boxes.size(0); const float *boxes_data = boxes.data_ptr(); - long long *keep_data = keep.data_ptr(); + int64_t *keep_data = keep.data_ptr(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); @@ -124,7 +125,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, cudaFree(mask_data); - unsigned long long* remv_cpu = new unsigned long long[col_blocks](); + unsigned long long *remv_cpu = new unsigned long long[col_blocks](); int num_to_keep = 0; @@ -157,7 +158,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, int boxes_num = boxes.size(0); const float *boxes_data = boxes.data_ptr(); - long long *keep_data = keep.data_ptr(); + int64_t *keep_data = keep.data_ptr(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); @@ -178,7 +179,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, cudaFree(mask_data); - unsigned long long* remv_cpu = new unsigned long long[col_blocks](); + unsigned long long *remv_cpu = new unsigned long long[col_blocks](); int num_to_keep = 0; diff --git a/mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu b/mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu index af51421463..7ae56f24b2 100644 --- a/mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu +++ b/mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -49,7 +50,7 @@ __global__ void assign_score_withk_forward_kernel(const int B, const int N0, con const float* points, const float* centers, const float* scores, - const long long* knn_idx, + const int64_t* knn_idx, float* output) { // ----- parallel loop for B, N1, K and O --------- @@ -82,7 +83,7 @@ __global__ void assign_score_withk_backward_points_kernel(const int B, const int const int K, const int O, const int aggregate, const float* grad_out, const float* scores, - const long long* knn_idx, + const int64_t* knn_idx, float* grad_points, float* grad_centers) { @@ -116,7 +117,7 @@ __global__ void assign_score_withk_backward_scores_kernel(const int B, const int const float* grad_out, const float* points, const float* centers, - const long long* knn_idx, + const int64_t* knn_idx, float* grad_scores) { // ----- parallel loop for B, N, K, M --------- @@ -156,7 +157,7 @@ void assign_score_withk_forward_wrapper(int B, int N0, int N1, int M, int K, int const float* points_data = points.data_ptr(); const float* centers_data = centers.data_ptr(); const float* scores_data = scores.data_ptr(); - const long long* knn_idx_data = knn_idx.data_ptr(); + const int64_t* knn_idx_data = knn_idx.data_ptr(); float* output_data = output.data_ptr(); dim3 blocks(DIVUP(B*O*N1*K, THREADS_PER_BLOCK)); @@ -191,7 +192,7 @@ void assign_score_withk_backward_wrapper(int B, int N0, int N1, int M, int K, in const float* points_data = points.data_ptr(); const float* centers_data = centers.data_ptr(); const float* scores_data = scores.data_ptr(); - const long long* knn_idx_data = knn_idx.data_ptr(); + const int64_t* knn_idx_data = knn_idx.data_ptr(); float* grad_points_data = grad_points.data_ptr(); float* grad_centers_data = grad_centers.data_ptr(); float* grad_scores_data = grad_scores.data_ptr();