diff --git a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh index 3f95114994..0a5c2505f5 100644 --- a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh @@ -72,4 +72,46 @@ __global__ void nms_cuda(const int n_boxes, const float iou_threshold, } } } + +__global__ void gather_keep_from_mask(bool *keep, + const unsigned long long *dev_mask, + const int n_boxes) { + const int col_blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock; + const int tid = threadIdx.x; + + // mark the bboxes which have been removed. + extern __shared__ unsigned long long removed[]; + + // initialize removed. + for (int i = tid; i < col_blocks; i += blockDim.x) { + removed[i] = 0; + } + __syncthreads(); + + for (int nblock = 0; nblock < col_blocks; ++nblock) { + auto removed_val = removed[nblock]; + __syncthreads(); + const int i_offset = nblock * threadsPerBlock; +#pragma unroll + for (int inblock = 0; inblock < threadsPerBlock; ++inblock) { + const int i = i_offset + inblock; + if (i >= n_boxes) break; + // select a candidate, check if it should kept. + if (!(removed_val & (1ULL << inblock))) { + if (tid == 0) { + // mark the output. + keep[i] = true; + } + auto p = dev_mask + i * col_blocks; + // remove all bboxes which overlap the candidate. + for (int j = tid; j < col_blocks; j += blockDim.x) { + if (j >= nblock) removed[j] |= p[j]; + } + __syncthreads(); + removed_val = removed[nblock]; + } + } + } +} + #endif // NMS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/nms_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/nms_cuda.cu index 8798ec1740..1b87e0fa75 100644 --- a/mmcv/ops/csrc/pytorch/cuda/nms_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/nms_cuda.cu @@ -24,31 +24,13 @@ Tensor NMSCUDAKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, boxes_num, iou_threshold, offset, boxes_sorted.data_ptr(), (unsigned long long*)mask.data_ptr()); - at::Tensor mask_cpu = mask.to(at::kCPU); - unsigned long long* mask_host = - (unsigned long long*)mask_cpu.data_ptr(); - - std::vector remv(col_blocks); - memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); - - at::Tensor keep_t = - at::zeros({boxes_num}, boxes.options().dtype(at::kBool).device(at::kCPU)); - bool* keep = keep_t.data_ptr(); - - for (int i = 0; i < boxes_num; i++) { - int nblock = i / threadsPerBlock; - int inblock = i % threadsPerBlock; - - if (!(remv[nblock] & (1ULL << inblock))) { - keep[i] = true; - // set every overlap box with bit 1 in remv - unsigned long long* p = mask_host + i * col_blocks; - for (int j = nblock; j < col_blocks; j++) { - remv[j] |= p[j]; - } - } - } - + // Filter the boxes which should be kept. + at::Tensor keep_t = at::zeros( + {boxes_num}, boxes.options().dtype(at::kBool).device(at::kCUDA)); + gather_keep_from_mask<<<1, min(col_blocks, THREADS_PER_BLOCK), + col_blocks * sizeof(unsigned long long), stream>>>( + keep_t.data_ptr(), (unsigned long long*)mask.data_ptr(), + boxes_num); AT_CUDA_CHECK(cudaGetLastError()); - return order_t.masked_select(keep_t.to(at::kCUDA)); + return order_t.masked_select(keep_t); }