Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Adds windows compilation support #551

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mmdet3d/ops/iou3d/src/iou3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep,

int boxes_num = boxes.size(0);
const float *boxes_data = boxes.data_ptr<float>();
long *keep_data = keep.data_ptr<long>();
long long *keep_data = keep.data_ptr<long long>();

const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);

Expand All @@ -124,7 +124,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep,

cudaFree(mask_data);

unsigned long long remv_cpu[col_blocks];
unsigned long long *remv_cpu = new unsigned long long [col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));

int num_to_keep = 0;
Expand Down Expand Up @@ -157,7 +157,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,

int boxes_num = boxes.size(0);
const float *boxes_data = boxes.data_ptr<float>();
long *keep_data = keep.data_ptr<long>();
long long *keep_data = keep.data_ptr<long long>();

const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);

Expand All @@ -178,7 +178,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,

cudaFree(mask_data);

unsigned long long remv_cpu[col_blocks];
unsigned long long *remv_cpu = new unsigned long long [col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));

int num_to_keep = 0;
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/ops/iou3d/src/iou3d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ All Rights Reserved 2019-2020.

//#define DEBUG
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
const float EPS = 1e-8;
__device__ const float EPS = 1e-8;
struct Point {
float x, y;
__device__ Point() {}
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/ops/knn/src/knn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void knn_kernels_launcher(
int dim,
int k,
float* dist_dev,
long* ind_dev,
long long* ind_dev,
cudaStream_t stream
);

Expand All @@ -39,7 +39,7 @@ void knn_wrapper(
int dim = query.size(0);
auto dist = at::empty({ref_nb, query_nb}, query.options().dtype(at::kFloat));
float * dist_dev = dist.data_ptr<float>();
long * ind_dev = ind.data_ptr<long>();
long long * ind_dev = ind.data_ptr<long long>();

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

Expand Down
14 changes: 7 additions & 7 deletions mmdet3d/ops/knn/src/knn_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ __global__ void cuComputeDistanceGlobal(const float* A, int wA,
* @param height height of the distance matrix and of the index matrix
* @param k number of neighbors to consider
*/
__global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){
__global__ void cuInsertionSort(float *dist, long long *ind, int width, int height, int k){

// Variables
int l, i, j;
float *p_dist;
long *p_ind;
long long *p_ind;
float curr_dist, max_dist;
long curr_row, max_row;
long long curr_row, max_row;
unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
if (xIndex<width){
// Pointer shift, initialization, and max value
Expand Down Expand Up @@ -182,16 +182,16 @@ __global__ void cuParallelSqrt(float *dist, int width, int k){
}


void debug(float * dist_dev, long * ind_dev, const int query_nb, const int k){
void debug(float * dist_dev, long long * ind_dev, const int query_nb, const int k){
float* dist_host = new float[query_nb * k];
long* idx_host = new long[query_nb * k];
long long* idx_host = new long long[query_nb * k];

// Memory copy of output from device to host
cudaMemcpy(dist_host, dist_dev,
query_nb * k * sizeof(float), cudaMemcpyDeviceToHost);

cudaMemcpy(idx_host, ind_dev,
query_nb * k * sizeof(long), cudaMemcpyDeviceToHost);
query_nb * k * sizeof(long long), cudaMemcpyDeviceToHost);

int i, j;
for(i = 0; i < k; i++){
Expand Down Expand Up @@ -229,7 +229,7 @@ void debug(float * dist_dev, long * ind_dev, const int query_nb, const int k){
*
*/
void knn_kernels_launcher(const float* ref_dev, int ref_nb, const float* query_dev, int query_nb,
int dim, int k, float* dist_dev, long* ind_dev, cudaStream_t stream){
int dim, int k, float* dist_dev, long long* ind_dev, cudaStream_t stream){

// Grids ans threads
dim3 g_16x16(query_nb / BLOCK_DIM, ref_nb / BLOCK_DIM, 1);
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/ops/voxel/src/voxelization_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void dynamic_voxelize_kernel(const torch::TensorAccessor<T, 2> points,
const int NDim) {
const int ndim_minus_1 = NDim - 1;
bool failed = false;
int coor[NDim];
int *coor = new int[NDim];
int c;

for (int i = 0; i < num_points; ++i) {
Expand Down