Skip to content

Commit

Permalink
temporarily fix issue #305
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhen17 committed Feb 18, 2021
1 parent f29fb1b commit df57552
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,17 @@ __global__ void points_in_boxes_batch_kernel(int batch_size, int boxes_num,

void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,
const float *boxes, const float *pts,
int *box_idx_of_points) {
int *box_idx_of_points, int device) {
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
cudaError_t err;

cudaError_t err = cudaSetDevice(device);
if (cudaSuccess != err) {
fprintf(stderr, "Invalid CUDA device : %s\n", cudaGetErrorString(err));
exit(-1);
}

dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
Expand All @@ -131,11 +136,15 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,

void points_in_boxes_batch_launcher(int batch_size, int boxes_num, int pts_num,
const float *boxes, const float *pts,
int *box_idx_of_points) {
int *box_idx_of_points, int device) {
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box params pts: (B, npoints, 3) [x, y, z] in
// LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1
cudaError_t err;
cudaError_t err = cudaSetDevice(device);
if (cudaSuccess != err) {
fprintf(stderr, "Invalid CUDA device : %s\n", cudaGetErrorString(err));
exit(-1);
}

dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
Expand Down Expand Up @@ -172,8 +181,10 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
const float *pts = pts_tensor.data_ptr<float>();
int *box_idx_of_points = box_idx_of_points_tensor.data_ptr<int>();

int device = boxes_tensor.get_device();

points_in_boxes_launcher(batch_size, boxes_num, pts_num, boxes, pts,
box_idx_of_points);
box_idx_of_points, device);

return 1;
}
Expand All @@ -196,8 +207,10 @@ int points_in_boxes_batch(at::Tensor boxes_tensor, at::Tensor pts_tensor,
const float *pts = pts_tensor.data_ptr<float>();
int *box_idx_of_points = box_idx_of_points_tensor.data_ptr<int>();

int device = boxes_tensor.get_device();

points_in_boxes_batch_launcher(batch_size, boxes_num, pts_num, boxes, pts,
box_idx_of_points);
box_idx_of_points, device);

return 1;
}

0 comments on commit df57552

Please sign in to comment.