Skip to content

Commit

Permalink
Remove bbox overlap fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed May 14, 2022
1 parent a3b4640 commit f1353d0
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 68 deletions.
54 changes: 0 additions & 54 deletions mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,58 +88,4 @@ __global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
}
}

__device__ __forceinline__ __half __half_area(const __half x1, const __half y1,
const __half x2, const __half y2,
const __half offset) {
const __half half_w = __hadd(__hsub(x2, x1), offset);
const __half half_h = __hadd(__hsub(y2, y1), offset);
return __hmul(half_w, half_h);
}

__device__ __forceinline__ __half __half_max(const __half a, const __half b) {
return __hge(a, b) ? a : b;
}

__device__ __forceinline__ __half __half_min(const __half a, const __half b) {
return __hle(a, b) ? a : b;
}

// fp16 won't provide much increase when aligned==true. It is useful when
// aligned==false, which would give you ~40% bonus.
__device__ void bbox_overlaps_cuda_kernel_half(
const __half* bbox1, const __half* bbox2, __half* ious, const int num_bbox1,
const int num_bbox2, const int mode, const bool aligned, const int offset) {
const int num_output = aligned ? num_bbox1 : num_bbox1 * num_bbox2;
const __half h_offset = __int2half_rn(offset);
CUDA_1D_KERNEL_LOOP(index, num_output) {
const int b1 = aligned ? index : index / num_bbox2;
const int b2 = aligned ? index : index % num_bbox2;

const int base1 = b1 << 2;
__half b1_x1, b1_y1, b1_x2, b1_y2;
load_bbox<__half>(bbox1, base1, b1_x1, b1_y1, b1_x2, b1_y2);
const __half b1_area = __half_area(b1_x1, b1_y1, b1_x2, b1_y2, h_offset);

const int base2 = b2 << 2;
__half b2_x1, b2_y1, b2_x2, b2_y2;
load_bbox<__half>(bbox2, base2, b2_x1, b2_y1, b2_x2, b2_y2);
const __half b2_area = __half_area(b2_x1, b2_y1, b2_x2, b2_y2, h_offset);

const __half left = __half_max(b1_x1, b2_x1),
right = __half_min(b1_x2, b2_x2);
const __half top = __half_max(b1_y1, b2_y1),
bottom = __half_min(b1_y2, b2_y2);
const __half width =
__half_max(__hadd(__hsub(right, left), h_offset), __float2half(0.f));
const __half height =
__half_max(__hadd(__hsub(bottom, top), h_offset), __float2half(0.f));
const __half interS = __hmul(width, height);

const __half baseS = __half_max(
mode == 0 ? __hsub(__hadd(b1_area, b2_area), interS) : b1_area,
h_offset);
ious[index] = __hdiv(interS, baseS);
}
}

#endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH
14 changes: 0 additions & 14 deletions mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,6 @@
#include "bbox_overlaps_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"

// Disable fp16 on ROCm device
#ifndef HIP_DIFF
template <>
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
const at::Half* bbox1, const at::Half* bbox2, at::Half* ious,
const int num_bbox1, const int num_bbox2, const int mode,
const bool aligned, const int offset) {
bbox_overlaps_cuda_kernel_half(reinterpret_cast<const __half*>(bbox1),
reinterpret_cast<const __half*>(bbox2),
reinterpret_cast<__half*>(ious), num_bbox1,
num_bbox2, mode, aligned, offset);
}
#endif // HIP_DIFF

void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode,
const bool aligned, const int offset) {
Expand Down

0 comments on commit f1353d0

Please sign in to comment.