Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix bbox overlap fp16 #1958

Merged
merged 2 commits into from
May 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ __global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
}
}

#if __CUDA_ARCH__ >= 530
__device__ __forceinline__ __half __half_area(const __half x1, const __half y1,
const __half x2, const __half y2,
const __half offset) {
Expand Down Expand Up @@ -141,5 +142,6 @@ __device__ void bbox_overlaps_cuda_kernel_half(
ious[index] = __hdiv(interS, baseS);
}
}
#endif // __CUDA_ARCH__ >= 530

#endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// Disable fp16 on ROCm device
#ifndef HIP_DIFF
#if __CUDA_ARCH__ >= 530
template <>
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
const at::Half* bbox1, const at::Half* bbox2, at::Half* ious,
Expand All @@ -14,6 +15,7 @@ __global__ void bbox_overlaps_cuda_kernel<at::Half>(
reinterpret_cast<__half*>(ious), num_bbox1,
num_bbox2, mode, aligned, offset);
}
#endif // __CUDA_ARCH__ >= 530
#endif // HIP_DIFF

void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Expand Down