Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NMS with CUDA only #1824

Merged
merged 5 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,46 @@ __global__ void nms_cuda(const int n_boxes, const float iou_threshold,
}
}
}

__global__ void gather_keep_from_mask(bool *keep,
const unsigned long long *dev_mask,
const int n_boxes) {
const int col_blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock;
const int tid = threadIdx.x;

// mark the bboxes which have been removed.
extern __shared__ unsigned long long removed[];

// initialize removed.
for (int i = tid; i < col_blocks; i += blockDim.x) {
removed[i] = 0;
}
__syncthreads();

for (int nblock = 0; nblock < col_blocks; ++nblock) {
auto removed_val = removed[nblock];
__syncthreads();
const int i_offset = nblock * threadsPerBlock;
#pragma unroll
for (int inblock = 0; inblock < threadsPerBlock; ++inblock) {
const int i = i_offset + inblock;
if (i >= n_boxes) break;
// select a candidate, check if it should kept.
if (!(removed_val & (1ULL << inblock))) {
if (tid == 0) {
// mark the output.
keep[i] = true;
}
auto p = dev_mask + i * col_blocks;
// remove all bboxes which overlap the candidate.
for (int j = tid; j < col_blocks; j += blockDim.x) {
if (j >= nblock) removed[j] |= p[j];
}
__syncthreads();
removed_val = removed[nblock];
}
}
}
}

#endif // NMS_CUDA_KERNEL_CUH
34 changes: 8 additions & 26 deletions mmcv/ops/csrc/pytorch/cuda/nms_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,13 @@ Tensor NMSCUDAKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
boxes_num, iou_threshold, offset, boxes_sorted.data_ptr<float>(),
(unsigned long long*)mask.data_ptr<int64_t>());

at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();

std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

at::Tensor keep_t =
at::zeros({boxes_num}, boxes.options().dtype(at::kBool).device(at::kCPU));
bool* keep = keep_t.data_ptr<bool>();

for (int i = 0; i < boxes_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;

if (!(remv[nblock] & (1ULL << inblock))) {
keep[i] = true;
// set every overlap box with bit 1 in remv
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}

// Filter the boxes which should be kept.
at::Tensor keep_t = at::zeros(
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
{boxes_num}, boxes.options().dtype(at::kBool).device(at::kCUDA));
gather_keep_from_mask<<<1, min(col_blocks, THREADS_PER_BLOCK),
col_blocks * sizeof(unsigned long long), stream>>>(
keep_t.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
boxes_num);
AT_CUDA_CHECK(cudaGetLastError());
return order_t.masked_select(keep_t.to(at::kCUDA));
return order_t.masked_select(keep_t);
}