From 2b24198c86296d380cdbc5e7850bb3f83f5428dc Mon Sep 17 00:00:00 2001 From: Divyansh Jha Date: Mon, 17 May 2021 12:13:48 +0530 Subject: [PATCH] adds windows compilation support --- mmdet3d/ops/iou3d/src/iou3d.cpp | 8 ++++---- mmdet3d/ops/iou3d/src/iou3d_kernel.cu | 2 +- mmdet3d/ops/knn/src/knn.cpp | 4 ++-- mmdet3d/ops/knn/src/knn_cuda.cu | 14 +++++++------- mmdet3d/ops/voxel/src/voxelization_cpu.cpp | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mmdet3d/ops/iou3d/src/iou3d.cpp b/mmdet3d/ops/iou3d/src/iou3d.cpp index 706cddbb80..56d1987c63 100644 --- a/mmdet3d/ops/iou3d/src/iou3d.cpp +++ b/mmdet3d/ops/iou3d/src/iou3d.cpp @@ -103,7 +103,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, int boxes_num = boxes.size(0); const float *boxes_data = boxes.data_ptr(); - long *keep_data = keep.data_ptr(); + long long *keep_data = keep.data_ptr(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); @@ -124,7 +124,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, cudaFree(mask_data); - unsigned long long remv_cpu[col_blocks]; + unsigned long long *remv_cpu = new unsigned long long [col_blocks]; memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); int num_to_keep = 0; @@ -157,7 +157,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, int boxes_num = boxes.size(0); const float *boxes_data = boxes.data_ptr(); - long *keep_data = keep.data_ptr(); + long long *keep_data = keep.data_ptr(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); @@ -178,7 +178,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, cudaFree(mask_data); - unsigned long long remv_cpu[col_blocks]; + unsigned long long *remv_cpu = new unsigned long long [col_blocks]; memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); int num_to_keep = 0; diff --git a/mmdet3d/ops/iou3d/src/iou3d_kernel.cu b/mmdet3d/ops/iou3d/src/iou3d_kernel.cu index fce3f78825..861aea3c5a 100644 --- a/mmdet3d/ops/iou3d/src/iou3d_kernel.cu +++ b/mmdet3d/ops/iou3d/src/iou3d_kernel.cu @@ -13,7 +13,7 @@ All Rights Reserved 2019-2020. //#define DEBUG const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; -const float EPS = 1e-8; +__device__ const float EPS = 1e-8; struct Point { float x, y; __device__ Point() {} diff --git a/mmdet3d/ops/knn/src/knn.cpp b/mmdet3d/ops/knn/src/knn.cpp index a86e13d0af..332106cd49 100644 --- a/mmdet3d/ops/knn/src/knn.cpp +++ b/mmdet3d/ops/knn/src/knn.cpp @@ -18,7 +18,7 @@ void knn_kernels_launcher( int dim, int k, float* dist_dev, - long* ind_dev, + long long* ind_dev, cudaStream_t stream ); @@ -39,7 +39,7 @@ void knn_wrapper( int dim = query.size(0); auto dist = at::empty({ref_nb, query_nb}, query.options().dtype(at::kFloat)); float * dist_dev = dist.data_ptr(); - long * ind_dev = ind.data_ptr(); + long long * ind_dev = ind.data_ptr(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/mmdet3d/ops/knn/src/knn_cuda.cu b/mmdet3d/ops/knn/src/knn_cuda.cu index 074ce0dfb5..f59844f4db 100644 --- a/mmdet3d/ops/knn/src/knn_cuda.cu +++ b/mmdet3d/ops/knn/src/knn_cuda.cu @@ -101,14 +101,14 @@ __global__ void cuComputeDistanceGlobal(const float* A, int wA, * @param height height of the distance matrix and of the index matrix * @param k number of neighbors to consider */ -__global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){ +__global__ void cuInsertionSort(float *dist, long long *ind, int width, int height, int k){ // Variables int l, i, j; float *p_dist; - long *p_ind; + long long *p_ind; float curr_dist, max_dist; - long curr_row, max_row; + long long curr_row, max_row; unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; if (xIndex points, const int NDim) { const int ndim_minus_1 = NDim - 1; bool failed = false; - int coor[NDim]; + int *coor = new int[NDim]; int c; for (int i = 0; i < num_points; ++i) {